"""
Fraunhofer2D — 2-D far-field (Fraunhofer) propagator via 2-D FFT.
Note: phases are approximate; not suitable for compound beamlines or in combination with lenses.
"""
import numpy
from wofry.propagator.wavefront2D.generic_wavefront import GenericWavefront2D
from wofry.propagator.propagator import Propagator2D
[docs]
class Fraunhofer2D(Propagator2D):
HANDLER_NAME = "FRAUNHOFER_2D"
[docs]
def get_handler_name(self):
return self.HANDLER_NAME
[docs]
def do_specific_progation_after(self, wavefront, propagation_distance, parameters, element_index=None):
return self.do_specific_progation(wavefront, propagation_distance, parameters, element_index=element_index)
[docs]
def do_specific_progation_before(self, wavefront, propagation_distance, parameters, element_index=None):
return self.do_specific_progation( wavefront, propagation_distance, parameters, element_index=element_index)
[docs]
def do_specific_progation(self, wavefront, propagation_distance, parameters, element_index=None):
"""
Propagate a 2-D wavefront using the Fraunhofer (far-field) approximation.
Parameters
----------
wavefront : GenericWavefront2D
Input wavefront.
propagation_distance : float
Propagation distance [m]. If zero, the output abscissas are in
angle [rad].
parameters : PropagationParameters
Propagation parameter container (may include ``shift_half_pixel``).
element_index : int, optional
Index of the beamline element being propagated through.
Returns
-------
GenericWavefront2D
Propagated wavefront on the far-field grid.
"""
shift_half_pixel = self.get_additional_parameter("shift_half_pixel",False,parameters,element_index=element_index)
return self.propagate_wavefront(wavefront,propagation_distance,shift_half_pixel=shift_half_pixel)
[docs]
@classmethod
def propagate_wavefront(cls,wavefront,propagation_distance,shift_half_pixel=False):
wavelength = wavefront.get_wavelength()
#
# check validity
#
x = wavefront.get_coordinate_x()
y = wavefront.get_coordinate_y()
half_max_aperture = 0.5 * numpy.array((x[-1]-x[0], y[-1]-y[0])).max()
far_field_distance = half_max_aperture**2/wavelength
if propagation_distance < far_field_distance:
print("WARNING: Fraunhoffer diffraction valid for distances > > half_max_aperture^2/lambda = %f m (propagating at %4.1f)"%
(far_field_distance,propagation_distance))
#
#compute Fourier transform
#
F1 = numpy.fft.fft2(wavefront.get_complex_amplitude()) # Take the fourier transform of the image.
# Now shift the quadrants around so that low spatial frequencies are in
# the center of the 2D fourier transformed image.
F2 = numpy.fft.fftshift( F1 )
# frequency for axis 1
shape = wavefront.size()
delta = wavefront.delta()
wavenumber = wavefront.get_wavenumber()
pixelsize = delta[0] # p_x[1] - p_x[0]
npixels = shape[0]
fft_scale = numpy.fft.fftfreq(npixels, d=pixelsize)
fft_scale = numpy.fft.fftshift(fft_scale)
x2 = fft_scale * propagation_distance * wavelength
# frequency for axis 2
pixelsize = delta[1]
npixels = shape[1]
fft_scale = numpy.fft.fftfreq(npixels, d=pixelsize)
fft_scale = numpy.fft.fftshift(fft_scale)
y2 = fft_scale * propagation_distance * wavelength
f_x, f_y = numpy.meshgrid(x2, y2, indexing='ij')
fsq = numpy.fft.fftshift(f_x ** 2 + f_y ** 2)
P1 = numpy.exp(1.0j * wavenumber * propagation_distance)
P2 = numpy.exp(1.0j * wavenumber / 2 / propagation_distance * fsq)
P3 = 1.0j * wavelength * propagation_distance
F1 = numpy.fft.fft2(wavefront.get_complex_amplitude()) # Take the fourier transform of the image.
# Now shift the quadrants around so that low spatial frequencies are in
# the center of the 2D fourier transformed image.
F1 *= P1
F1 *= P2
F1 /= P3
F2 = numpy.fft.fftshift(F1)
if shift_half_pixel:
x2 = x2 - 0.5 * numpy.abs(x2[1] - x2[0])
y2 = y2 - 0.5 * numpy.abs(y2[1] - y2[0])
wavefront_out = GenericWavefront2D.initialize_wavefront_from_arrays(x_array=x2,
y_array=y2,
z_array=F2,
wavelength=wavelength)
# added srio@esrf.eu 2018-03-23 to conserve energy - TODO: review method!
wavefront_out.rescale_amplitude( numpy.sqrt(wavefront.get_intensity().sum() /
wavefront_out.get_intensity().sum()))
return wavefront_out