def run_pd(im,
           samp_patt,
           wav,
           levels,
           n_iter,
           eta,
           sigma=0.5,
           tau=0.5,
           theta=1):
    """Perform experiment"""
    N = im.shape[0]
    result_coeffs = build_pd_graph(N, wav, levels)

    real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels)
    node = tf.complex(real_idwt, imag_idwt)

    im = np.expand_dims(im, -1).astype(np.complex)
    samp_patt = np.expand_dims(samp_patt, -1).astype(np.bool)

    start = time.time()
    with tf.Session() as sess:
        result = sess.run(node,
                          feed_dict={
                              'image:0': im,
                              'sampling_pattern:0': samp_patt,
                              'sigma:0': sigma,
                              'eta:0': eta,
                              'tau:0': tau,
                              'theta:0': theta,
                              'n_iter:0': n_iter
                          })
    end = time.time()
    print(end - start)
    return np.abs(np.squeeze(result))
Beispiel #2
0
def build_graph(tau):
    """Builds graph where the input is an image placeholder + a perturbation variable"""

    # tf_tau = tf.placeholder(tf.float32, shape=(), name='stab_tau')
    tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda')

    N = 128
    wav = db4
    levels = 5

    # Build FISTA graph
    tf_im = tf.placeholder(tf.complex64, shape=[N, N, 1], name='image')
    tf_samp_patt = tf.placeholder(tf.bool,
                                  shape=[N, N, 1],
                                  name='sampling_pattern')

    # perturbation
    tf_rr_real = tf.Variable(tau * tf.random_uniform(tf_im.shape),
                             name='rr_real',
                             trainable=True)
    tf_rr_imag = tf.Variable(tau * tf.random_uniform(tf_im.shape),
                             name='rr_imag',
                             trainable=True)

    tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr')

    tf_input = tf_im + tf_rr

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_input)

    tf_adjoint_coeffs = op(measurements, adjoint=True)
    adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels)
    adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels)
    tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt)

    gradient = LASSOGradient(op, measurements)

    alg = FISTA(gradient)

    initial_x = op(measurements, adjoint=True)
    result_coeffs = alg.run(initial_x)

    real_idwt = idwt2d(tf.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels)
    result_image = tf.complex(real_idwt, imag_idwt)
    # End build primal-dual graph

    # Start building objective function for adv noise

    # the output of PD with no adv. noise
    tf_solution = tf.placeholder(tf.complex64, shape=[N, N, 1], name='actual')

    tf_obj = tf.nn.l2_loss(
        tf.abs(result_image -
               tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr))
    # End building objective function for adv noise

    return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint
def build_full_pd_graph(N, wav, levels):

    result_coeffs = build_pd_graph(N, wav, levels)

    real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.math.imag(result_coeffs), wav, levels)
    node = tf.complex(real_idwt, imag_idwt)

    return node
def build_graph(tau):
    N = 128
    wav = db4
    levels = 4

    tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda')
    
    # Build Primal-dual graph
    tf_im = tf.placeholder(tf.complex64, shape=[N,N,1], name='image')
    tf_samp_patt = tf.placeholder(tf.bool, shape=[N,N,1], name='sampling_pattern')

    # perturbation
    tf_rr_real = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_real', trainable=True)
    tf_rr_imag = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_imag', trainable=True)

    tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr')

    tf_input = tf_im + tf_rr

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_input)

    tf_adjoint_coeffs = op(measurements, adjoint=True)
    adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels)
    adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels)
    tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt)

    prox1 = SQLassoProx1() 
    prox2 = SQLassoProx2()

    alg = SquareRootLASSO(op, prox1, prox2, measurements)

    initial_x = op(measurements, adjoint=True)

    result_coeffs = alg.run(initial_x)

    real_idwt = idwt2d(tf.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels)
    result_image = tf.complex(real_idwt, imag_idwt)

    tf_solution = tf.placeholder(tf.complex64, shape=[N,N,1], name='actual')

    tf_obj = tf.nn.l2_loss(tf.abs(result_image - tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr))
    # End building objective function for adv noise


    return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint

    return result_image
Beispiel #5
0
    def forward(self, x):
        """
        Arguments:
           x: Tensor
        """
        real_idwt = idwt2d(tf.math.real(x), self.wavelet, self.levels)
        imag_idwt = idwt2d(tf.math.imag(x), self.wavelet, self.levels)
        result = tf.dtypes.complex(real_idwt, imag_idwt)

        result = tf.transpose(result, [2,0,1])

        result = tf.dtypes.complex(1.0/tf.sqrt(tf.cast(tf.size(result), self.dtype)), tf.constant(0.0, dtype=self.dtype)) * tf.signal.fft2d(result)
        result = tf.transpose(result, [1,2,0])

        # Subsampling
        result = tf.compat.v1.where_v2(self.samp_patt, result, tf.zeros_like(result))
        return result
Beispiel #6
0
    def forward(x):

        # Compute the IDWT both for real and imaginary part
        # tf.conv1d does not support complex numbers
        # tfWavelets only support 3D-tensors
        real_idwt = idwt2d(tf.real(x), wavelet, levels)
        imag_idwt = idwt2d(tf.imag(x), wavelet, levels)
        complex_idwt = tf.complex(real_idwt, imag_idwt)

        # FFT2 uses the two last dimensions
        result = tf.transpose(complex_idwt, [2,0,1]) # [channels, height, width]
        # TODO: Scaling?
        result = 1./nsqrt * tf.fft2d(result)
        result = tf.transpose(result, [1,2,0]) # [height, width, channels]

        result = tf.where(mask, result, tf.zeros_like(result))
        return result
def build_fista_graph(N, wav, levels):
    tf_im = tf.compat.v1.placeholder(tf.complex64,
                                     shape=[N, N, 1],
                                     name='image')
    tf_samp_patt = tf.compat.v1.placeholder(tf.bool,
                                            shape=[N, N, 1],
                                            name='sampling_pattern')

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_im)

    initial_x = op(measurements, adjoint=True)

    gradient = LASSOGradient(op, measurements)
    alg = FISTA(gradient)

    result_coeffs = alg.run(op(measurements, True))
    real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.math.imag(result_coeffs), wav, levels)
    output_node = tf.complex(real_idwt, imag_idwt)

    return output_node
Beispiel #8
0
wav = db4
wav_name = 'db4'
eta = 0.00

# def run(eta):
#     result = run_pd(im, samp_patt, wav, levels, n_iter, eta=eta)
#     with open('../data/results/data.csv', 'a') as outfile:
#         outfile.write('{start_time}.png,{n_iter},{levels},{wav_name},{eta}\n'.format(
#             start_time = start_time,
#             n_iter     = n_iter,
#             levels     = levels,
#             wav_name   = wav_name,
#             eta        = eta))

result_coeffs = build_pd_graph(im.shape[0], wav, levels)
real_idwt = idwt2d(tf.real(result_coeffs), wav, levels)
imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels)
node = tf.complex(real_idwt, imag_idwt)

im = np.expand_dims(im, -1)
samp_patt = np.expand_dims(samp_patt, -1)

with tf.Session() as sess:
    start_time = datetime.now().strftime('%F_%T')
    start = time.time()

    result = sess.run(node,
                      feed_dict={
                          'image:0': im,
                          'sampling_pattern:0': samp_patt,
                          'sigma:0': 0.5,
# Build Primal-dual graph
tf_im = tf.compat.v1.placeholder(cdtype, shape=[N, N, 1], name='image')
tf_samp_patt = tf.compat.v1.placeholder(tf.bool,
                                        shape=[N, N, 1],
                                        name='sampling_pattern')

# For the weighted l^1-norm
pl_weights = tf.compat.v1.placeholder(dtype, shape=[N, N, 1], name='weights')

tf_input = tf_im

op = MRIOperator(tf_samp_patt, wav, levels, dtype=dtype)
measurements = op.sample(tf_input)

tf_adjoint_coeffs = op(measurements, adjoint=True)
adj_real_idwt = idwt2d(tf.math.real(tf_adjoint_coeffs), wav, levels)
adj_imag_idwt = idwt2d(tf.math.imag(tf_adjoint_coeffs), wav, levels)
tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt)

prox1 = WeightedL1Prox(pl_weights, pl_lam * pl_tau, dtype=dtype)
prox2 = SQLassoProx2(dtype=dtype)

alg = SquareRootLASSO(op,
                      prox1,
                      prox2,
                      measurements,
                      sigma=pl_sigma,
                      tau=pl_tau,
                      lam=pl_lam,
                      dtype=dtype)