Beispiel #1
0
        def tf_fn():
            dt = tf.constant(d, dtype=tf.float32)
            kt = tf.constant(k, dtype=tf.float32)

            # Determine FFT dimensions and functions
            pad_dims = fft_utils_tf.get_fft_pad_dims(dt, kt)
            optim_dims = fft_utils_tf.optimize_dims(pad_dims, mode)
            fft_fwd, fft_rev = fft_utils_tf.get_fft_tf_fns(dt.shape.ndims)

            # Run convolution of data 'd' with kernel 'k'
            dk_fft = fft_fwd(kt, fft_length=optim_dims)
            dconv = fft_utils_tf.convolve(dt, dk_fft, optim_dims, fft_fwd, fft_rev)

            # Extract patch from result matching dimensions of original data array
            return fft_utils_tf.extract(dconv, tf.shape(dt), pad_dims)
Beispiel #2
0
 def tf_fn():
     dt = tf.constant(d, dtype=tf.float32)
     kt = tf.constant(k, dtype=tf.float32)
     pad_dims = fft_utils_tf.get_fft_pad_dims(dt, kt)
     return fft_utils_tf.optimize_dims(pad_dims, mode)
Beispiel #3
0
 def tf_fn():
     dt = tf.constant(d, dtype=tf.float32)
     kt = tf.constant(k, dtype=tf.float32)
     return fft_utils_tf.get_fft_pad_dims(dt, kt)