Source code for wofryimpl.propagator.propagators2D.fresnel_convolution

"""
FresnelConvolution2D — 2-D near-field Fresnel propagator using direct spatial-domain convolution (scipy.signal.fftconvolve).

Prefer ``Fresnel2D`` (FFT transfer-function method) for better accuracy.
"""
import numpy

from wofry.propagator.wavefront2D.generic_wavefront import GenericWavefront2D
from wofry.propagator.propagator import Propagator2D

[docs] class FresnelConvolution2D(Propagator2D): HANDLER_NAME = "FRESNEL_CONVOLUTION_2D"
[docs] def get_handler_name(self): return self.HANDLER_NAME
[docs] def do_specific_progation_after(self, wavefront, propagation_distance, parameters, element_index=None): return self.do_specific_progation(wavefront, propagation_distance, parameters, element_index=element_index)
[docs] def do_specific_progation_before(self, wavefront, propagation_distance, parameters, element_index=None): return self.do_specific_progation( wavefront, propagation_distance, parameters, element_index=element_index)
[docs] def do_specific_progation(self, wavefront, propagation_distance, parameters, element_index=None): """ Propagate a 2-D wavefront using direct spatial-domain convolution with the Fresnel kernel. Parameters ---------- wavefront : GenericWavefront2D Input wavefront. propagation_distance : float Propagation distance [m]. parameters : PropagationParameters Propagation parameter container (may include ``shift_half_pixel``). element_index : int, optional Index of the beamline element being propagated through. Returns ------- GenericWavefront2D Propagated wavefront on the same spatial grid. """ shift_half_pixel = self.get_additional_parameter("shift_half_pixel",False,parameters,element_index=element_index) is_generic_wavefront = isinstance(wavefront, GenericWavefront2D) if is_generic_wavefront: pass else: wavefront = wavefront.toGenericWavefront() return self.propagate_wavefront(wavefront,propagation_distance,shift_half_pixel=shift_half_pixel)
[docs] @classmethod def propagate_wavefront(cls,wavefront,propagation_distance,shift_half_pixel=False): from scipy.signal import fftconvolve wavelength = wavefront.get_wavelength() X = wavefront.get_mesh_x() Y = wavefront.get_mesh_y() if shift_half_pixel: x = wavefront.get_coordinate_x() y = wavefront.get_coordinate_y() X += 0.5 * numpy.abs( x[0] - x[1] ) Y += 0.5 * numpy.abs( y[0] - y[1] ) kernel = numpy.exp(1j*2*numpy.pi/wavefront.get_wavelength() * (X**2 + Y**2) / 2 / propagation_distance) kernel *= numpy.exp(1j*2*numpy.pi/wavefront.get_wavelength() * propagation_distance) kernel /= 1j * wavefront.get_wavelength() * propagation_distance wavefront_out = GenericWavefront2D.initialize_wavefront_from_arrays(x_array=wavefront.get_coordinate_x(), y_array=wavefront.get_coordinate_y(), z_array=fftconvolve(wavefront.get_complex_amplitude(), kernel, mode='same'), wavelength=wavelength) # added srio@esrf.eu 2018-03-23 to conserve energy - TODO: review method! wavefront_out.rescale_amplitude( numpy.sqrt(wavefront.get_intensity().sum() / wavefront_out.get_intensity().sum())) return wavefront_out