The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):

"""
The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
"""

def __init__(self, kernel = None, loss = nn.MSELoss(), device=torch.device("cpu")):
"""
Parameters
----------
kernel                  : torch.tensor
Convolution filter kernel, 3 by 3 Laplacian kernel by default.
loss                    : torch.nn.Module
loss function, L2 Loss by default.
"""
self.device = device
self.loss = loss
if kernel == None:
self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32)/8
else:
if len(kernel.shape) == 4:
self.kernel = kernel
else:
self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
self.kernel = Variable(self.kernel.to(self.device))

def forward(self, phase):
"""

Parameters
----------
phase                  : torch.tensor
Phase of the complex amplitude.

Returns
-------

loss_value              : torch.tensor
The computed loss.
"""

if len(phase.shape) == 2:
phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
edge_detect = self.functional_conv2d(phase)
loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
return loss_value

def functional_conv2d(self, phase):
"""
Calculates the gradient of the phase.

Parameters
----------
phase                  : torch.tensor
Phase of the complex amplitude.

Returns
-------

edge_detect              : torch.tensor
"""
return edge_detect


## __init__(self, kernel=None, loss=MSELoss(), device=device(type='cpu')) special ¶

Parameters:

Name Type Description Default
kernel torch.tensor

Convolution filter kernel, 3 by 3 Laplacian kernel by default.

None
loss torch.nn.Module

loss function, L2 Loss by default.

MSELoss()
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device=torch.device("cpu")):
"""
Parameters
----------
kernel                  : torch.tensor
Convolution filter kernel, 3 by 3 Laplacian kernel by default.
loss                    : torch.nn.Module
loss function, L2 Loss by default.
"""
self.device = device
self.loss = loss
if kernel == None:
self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32)/8
else:
if len(kernel.shape) == 4:
self.kernel = kernel
else:
self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
self.kernel = Variable(self.kernel.to(self.device))


## forward(self, phase)¶

Parameters:

Name Type Description Default
phase torch.tensor

Phase of the complex amplitude.

required

Returns:

Type Description
torch.tensor

The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):
"""

Parameters
----------
phase                  : torch.tensor
Phase of the complex amplitude.

Returns
-------

loss_value              : torch.tensor
The computed loss.
"""

if len(phase.shape) == 2:
phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
edge_detect = self.functional_conv2d(phase)
loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
return loss_value


## functional_conv2d(self, phase)¶

Calculates the gradient of the phase.

Parameters:

Name Type Description Default
phase torch.tensor

Phase of the complex amplitude.

required

Returns:

Type Description
torch.tensor

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):
"""
Calculates the gradient of the phase.

Parameters
----------
phase                  : torch.tensor
Phase of the complex amplitude.

Returns
-------

edge_detect              : torch.tensor