Skip to content

odak.learn.wave.stochastic_gradient_descent

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

Name Type Description Default
field torch.Tensor

Target field amplitude.

required
wavelength double

Set if the converted array requires gradient.

required
distance double

Hologaram plane distance wrt SLM plane

required
dx float

SLM pixel pitch

required
resolution array

SLM resolution

required
propogation_type str

Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer)

required
n_iteration : int

Max iteratation

100
loss_function : function

If none it is set to be l2 loss

None
cuda boolean

GPU enabled

False
learning_rate float

Learning rate.

0.1

Returns:

Type Description
torch.Tensor

Phase only hologram as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(field, wavelength, distance, dx, resolution, propogation_type, n_iteration=100, loss_function=None, cuda=False, learning_rate=0.1):
    """
    Definition to generate phase and reconstruction from target image via stochastic gradient descent.

    Parameters
    ----------
    field                   : torch.Tensor
                              Target field amplitude.
    wavelength              : double
                              Set if the converted array requires gradient.
    distance                : double
                              Hologaram plane distance wrt SLM plane
    dx                      : float
                              SLM pixel pitch
    resolution              : array
                              SLM resolution
    propogation_type        : str
                              Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer)
    n_iteration:            : int
                              Max iteratation 
    loss_function:          : function
                              If none it is set to be l2 loss
    cuda                    : boolean
                              GPU enabled
    learning_rate           : float
                              Learning rate.

    Returns
    -------
    hologram                : torch.Tensor
                              Phase only hologram as torch array

    reconstruction_intensity: torch.Tensor
                              Reconstruction as torch array

    """
    torch.cuda.empty_cache()
    torch.manual_seed(0)
    device = torch.device("cuda" if cuda else "cpu")
    field = field.to(device)
    phase = torch.rand(resolution[0], resolution[1]).detach().to(
        device).requires_grad_()
    amplitude = torch.ones(
        resolution[0], resolution[1], requires_grad=False).to(device)
    k = wavenumber(wavelength)
    optimizer = torch.optim.Adam([{'params': [phase]}], lr=learning_rate)
    if type(loss_function) == type(None):
        loss_function = torch.nn.MSELoss().to(device)
    t = tqdm(range(n_iteration), leave=False)
    for i in t:
        optimizer.zero_grad()
        hologram = generate_complex_field(amplitude, phase)
        hologram_padded = zero_pad(hologram)
        reconstruction_padded = propagate_beam(
            hologram_padded, k, distance, dx, wavelength, propogation_type)
        reconstruction = crop_center(reconstruction_padded)
        reconstruction_intensity = calculate_amplitude(reconstruction)**2
        loss = loss_function(reconstruction_intensity, field)
        description = "Iteration: {} loss:{:.4f}".format(i, loss.item())
        loss.backward(retain_graph=True)
        optimizer.step()
        t.set_description(description)
    print(description)
    torch.no_grad()
    hologram = generate_complex_field(amplitude, phase)
    hologram_padded = zero_pad(hologram)
    reconstruction_padded = propagate_beam(
        hologram_padded, k, distance, dx, wavelength, propogation_type)
    reconstruction = crop_center(reconstruction_padded)
    hologram = crop_center(hologram_padded)
    return hologram.detach(), reconstruction.detach()

Notes

To optimize a phase-only hologram using Gerchberg-Saxton algorithm, please follow and observe the below example:

import torch
from odak.learn.wave import stochastic_gradient_descent
wavelength               = 0.000000532
dx                       = 0.0000064
distance                 = 0.1
cuda                     = False
resolution               = [1080,1920]
target_field             = torch.zeros(resolution[0],resolution[1])
target_field[500::600,:] = 1
iteration_number         = 5
hologram,reconstructed   = stochastic_gradient_descent(
                                                       target_field,
                                                       wavelength,
                                                       distance,
                                                       dx,
                                                       resolution,
                                                       'TR Fresnel',
                                                       iteration_number,
                                                       learning_rate=0.1,
                                                       cuda=cuda
                                                      )

See also