示例#1
0
文件: models.py 项目: hyzcn/cnn_graph
    def _inference(self, x, dropout):
        with tf.name_scope('conv1'):
            # Transform to Fourier domain
            x_2d = tf.reshape(x, [-1, 28, 28])
            x_2d = tf.complex(x_2d, 0)
            xf_2d = tf.fft2d(x_2d)
            xf = tf.reshape(xf_2d, [-1, NFEATURES])
            xf = tf.expand_dims(xf, 1)  # NSAMPLES x 1 x NFEATURES
            xf = tf.transpose(xf)  # NFEATURES x 1 x NSAMPLES
            # Filter
            Wreal = self._weight_variable([int(NFEATURES/2), self.F, 1])
            Wimg = self._weight_variable([int(NFEATURES/2), self.F, 1])
            W = tf.complex(Wreal, Wimg)
            xf = xf[:int(NFEATURES/2), :, :]
            yf = tf.matmul(W, xf)  # for each feature
            yf = tf.concat([yf, tf.conj(yf)], axis=0)
            yf = tf.transpose(yf)  # NSAMPLES x NFILTERS x NFEATURES
            yf_2d = tf.reshape(yf, [-1, 28, 28])
            # Transform back to spatial domain
            y_2d = tf.ifft2d(yf_2d)
            y_2d = tf.real(y_2d)
            y = tf.reshape(y_2d, [-1, self.F, NFEATURES])
            # Bias and non-linearity
            b = self._bias_variable([1, self.F, 1])
#            b = self._bias_variable([1, self.F, NFEATURES])
            y += b  # NSAMPLES x NFILTERS x NFEATURES
            y = tf.nn.relu(y)
        with tf.name_scope('fc1'):
            W = self._weight_variable([self.F*NFEATURES, NCLASSES])
            b = self._bias_variable([NCLASSES])
            y = tf.reshape(y, [-1, self.F*NFEATURES])
            y = tf.matmul(y, W) + b
        return y
    def get_reconstructed_image(self, real, imag, name=None):
        """
        :param real:
        :param imag:
        :param name:
        :return:
        """
        complex_k_space_label = tf.complex(real=tf.squeeze(real), imag=tf.squeeze(imag), name=name+"_complex_k_space")
        rec_image_complex = tf.expand_dims(tf.ifft2d(complex_k_space_label), axis=1)
        
        rec_image_real = tf.reshape(tf.real(rec_image_complex), shape=[-1, 1, self.dims_out[1], self.dims_out[2]])
        rec_image_imag = tf.reshape(tf.imag(rec_image_complex), shape=[-1, 1, self.dims_out[1], self.dims_out[2]])

        # Shifting
        top, bottom = tf.split(rec_image_real, num_or_size_splits=2, axis=2)
        top_left, top_right = tf.split(top, num_or_size_splits=2, axis=3)
        bottom_left, bottom_right = tf.split(bottom, num_or_size_splits=2, axis=3)

        top_shift = tf.concat(axis=3, values=[bottom_right, bottom_left])
        bottom_shift = tf.concat(axis=3, values=[top_right, top_left])
        shifted_image = tf.concat(axis=2, values=[top_shift, bottom_shift])


        # Shifting
        top_imag, bottom_imag = tf.split(rec_image_imag, num_or_size_splits=2, axis=2)
        top_left_imag, top_right_imag = tf.split(top_imag, num_or_size_splits=2, axis=3)
        bottom_left_imag, bottom_right_imag = tf.split(bottom_imag, num_or_size_splits=2, axis=3)

        top_shift_imag = tf.concat(axis=3, values=[bottom_right_imag, bottom_left_imag])
        bottom_shift_imag = tf.concat(axis=3, values=[top_right_imag, top_left_imag])
        shifted_image_imag = tf.concat(axis=2, values=[top_shift_imag, bottom_shift_imag])

        shifted_image_two_channels = tf.stack([shifted_image[:,0,:,:], shifted_image_imag[:,0,:,:]], axis=1)
        return shifted_image_two_channels
def compute_fft(x, direction="C2C", inverse=False):

    if direction == 'C2R':
        inverse = True

    x_shape = x.get_shape().as_list()
    h, w = x_shape[-2], x_shape[-3]

    x_complex = tf.complex(x[..., 0], x[..., 1])

    if direction == 'C2R':
        out = tf.real(tf.ifft2d(x_complex)) * h * w
        return out

    else:
        if inverse:
            out = stack_real_imag(tf.ifft2d(x_complex)) * h * w
        else:
            out = stack_real_imag(tf.fft2d(x_complex))
        return out
示例#4
0
    def body(i, fi, fi_1, fi_2):
        with tf.name_scope("Analysis_Trans"):
            x = fi + alpha * tf.multiply(mask, (g-fi))
            coeffs = tf.ifft2d(tf.multiply(tf.fft2d(x), dec_fft) )
        with tf.name_scope("Hard_Thresholding"):
            comp = tf.greater(tf.abs(coeffs), tf.multiply(thresholds[i], w_st) )
            coeffs = tf.multiply(tf.cast(comp, tf.complex64), coeffs)
        with tf.name_scope("Synthesis_Trans"):
            coeffs_fft = tf.multiply(tf.fft2d(coeffs), rec_fft )
            f_hat = tf.ifft2d(tf.reduce_sum(coeffs_fft, 1, keepdims = True))
        # two-step overrelaxation
        with tf.name_scope("Double_Overrelaxation"):
            beta1 = tf.divide( tf.reduce_sum((g - f_hat) * mask * (f_hat - fi_1), axis=[1, 2, 3], keepdims=True),
                                tf.reduce_sum((f_hat - fi_1) * mask * (f_hat - fi_1), axis=[1, 2, 3], keepdims=True) + num_tiny )
            beta1 = tf.clip_by_value(tf.cast(beta1, tf.float32), tf.constant(0, tf.float32), tf.constant(1, tf.float32))
            f_tilde = f_hat + tf.cast(beta1, tf.complex64) * (f_hat - fi_1)

            beta2 = tf.divide( tf.reduce_sum((g - f_tilde) * mask * (f_tilde - fi_2), axis=[1, 2, 3], keepdims=True), 
                                tf.reduce_sum((f_tilde - fi_2) * mask * (f_tilde - fi_2), axis=[1, 2, 3], keepdims=True) + num_tiny )
            beta2 = tf.clip_by_value(tf.cast(beta2, tf.float32), tf.constant(0, tf.float32), tf.constant(1, tf.float32))
            f_i_new = f_tilde + tf.cast(beta2, tf.complex64) * (f_tilde - fi_2)

        return tf.add(i, 1), f_i_new, fi, fi_1
示例#5
0
 def complexsaliency(self, phamap, ampmap):
     for i in range(self.batch_size):
         out_angle = phamap[i, :, :, 0]
         out_mag = ampmap[i, :, :, 0]
         outcomplex = tf.complex(out_mag * tf.cos(out_angle),
                                 out_mag * tf.sin(out_angle))
         outsalmap = tf.abs(tf.ifft2d(outcomplex))
         outsalmap = tf.expand_dims(outsalmap, -1)
         outsalmap = tf.expand_dims(outsalmap, 0)
         if i == 0:
             compredict = outsalmap
         else:
             compredict = tf.concat([compredict, outsalmap], axis=0)
     return compredict
def inverse_filter(blurred, estimate, psf, gamma=None, init_gamma=2.):
    """Inverse filtering in the frequency domain.

    Args:
        blurred: image with shape (batch_size, height, width, num_img_channels)
        estimate: image with shape (batch_size, height, width, num_img_channels)
        psf: filters with shape (kernel_height, kernel_width, num_img_channels, num_filters)
        gamma: Optional. Scalar that determines regularization (higher --> more regularization, output is closer to
               "estimate", lower --> less regularization, output is closer to straight inverse filtered-result). If
               not passed, a trainable variable will be created.
        init_gamma: Optional. Scalar that determines the square root of the initial value of gamma.
    """
    img_shape = blurred.shape.as_list()

    if gamma is None:  # Gamma (the regularization parameter) is also a trainable parameter.
        gamma_initializer = tf.constant_initializer(init_gamma)
        gamma = tf.get_variable(name="gamma",
                                shape=(),
                                dtype=tf.float32,
                                trainable=True,
                                initializer=gamma_initializer)
        gamma = tf.square(gamma)  # Enforces positivity of gamma.
        tf.summary.scalar('gamma', gamma)

    a_tensor_transp = tf.transpose(blurred, [0, 3, 1, 2])
    estimate_transp = tf.transpose(estimate, [0, 3, 1, 2])

    # Everything has shape (batch_size, num_channels, height, width)
    img_fft = tf.fft2d(tf.complex(a_tensor_transp, 0.))
    # otf = my_psf2otf(psf, output_size=img_shape[1:3])
    otf = psf2otf(psf, output_size=img_shape[1:3])
    otf = tf.transpose(otf, [2, 3, 0, 1])

    adj_conv = img_fft * tf.conj(otf)

    # This is a slight modification to standard inverse filtering - gamma not only regularizes the inverse filtering,
    # but also trades off between the regularized inverse filter and the unfiltered estimate_transp.
    numerator = adj_conv + tf.fft2d(tf.complex(gamma * estimate_transp, 0.))

    kernel_mags = tf.square(tf.abs(otf))  # Magnitudes of the blur kernel.

    denominator = tf.complex(kernel_mags + gamma, 0.0)
    filtered = tf.div(numerator, denominator)
    cplx_result = tf.ifft2d(filtered)
    real_result = tf.real(cplx_result)  # Discard complex parts.
    real_result = tf.maximum(1e-5,real_result)

    # Get back to (batch_size, num_channels, height, width)
    result = tf.transpose(real_result, [0, 2, 3, 1])
    return result
示例#7
0
def _mygrad(op, grad):
    # Spatial whitening by FFT assuming 1/sqrt(F) spectrum
    num_px = int(grad.shape[1])
    grad = tf.transpose(grad, [0, 3, 1, 2])
    grad_fft = tf.fft2d(tf.cast(grad, tf.complex64))
    t = np.minimum(np.arange(0, num_px),
                   np.arange(num_px, 0, -1),
                   dtype=np.float32)
    t = 1 / np.maximum(1.0, (t[None, :]**2 + t[:, None]**2)**(1 / 4))
    F = tf.constant(t / t.mean(), dtype=tf.float32, name='F')
    grad_fft *= tf.cast(F, tf.complex64)
    grad = tf.ifft2d(grad_fft)
    grad = tf.transpose(tf.cast(grad, tf.float32), [0, 2, 3, 1])
    return grad
示例#8
0
    def get_noisy_dirty_image(self, weighting='natural', return_full=False):
        if weighting == 'uniform':
            return tf.transpose(tf.real(tf.ifft2d(self.vis_full_sampled_noisy) * \
                                tf.cast(tf.sqrt(self.N),dtype=tf.complex64)) ,
                                [0,2,3,1])[:,192/2:3*192/2,192/2:3*192/2,:]
        else:
            if self.vis_input:
                dim_full = tf.transpose(tf_fftshift(tf.real(tf.ifft2d(tf.multiply(self.vis_full_sampled_noisy,\
                                        tf.cast(self.UVGRID_full,dtype=tf.complex64))) * \
                                        tf.cast(tf.sqrt(self.N),dtype=tf.complex64))),[0,2,3,1])
            else:
                dim_full = tf.transpose(tf.real(tf.ifft2d(tf.multiply(self.vis_full_sampled_noisy,\
                                        tf.cast(self.UVGRID_full,dtype=tf.complex64))) * \
                                        tf.cast(tf.sqrt(self.N),dtype=tf.complex64)),[0,2,3,1])

            N, H, W, C = dim_full.get_shape().as_list()

            if return_full == False:
                # return only the center 192 pixels (this is temporary)
                return dim_full[:, H / 2 - 192 / 2:H / 2 + 192 / 2,
                                W / 2 - 192 / 2:W / 2 + 192 / 2, :]
            else:
                return dim_full
示例#9
0
def data_consistency(generated, X_k, mask):
    gene_complex = real2complex(generated)
    gene_complex = tf.transpose(gene_complex, [0, 3, 1, 2])
    mask = tf.transpose(mask, [0, 3, 1, 2])
    X_k = tf.transpose(X_k, [0, 3, 1, 2])
    gene_fft = tf.fft2d(gene_complex)
    out_fft = X_k + gene_fft * (1.0 - mask)
    output_complex = tf.ifft2d(out_fft)
    output_complex = tf.transpose(output_complex, [0, 2, 3, 1])
    output_real = tf.cast(tf.real(output_complex), dtype=tf.float32)
    output_imag = tf.cast(tf.imag(output_complex), dtype=tf.float32)
    output = tf.concat([output_real, output_imag], axis=-1)

    return output
示例#10
0
    def EhE_Op(self, img, mu):
        """
        Performs (E^h*E+ mu*I) x
        """
        with tf.name_scope('EhE'):
            coil_imgs = self.sens_maps * img
            kspace = tf_utils.tf_fftshift(tf.fft2d(tf_utils.tf_ifftshift(coil_imgs))) / self.scalar
            masked_kspace = kspace * self.mask
            image_space_coil_imgs = tf_utils.tf_ifftshift(tf.ifft2d(tf_utils.tf_fftshift(masked_kspace))) * self.scalar
            image_space_comb = tf.reduce_sum(image_space_coil_imgs * tf.conj(self.sens_maps), axis=0)

            ispace = image_space_comb + mu * img

        return ispace
示例#11
0
def spectral_pool(image, pool_size=4,
                  convert_grayscale=True):
    """ Perform a single spectral pool operation.
    Args:
        image: numpy array representing an image
        pool_size: number of dimensions to throw away in each dimension,
                   same as the filter size of max_pool
        convert_grayscale: bool, if True, the image will be converted to
                           grayscale
    Returns:
        An image of shape (n, n, 1) if grayscale is True or same as input
    """
    tf.reset_default_graph()
    im = tf.placeholder(shape=image.shape, dtype=tf.float32)
    if convert_grayscale:
        im_conv = tf.image.rgb_to_grayscale(im)
    else:
        im_conv = im
    # make channels first
    im_channel_first = tf.transpose(im_conv, perm=[2, 0, 1])
    im_fft = tf.fft2d(tf.cast(im_channel_first, tf.complex64))
    lowpass = tf.get_variable(name='lowpass',
                              initializer=get_low_pass_filter(
                                    im_channel_first.get_shape().as_list(),
                                    pool_size))
    im_magnitude = tf.multiply(tf.abs(im_fft), lowpass)
    im_angles = tf.angle(im_fft)
    part1 = tf.complex(real=im_magnitude,
                       imag=tf.zeros_like(im_angles))
    part2 = tf.exp(tf.complex(real=tf.zeros_like(im_magnitude),
                              imag=im_angles))
    im_fft_lowpass = tf.multiply(part1, part2)
    im_transformed = tf.ifft2d(im_fft_lowpass)
    # make channels last and real values:
    im_channel_last = tf.real(tf.transpose(im_transformed, perm=[1, 2, 0]))

    # normalize image:
    channel_max = tf.reduce_max(im_channel_last, axis=(0, 1))
    channel_min = tf.reduce_min(im_channel_last, axis=(0, 1))
    im_out = tf.divide(im_channel_last - channel_min,
                       channel_max - channel_min)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        im_fftout, im_new = sess.run([im_magnitude, im_out],
                                     feed_dict={im: image})

    return im_fftout, im_new
示例#12
0
def upsample(x, mask):

    image_complex = tf.ifft2d(x)
    image_size = [FLAGS.batch_size, FLAGS.sample_size,
                  FLAGS.sample_size_y]  #tf.shape(image_complex)

    #get real and imaginary parts
    image_real = tf.reshape(tf.real(image_complex),
                            [image_size[0], image_size[1], image_size[2], 1])
    image_imag = tf.reshape(tf.imag(image_complex),
                            [image_size[0], image_size[1], image_size[2], 1])

    out = tf.concat([image_real, image_imag], 3)

    return out
示例#13
0
 def ifft(self, k):
     rank = len(k.shape) - 2
     assert rank >= 1
     if rank == 1:
         return tf.stack([tf.ifft(c) for c in tf.unstack(k, axis=-1)],
                         axis=-1)
     elif rank == 2:
         return tf.stack([tf.ifft2d(c) for c in tf.unstack(k, axis=-1)],
                         axis=-1)
     elif rank == 3:
         return tf.stack([tf.ifft3d(c) for c in tf.unstack(k, axis=-1)],
                         axis=-1)
     else:
         raise NotImplementedError(
             'n-dimensional inverse FFT not implemented.')
示例#14
0
    def _build(self, inputs):
        outputs = tf.cast(inputs, tf.complex64)
        output_list = [tf.fft2d(e) for e in tf.unstack(outputs, axis=(-1))]

        output_list = [e * self._mask for e in output_list]
        image_list = [tf.abs(tf.ifft2d(e)) for e in output_list]

        image_list = [tf.clip_by_value(e, 0.0, 255.0) for e in image_list]
        outputs = tf.stack(image_list, axis=-1)

        with tf.control_dependencies(
            [tf.equal(tf.shape(inputs), tf.shape(outputs))]):
            outputs = tf.identity(outputs)

        return outputs
示例#15
0
 def At_handle(A_val_tf, z):
     sign_vec = A_val_tf[0:n]
     z_padded = tf.sparse_tensor_dense_matmul(sparse_sampling_matrix,
                                              z,
                                              adjoint_a=True)
     z_padded = tf.reshape(z_padded,
                           [height_img, width_img, BATCH_SIZE])
     z_padded = tf.transpose(
         z_padded
     )  #Transpose because fft2d operates upon the last two axes
     Finv_z = tf.ifft2d(z_padded)
     Finv_z = tf.transpose(Finv_z)
     Finv_z = tf.reshape(Finv_z, [height_img * width_img, BATCH_SIZE])
     out = tf.multiply(tf.conj(sign_vec), Finv_z) * n_fp / np.sqrt(m)
     return out
def Clip_OperatorNorm(conv, inp_shape, clip_to):
    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)
    conv_shape = conv.get_shape().as_list()
    padding = tf.constant([[0, 0], [0, 0], [0, inp_shape[0] - conv_shape[0]],
                           [0, inp_shape[1] - conv_shape[1]]])
    transform_coeff = tf.fft2d(tf.pad(conv_tr, padding))
    D, U, V = tf.svd(tf.transpose(transform_coeff, perm=[2, 3, 0, 1]))
    norm = tf.reduce_max(D)
    D_clipped = tf.cast(tf.minimum(D, clip_to), tf.complex64)
    clipped_coeff = tf.matmul(
        U, tf.matmul(tf.linalg.diag(D_clipped), V, adjoint_b=True))
    clipped_conv_padded = tf.real(
        tf.ifft2d(tf.transpose(clipped_coeff, perm=[2, 3, 0, 1])))
    return tf.slice(tf.transpose(clipped_conv_padded, perm=[2, 3, 0, 1]),
                    [0] * len(conv_shape), conv_shape), norm
def tf_FSPAS_FFT(self, wlength, z, dx, dy, ridx, *theta0):
    """
    Angular Spectrum Propagation of Coherent Wave Fields
    with optional filtering
    
    INPUTS : 
        U, wave-field in space domain
        wlenght : wavelength of the optical wave
        z : distance of propagation
        dx,dy : sampling intervals in space
        M,N : Size of simulation window
        theta0 : Optional BAndwidth Limitation in DEGREES 
            (if no filtering is desired, only EVANESCENT WAVE IS FILTERED)
    
    OUTPUT : 
        output, propagated wave-field in space domain
    
    """

    wlengtheff = wlength / ridx
    B, M, N = self.shape
    dfx = 1 / dx / M.value
    dfy = 1 / dy / N.value
    fx = tf.constant((np.arange(M.value) - (M.value) / 2) * dfx,
                     shape=[M.value, 1],
                     dtype=tf.float32)
    fy = tf.constant((np.arange(N.value) - (N.value) / 2) * dfy,
                     shape=[1, N.value],
                     dtype=tf.float32)
    fx2 = tf.matmul(fx**2, tf.ones((1, N.value), dtype=tf.float32))
    fy2 = tf.matmul(tf.ones((M.value, 1), dtype=tf.float32), fy**2)
    if theta0:  #BANDLIMIT OF THE FREE-SPACE PROPAGATION
        f0 = np.sin(np.deg2rad(theta0)) / wlengtheff
        Q = tf.to_float(((fx2 + fy2) <= (f0**2)))
    else:
        Q = tf.to_float(((fx2 + fy2) <= (1 / wlengtheff**2)))

    W = Q * (fx2 + fy2) * (wlengtheff**2)
    Hphase = 2 * np.pi / wlengtheff * z * (tf.ones(
        (M.value, N.value)) - W)**(0.5)
    HFSP = tf.complex(Q * tf.cos(Hphase), Q * tf.sin(Hphase))
    ASpectrum = tf.fft2d(self)
    ASpectrum = tf_fft_shift_2d(ASpectrum)
    ASpectrum_z = tf_ifft_shift_2d(tf.multiply(HFSP, ASpectrum))
    output = tf.ifft2d(ASpectrum_z)
    output = tf.slice(output, np.int32([0, 0, 0]),
                      np.int32([B, M.value, N.value]))
    return output
示例#18
0
文件: final.py 项目: zjulds/Lyn
def _propogation(u0, d=delta, N = size, dL = dL, lmb = c/Hz,theta=0.0):
    #Parameter 
    df = 1.0/dL
    k = np.pi*2.0/lmb
    D= dL*dL/(N*lmb)
  
    #phase
    def phase(i,j):
        i -= N//2
        j -= N//2
        return ((i*df)*(i*df)+(j*df)*(j*df))
    ph  = np.fromfunction(phase,shape=(N,N),dtype=np.float32)
    #H
    H = np.exp(1.0j*k*d)*np.exp(-1.0j*lmb*np.pi*d*ph) 
    #Result
    return tf.ifft2d(np.fft.fftshift(H)*tf.fft2d(u0)*dL*dL/(N*N))*N*N*df*df
def DCLayer(incomings,data_shape,inv_noise_level):
    data, mask, sampled = incomings
    data = tf.cast(data,tf.complex64)
    dft2 = tf.fft2d(data, name='dc_dft2')
    dft2 = tf.cast(data,tf.float32)
    x = dft2
    if inv_noise_level:  # noisy case
        out = (x+ v * sampled) / (1 + v)
    else:  # noiseless case
        out = (1 - mask) * x + sampled
    

    out = tf.cast(out,tf.complex64)
    idft2 = tf.ifft2d(out, name='dc_idft2')
    idft2 = tf.cast(idft2,tf.float32)
    return idft2
示例#20
0
def inverse_filter(blurred,
                   estimate,
                   psf,
                   gamma=None,
                   otf=None,
                   init_gamma=1.5):
    """Implements Weiner deconvolution in the frequency domain, with circular boundary conditions.

     Args:
         blurred: image with shape (batch_size, height, width, num_img_channels)
         estimate: image with shape (batch_size, height, width, num_img_channels)
         psf: filters with shape (kernel_height, kernel_width, num_img_channels, num_filters)

     TODO precompute OTF, adj_filt_img.
     """
    img_shape = blurred.shape.as_list()

    if gamma is None:
        gamma_initializer = tf.constant_initializer(init_gamma)
        gamma = tf.get_variable(name="wiener_gamma",
                                shape=(),
                                dtype=tf.float32,
                                trainable=True,
                                initializer=gamma_initializer)
        gamma = tf.square(gamma)
        tf.summary.scalar('gamma', gamma)

    a_tensor_transp = tf.transpose(blurred, [0, 3, 1, 2])
    estimate_transp = tf.transpose(estimate, [0, 3, 1, 2])
    # Everything has shape (batch_size, num_channels, height, width)
    img_fft = tf.fft2d(tf.complex(a_tensor_transp, 0.))
    if otf is None:
        otf = optics.psf2otf(psf, output_size=img_shape[1:3])
        otf = tf.transpose(otf, [2, 3, 0, 1])

    adj_conv = img_fft * tf.conj(otf)
    numerator = adj_conv + tf.fft2d(tf.complex(gamma * estimate_transp, 0.))

    kernel_mags = tf.square(tf.abs(otf))

    denominator = tf.complex(kernel_mags + gamma, 0.0)
    filtered = tf.div(numerator, denominator)
    cplx_result = tf.ifft2d(filtered)
    real_result = tf.real(cplx_result)
    # Get back to (batch_size, num_channels, height, width)
    result = tf.transpose(real_result, [0, 2, 3, 1])
    return result
示例#21
0
def hdrplus_merge(imgs, N, c, sig):
    def ccast_tf(x): return tf.complex(x, tf.zeros_like(x))

    # imgs is [batch, h, w, ch]
    rcw = tf.expand_dims(rcwindow(N), axis=-1)
    imgs = imgs * rcw
    imgs = tf.transpose(imgs, [0, 3, 1, 2])
    imgs_f = tf.fft2d(ccast_tf(imgs))
    imgs_f = tf.transpose(imgs_f, [0, 2, 3, 1])
    Dz2 = tf.square(tf.abs(imgs_f[..., 0:1] - imgs_f))
    Az = Dz2 / (Dz2 + c*sig**2)
    filt0 = 1 + tf.expand_dims(tf.reduce_sum(Az[..., 1:], axis=-1), axis=-1)
    filts = tf.concat([filt0, 1 - Az[..., 1:]], axis=-1)
    output_f = tf.reduce_mean(imgs_f * ccast_tf(filts), axis=-1)
    output_f = tf.real(tf.ifft2d(output_f))

    return output_f
示例#22
0
    def build(self, vgg_fea_pca, model_alphaf, model_xf):

        vgg_fea_pca = tf.transpose(vgg_fea_pca, [2, 0, 1, 3])  #3249*7*6*100
        vgg_fea_pca = tf.reshape(
            vgg_fea_pca,
            [fea_sz[0], fea_sz[1], 7, nn_p, pca_energy])  #57*57*7*6*100
        vgg_fea_pca = tf.transpose(vgg_fea_pca, perm=[2, 3, 4, 0,
                                                      1])  #7*6*100*57*57
        vgg_fea_pca = tf.cast(vgg_fea_pca, dtype=tf.complex64)
        model_xf = tf.transpose(model_xf, perm=[1, 0, 2, 3, 4])  #1*6*100*57*57

        zf = tf.fft2d(vgg_fea_pca)  #7*6*100*57*57
        k_zf_xf = tf.reduce_sum(tf.multiply(zf, tf.conj(model_xf)),
                                axis=2) / M  #7*6*57*57

        response = tf.real(tf.ifft2d(k_zf_xf * model_alphaf))
        self.response = tf.expand_dims(response, axis=0)
示例#23
0
文件: util.py 项目: mdw771/beyond_dof
def free_propagate_paraxial(wavefront,
                            dist_cm,
                            r_cm,
                            wavelen_nm,
                            psize_cm,
                            h=None):
    m = (dist_cm + r_cm) / r_cm
    dist_nm = dist_cm * 1.e7
    dist_eff_nm = dist_nm / m
    psize_nm = psize_cm * 1.e7
    if h is None:
        h = get_kernel(dist_eff_nm, wavelen_nm, [psize_nm, psize_nm],
                       wavefront.shape)
        h = tf.convert_to_tensor(h, dtype=tf.complex64)
    wavefront = fftshift(tf.fft2d(wavefront)) * h
    wavefront = tf.ifft2d(ifftshift(wavefront))
    return wavefront, m
示例#24
0
文件: util.py 项目: mdw771/beyond_dof
 def modulate_and_propagate(i, wavefront):
     delta_slice = grid_delta_batch[:, :, :, i]
     # delta_slice = tf.cast(delta_slice, dtype=tf.complex64)
     beta_slice = grid_beta_batch[:, :, :, i]
     # beta_slice = tf.cast(beta_slice, dtype=tf.complex64)
     c = tf.exp(1j * k * delta_slice) * tf.exp(-k * beta_slice)
     wavefront = wavefront * c
     # wavefront = tf.ifft2d(tf.fft2d(wavefront) * h)
     if type == 'projection':
         wavefront, m = free_propagate_paraxial(wavefront, psize_cm,
                                                s_r_cm + psize_cm * i,
                                                lmbda_nm, psize_cm)
         wavefront = rescale_image(
             wavefront, m,
             [batch_size, obj_batch_shape[1], obj_batch_shape[2]])
     else:
         wavefront = tf.ifft2d(
             ifftshift(fftshift(tf.fft2d(wavefront)) * h))
     i = i + 1
     return (i, wavefront)
示例#25
0
    def get_reconstructed_image(self, real, imag, name=None):
        """
        :param real:
        :param imag:
        :param name:
        :return:
        """
        complex_k_space_label = tf.complex(real=tf.squeeze(real), imag=tf.squeeze(imag), name=name+"_complex_k_space")
        rec_image_complex = tf.expand_dims(tf.ifft2d(complex_k_space_label), axis=3)
        rec_image = tf.reshape(tf.abs(rec_image_complex), shape=[-1, 256, 256, 1])

        # Shifting
        top, bottom = tf.split(rec_image, num_or_size_splits=2, axis=1)
        top_left, top_right = tf.split(top, num_or_size_splits=2, axis=2)
        bottom_left, bottom_right = tf.split(bottom, num_or_size_splits=2, axis=2)

        top_shift = tf.concat(axis=2, values=[bottom_right, bottom_left])
        bottom_shift = tf.concat(axis=2, values=[top_right, top_left])
        shifted_image = tf.concat(axis=1, values=[top_shift, bottom_shift])
        return shifted_image
示例#26
0
    def run(self, hr_img, lr_img):
        self.train_op = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.loss)
        self.sess.run(tf.global_variables_initializer())
        print('run: ->', hr_img.shape)
        # shape = np.zeros(hr_img.shape)
        # err_ = []
        # print(shape)
        for er in range(self.epoch):
            # image = tf.reshape(image,[image.shape[0],image.shape[1]])
            _, x = self.sess.run([self.train_op, self.loss],
                                 feed_dict={
                                     self.images: lr_img,
                                     self.label: hr_img
                                 })
            # source = self.sess.run([self.source_fft],feed_dict={self.images: lr_img, self.label:hr_img})
            # imshow_spectrum(np.squeeze(source))

            # _residual = self.sess.run([self.label_risidual],feed_dict={self.images: lr_img, self.label:hr_img})
            # _r = tf.abs(tf.ifft2d(np.squeeze(_residual)))
            # # imshow_spectrum(np.squeeze(_residual))
            # # print(np.abs(_residual))
            # plt_imshow(np.squeeze(self.sess.run(_r)))

            print(x)
        # w = self.sess.run([self.spectral_c1],feed_dict={self.images: lr_img, self.label:hr_img})
        # w = np.asarray(w)
        # # w =np.squeeze(w)
        # # w = w /(1e3*1e-5)
        # print(w)
        # print('----')
        # print(w[:,:,:,0])
        # # imshow_spectrum(w)

    # #
        result = self.pred.eval({self.images: lr_img, self.label: hr_img})
        result = np.squeeze(self.sess.run(tf.real(tf.ifft2d(result))))
        # result = result*255/(1e3*1e-5)
        # result = np.clip(result, 0.0, 255.0).astype(np.uint8)
        plt_imshow(((result)))
        print(np.abs(result))
示例#27
0
def ifft2c(im, name="ifft2c", do_orthonorm=False):
    """Centered iFFT2."""
    with tf.name_scope(name):
        im_out = im
        if do_orthonorm:
            fftscale = tf.sqrt(1.0 * im_out.get_shape().as_list()[-2] *
                               im_out.get_shape().as_list()[-3])
        else:
            fftscale = 1.0
        fftscale = tf.cast(fftscale, dtype=tf.complex64)
        if len(im.get_shape()) == 5:
            im_out = tf.transpose(im_out, [0, 3, 4, 1, 2])
            im_out = fftshift(im_out, axis=4)
            im_out = fftshift(im_out, axis=3)
        elif len(im.get_shape()) == 4:
            im_out = tf.transpose(im_out, [0, 3, 1, 2])
            im_out = fftshift(im_out, axis=3)
            im_out = fftshift(im_out, axis=2)
        else:
            im_out = tf.transpose(im_out, [2, 0, 1])
            im_out = fftshift(im_out, axis=2)
            im_out = fftshift(im_out, axis=1)

        with tf.device('/gpu:0'):
            # FFT is only supported on the GPU
            im_out = tf.ifft2d(im_out) * fftscale

        if len(im.get_shape()) == 5:
            im_out = fftshift(im_out, axis=4)
            im_out = fftshift(im_out, axis=3)
            im_out = tf.transpose(im_out, [0, 3, 4, 1, 2])
        elif len(im.get_shape()) == 4:
            im_out = fftshift(im_out, axis=3)
            im_out = fftshift(im_out, axis=2)
            im_out = tf.transpose(im_out, [0, 2, 3, 1])
        else:
            im_out = fftshift(im_out, axis=2)
            im_out = fftshift(im_out, axis=1)
            im_out = tf.transpose(im_out, [1, 2, 0])

    return im_out
示例#28
0
def image_profile(image, img_size):
    image = tf.cast(image, dtype=tf.complex64)
    image = tf.reshape(image, [img_size, img_size])
    fft = tf.fft2d(image)
    congfft = tf.conj(fft)
    tot = fft * congfft
    ifft = tf.ifft2d(tot)
    autocorr = tf.abs(ifft) / (img_size * img_size)
    shape_at = tf.shape(autocorr)
    centrdImg = np.zeros([img_size, img_size])
    dm_hf = int(img_size / 2)
    topleft = tf.slice(autocorr, [0, 0], [dm_hf, dm_hf])
    topright = tf.slice(autocorr, [0, dm_hf], [dm_hf, dm_hf])
    bottomleft = tf.slice(autocorr, [dm_hf, 0], [dm_hf, dm_hf])
    bottomright = tf.slice(autocorr, [dm_hf, dm_hf], [dm_hf, dm_hf])
    bottomhalf = tf.concat([topright, topleft], 1)
    tophalf = tf.concat([bottomright, bottomleft], 1)
    centrdImg_tf = tf.concat([tophalf, bottomhalf], 0)
    center = [int(img_size / 2), int(img_size / 2)]
    image_prof, image_rad = radial_profile_tf(centrdImg_tf, center, img_size)
    return image_prof
示例#29
0
    def dLdx(self, y, x, sigma=1.0):
        # TODO derive and validate
        """
        Args:
            x (tf.tensor): The input image
                shape is [None, width, height, channels],
                dtype is tf.float32
            y (tf.tensor): The outputs in k-space
                shape is [None, width, height, channels],
                dtype is tf.complex64

        Returns:
            (tf.tensor): the grad of the loss w.r.t x
                shape is [None, width, height, channels],
                dtype is tf.complex64
        """
        # gets the mask used in the forward process.
        # NOTE be careful here. will only return the correct mask if called
        # after the corresponding forward process
        y_ = self.mask*tf.fft(x)
        return tf.ifft2d(y_-y)/(sigma**2)
示例#30
0
def FhRh(freq, mask, name='FhRh', is_normalized=False):
	with tf.variable_scope(name+'_scope'):
		# Convert from 2 channel to complex number
		freq = tf_complex(freq)
		mask = tf_complex(mask) 

		# Under sample
		condition = tf.cast(tf.real(mask)>0.9, tf.bool)
		freq_full = freq
		freq_zero = tf.zeros_like(freq_full)
		freq_dest = tf.where(condition, freq_full, freq_zero, name='RfFf')

		# Inverse Fourier Transform
		image 	  = tf.ifft2d(freq_dest, name='FtRt')
		
		if is_normalized:
			image = tf.div(image, ((DIMX-1)*(DIMY-1)))

		# Convert from complex number to 2 channel
		image = tf_channel(image)
	return tf.identity(image, name)
示例#31
0
def sparisty_regularization(ss_epi_im, mask, thresholds, alpha, dec_fft, rec_fft, w_st):
    # Initialization
    f0 = tf.cast(ss_epi_im, tf.complex64, "epi_cast")
    mask = tf.cast(mask, tf.complex64, "mask_cast")
    g = f0 * mask
    with tf.name_scope("EPI_Initialization"):
        f0 = tf.ifft2d(tf.fft2d(f0) * dec_fft[-1] * rec_fft[-1]) # pre-filtering only using the low-pass filter
    niter = thresholds.shape[0]
    num_tiny = tf.constant(1e-6, tf.complex64)

    def condition(i, fi, fi_1, fi_2):
        return tf.less(i, niter)

    def body(i, fi, fi_1, fi_2):
        with tf.name_scope("Analysis_Trans"):
            x = fi + alpha * tf.multiply(mask, (g-fi))
            coeffs = tf.ifft2d(tf.multiply(tf.fft2d(x), dec_fft) )
        with tf.name_scope("Hard_Thresholding"):
            comp = tf.greater(tf.abs(coeffs), tf.multiply(thresholds[i], w_st) )
            coeffs = tf.multiply(tf.cast(comp, tf.complex64), coeffs)
        with tf.name_scope("Synthesis_Trans"):
            coeffs_fft = tf.multiply(tf.fft2d(coeffs), rec_fft )
            f_hat = tf.ifft2d(tf.reduce_sum(coeffs_fft, 1, keepdims = True))
        # two-step overrelaxation
        with tf.name_scope("Double_Overrelaxation"):
            beta1 = tf.divide( tf.reduce_sum((g - f_hat) * mask * (f_hat - fi_1), axis=[1, 2, 3], keepdims=True),
                                tf.reduce_sum((f_hat - fi_1) * mask * (f_hat - fi_1), axis=[1, 2, 3], keepdims=True) + num_tiny )
            beta1 = tf.clip_by_value(tf.cast(beta1, tf.float32), tf.constant(0, tf.float32), tf.constant(1, tf.float32))
            f_tilde = f_hat + tf.cast(beta1, tf.complex64) * (f_hat - fi_1)

            beta2 = tf.divide( tf.reduce_sum((g - f_tilde) * mask * (f_tilde - fi_2), axis=[1, 2, 3], keepdims=True), 
                                tf.reduce_sum((f_tilde - fi_2) * mask * (f_tilde - fi_2), axis=[1, 2, 3], keepdims=True) + num_tiny )
            beta2 = tf.clip_by_value(tf.cast(beta2, tf.float32), tf.constant(0, tf.float32), tf.constant(1, tf.float32))
            f_i_new = f_tilde + tf.cast(beta2, tf.complex64) * (f_tilde - fi_2)

        return tf.add(i, 1), f_i_new, fi, fi_1

    _, fi, _, _ = tf.while_loop(condition, body, [tf.constant(0), f0, f0, f0], name="Sparsity_Regularization")

    return tf.cast(fi, tf.float32)
示例#32
0
def recon_loss_L1_2chan_fixed3(y_true, y_pred):
    y_pred = tf.cast(tf.squeeze(y_pred, axis=0), dtype=tf.complex64)
    fft_img = tf.fft2d(y_pred[:, :, 0] + 1j * y_pred[:, :, 1])
    masked = tf.multiply(fft_img, mask)
    ifft = tf.ifft2d(masked)
    squeeze_ytrue = tf.cast(tf.squeeze(y_true, axis=0), dtype=tf.complex64)

    #New conjugate loss function
    imag_ytrue = squeeze_ytrue[:, :, 0] + 1j * squeeze_ytrue[:, :, 1]
    subtract = tf.cast(ifft - imag_ytrue, tf.complex64)
    conj = tf.cast(tf.real(subtract), dtype=tf.complex64) - 1j * tf.cast(
        tf.imag(subtract), dtype=tf.complex64)
    loss = tf.sqrt(tf.real(tf.multiply(subtract, conj)))

    loss = tf.reduce_sum(loss)

    print("ifft", ifft.shape)
    print("imag_ytrue", imag_ytrue.shape)
    print("squeeze_ytrue", squeeze_ytrue.shape)
    print("loss", loss.shape)

    return loss
示例#33
0
 def feaIDFT(self, feapha, feaamp):
     input_shapes = feapha.get_shape().as_list()
     for i in range(self.batch_size):
         outfeachannel = []
         for j in range(input_shapes[-1]):
             out_angle = feapha[i, :, :, j]
             out_mag = feaamp[i, :, :, j]
             outcomplex = tf.complex(out_mag * tf.cos(out_angle),
                                     out_mag * tf.sin(out_angle))
             outcomplex = self.specshift(outcomplex)
             outfea = tf.abs(tf.ifft2d(outcomplex))
             outfea = tf.expand_dims(outfea, -1)
             if j == 0:
                 outfeachannel = outfea
             else:
                 outfeachannel = tf.concat([outfeachannel, outfea], axis=-1)
         outfeachannel = tf.expand_dims(outfeachannel, 0)
         if i == 0:
             complexfea = outfeachannel
         else:
             complexfea = tf.concat([complexfea, outfeachannel], axis=0)
     return complexfea
示例#34
0
    def get_reconstructed_image(self, real, imag, name=None):
        """
        :param real:
        :param imag:
        :param name:
        :return:
        """
        factors = self.FLAGS.data_factors

        mu_r = np.float32(factors['mean']['k_space_real'])
        sigma_r = np.sqrt(np.float32(factors['variance']['k_space_real']))

        mu_i = np.float32(factors['mean']['k_space_imag'])
        sigma_i = np.sqrt(np.float32(factors['variance']['k_space_imag']))

        complex_k_space_label = tf.complex(real=(tf.squeeze(real) - mu_r) / sigma_r,
                                     imag=(tf.squeeze(imag) - mu_i) / sigma_i, name=name+"_complex_k_space")
        rec_image_complex = tf.expand_dims(tf.ifft2d(complex_k_space_label), axis=3)
        # import pdb
        # pdb.set_trace()
        rec_image = tf.reshape(tf.abs(rec_image_complex), shape=[-1, 256, 256, 1])
        return rec_image
示例#35
0
文件: losses.py 项目: nlaanait/stemdl
def thin_object(psi_k_re, psi_k_im, potential, summarize=True):
    # mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
    # ratio = 0
    # if ratio == 0:
    #     center = slice(None, None) 
    # else:
    #     center = slice(int(ratio * mask.shape[-1]), int((1-ratio)* mask.shape[-1]))
    # mask[:,:,center,center] = 1.
    # mask = tf.constant(mask, dtype=tf.complex64)
    psi_x = fftshift(tf.ifft2d(tf.cast(psi_k_re, tf.complex64) * tf.exp( 1.j * tf.cast(psi_k_im, tf.complex64))))
    scan_range = psi_x.shape.as_list()[-1]//2
    vx, vy = np.linspace(-scan_range, scan_range, num=4), np.linspace(-scan_range, scan_range, num=4)
    X, Y = np.meshgrid(vx.astype(np.int), vy.astype(np.int))
    psi_x_stack = [tf.roll(psi_x, shift=[x,y], axis=[1,2]) for (x,y) in zip(X.flatten(), Y.flatten())]
    psi_x_stack = tf.concat(psi_x_stack, axis=1)
    pot_frac = tf.exp(1.j * tf.cast(potential, tf.complex64))
    psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
    psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2
    psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True)
    if summarize:
        tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1)
        tf.summary.image('Psi_x_in', tf.transpose(tf.abs(psi_x)**0.25, perm=[0,2,3,1]), max_outputs=1)
    return psi_out_mod 
示例#36
0
 def _tfIFFT2D(self, x, use_gpu=False):
   with self.test_session(use_gpu=use_gpu):
     return tf.ifft2d(x).eval()
示例#37
0
文件: ops.py 项目: kestrelm/tfdeploy
 def test_IFFT2D(self):
     # only defined for gpu
     if DEVICE == GPU:
         t = tf.ifft2d(self.random(3, 4, complex=True))
         self.check(t)