Skip to content

odak.learn.wave.gradient_descent

Definition to generate phase and reconstruction from target image via gradient descent.

Parameters:

Name Type Description Default
field torch.Tensor

Target field intensity.

required
wavelength double

Set if the converted array requires gradient.

required
distance double

Hologaram plane distance wrt SLM plane

required
dx float

SLM pixel pitch

required
resolution array

SLM resolution

required
propagation_type str

Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer)

required
n_iteration : int

Max iteratation

100
cuda boolean

GPU enabled

False
alpha float

A hyperparameter.

0.1

Returns:

Type Description
torch.Tensor

Phase only hologram as torch array

Source code in odak/learn/wave/classical.py
def gradient_descent(field, wavelength, distance, dx, resolution, propagation_type, n_iteration=100, cuda=False, alpha=0.1):
    """
    Definition to generate phase and reconstruction from target image via gradient descent.

    Parameters
    ----------
    field                   : torch.Tensor
                              Target field intensity.
    wavelength              : double
                              Set if the converted array requires gradient.
    distance                : double
                              Hologaram plane distance wrt SLM plane
    dx                      : float
                              SLM pixel pitch
    resolution              : array
                              SLM resolution
    propagation_type        : str
                              Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer)
    n_iteration:            : int
                              Max iteratation 
    cuda                    : boolean
                              GPU enabled
    alpha                   : float
                              A hyperparameter.

    Returns
    -------
    hologram                : torch.Tensor
                              Phase only hologram as torch array

    reconstruction_intensity: torch.Tensor
                              Reconstruction as torch array

    """
    torch.cuda.empty_cache()
    torch.manual_seed(0)
    device = torch.device("cuda" if cuda else "cpu")
    field = field.to(device)
    phase = torch.rand(resolution[0], resolution[1])
    amplitude = torch.ones(
        resolution[0], resolution[1], requires_grad=False).to(device)
    k = wavenumber(wavelength)
    loss_function = torch.nn.MSELoss(reduction='none').to(device)
    t = tqdm(range(n_iteration), leave=False)
    hologram = generate_complex_field(amplitude, phase)
    for i in t:
        hologram_padded = zero_pad(hologram)
        reconstruction_padded = propagate_beam(
            hologram_padded, k, distance, dx, wavelength, propagation_type)
        reconstruction = crop_center(reconstruction_padded)
        reconstruction_intensity = calculate_amplitude(reconstruction)**2
        loss = loss_function(reconstruction_intensity, field)
        loss_field = generate_complex_field(loss, calculate_phase(reconstruction))
        loss_field_padded = zero_pad(loss_field)
        loss_propagated_padded = propagate_beam(loss_field_padded, k, -distance, dx, wavelength, propagation_type)
        loss_propagated = crop_center(loss_propagated_padded)
        hologram_updated = hologram - alpha * loss_propagated
        hologram_phase = calculate_phase(hologram_updated)
        hologram = generate_complex_field(amplitude, hologram_phase)
        description = "Iteration: {} loss:{:.4f}".format(i, torch.mean(loss))
        t.set_description(description)
    print(description)
    torch.no_grad()
    hologram = generate_complex_field(amplitude, phase)
    hologram_padded = zero_pad(hologram)
    reconstruction_padded = propagate_beam(
        hologram_padded, k, distance, dx, wavelength, propagation_type)
    reconstruction = crop_center(reconstruction_padded)
    hologram = crop_center(hologram_padded)
    return hologram.detach(), reconstruction.detach()

See also