odak.learn.wave.gradient_descent¶
Definition to generate phase and reconstruction from target image via gradient descent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
field |
torch.Tensor |
Target field intensity. |
required |
wavelength |
double |
Set if the converted array requires gradient. |
required |
distance |
double |
Hologaram plane distance wrt SLM plane |
required |
dx |
float |
SLM pixel pitch |
required |
resolution |
array |
SLM resolution |
required |
propagation_type |
str |
Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer) |
required |
n_iteration |
: int |
Max iteratation |
100 |
cuda |
boolean |
GPU enabled |
False |
alpha |
float |
A hyperparameter. |
0.1 |
Returns:
Type | Description |
---|---|
torch.Tensor |
Phase only hologram as torch array |
Source code in odak/learn/wave/classical.py
def gradient_descent(field, wavelength, distance, dx, resolution, propagation_type, n_iteration=100, cuda=False, alpha=0.1):
"""
Definition to generate phase and reconstruction from target image via gradient descent.
Parameters
----------
field : torch.Tensor
Target field intensity.
wavelength : double
Set if the converted array requires gradient.
distance : double
Hologaram plane distance wrt SLM plane
dx : float
SLM pixel pitch
resolution : array
SLM resolution
propagation_type : str
Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer)
n_iteration: : int
Max iteratation
cuda : boolean
GPU enabled
alpha : float
A hyperparameter.
Returns
-------
hologram : torch.Tensor
Phase only hologram as torch array
reconstruction_intensity: torch.Tensor
Reconstruction as torch array
"""
torch.cuda.empty_cache()
torch.manual_seed(0)
device = torch.device("cuda" if cuda else "cpu")
field = field.to(device)
phase = torch.rand(resolution[0], resolution[1])
amplitude = torch.ones(
resolution[0], resolution[1], requires_grad=False).to(device)
k = wavenumber(wavelength)
loss_function = torch.nn.MSELoss(reduction='none').to(device)
t = tqdm(range(n_iteration), leave=False)
hologram = generate_complex_field(amplitude, phase)
for i in t:
hologram_padded = zero_pad(hologram)
reconstruction_padded = propagate_beam(
hologram_padded, k, distance, dx, wavelength, propagation_type)
reconstruction = crop_center(reconstruction_padded)
reconstruction_intensity = calculate_amplitude(reconstruction)**2
loss = loss_function(reconstruction_intensity, field)
loss_field = generate_complex_field(loss, calculate_phase(reconstruction))
loss_field_padded = zero_pad(loss_field)
loss_propagated_padded = propagate_beam(loss_field_padded, k, -distance, dx, wavelength, propagation_type)
loss_propagated = crop_center(loss_propagated_padded)
hologram_updated = hologram - alpha * loss_propagated
hologram_phase = calculate_phase(hologram_updated)
hologram = generate_complex_field(amplitude, hologram_phase)
description = "Iteration: {} loss:{:.4f}".format(i, torch.mean(loss))
t.set_description(description)
print(description)
torch.no_grad()
hologram = generate_complex_field(amplitude, phase)
hologram_padded = zero_pad(hologram)
reconstruction_padded = propagate_beam(
hologram_padded, k, distance, dx, wavelength, propagation_type)
reconstruction = crop_center(reconstruction_padded)
hologram = crop_center(hologram_padded)
return hologram.detach(), reconstruction.detach()