Ejemplo n.º 1
0
def build_graph(tau):
    """Builds graph where the input is an image placeholder + a perturbation variable"""

    # tf_tau = tf.placeholder(tf.float32, shape=(), name='stab_tau')
    tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda')

    N = 128
    wav = db4
    levels = 5

    # Build FISTA graph
    tf_im = tf.placeholder(tf.complex64, shape=[N, N, 1], name='image')
    tf_samp_patt = tf.placeholder(tf.bool,
                                  shape=[N, N, 1],
                                  name='sampling_pattern')

    # perturbation
    tf_rr_real = tf.Variable(tau * tf.random_uniform(tf_im.shape),
                             name='rr_real',
                             trainable=True)
    tf_rr_imag = tf.Variable(tau * tf.random_uniform(tf_im.shape),
                             name='rr_imag',
                             trainable=True)

    tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr')

    tf_input = tf_im + tf_rr

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_input)

    tf_adjoint_coeffs = op(measurements, adjoint=True)
    adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels)
    adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels)
    tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt)

    gradient = LASSOGradient(op, measurements)

    alg = FISTA(gradient)

    initial_x = op(measurements, adjoint=True)
    result_coeffs = alg.run(initial_x)

    real_idwt = idwt2d(tf.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels)
    result_image = tf.complex(real_idwt, imag_idwt)
    # End build primal-dual graph

    # Start building objective function for adv noise

    # the output of PD with no adv. noise
    tf_solution = tf.placeholder(tf.complex64, shape=[N, N, 1], name='actual')

    tf_obj = tf.nn.l2_loss(
        tf.abs(result_image -
               tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr))
    # End building objective function for adv noise

    return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint
Ejemplo n.º 2
0
def build_graph(tau):
    N = 128
    wav = db4
    levels = 4

    tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda')
    
    # Build Primal-dual graph
    tf_im = tf.placeholder(tf.complex64, shape=[N,N,1], name='image')
    tf_samp_patt = tf.placeholder(tf.bool, shape=[N,N,1], name='sampling_pattern')

    # perturbation
    tf_rr_real = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_real', trainable=True)
    tf_rr_imag = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_imag', trainable=True)

    tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr')

    tf_input = tf_im + tf_rr

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_input)

    tf_adjoint_coeffs = op(measurements, adjoint=True)
    adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels)
    adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels)
    tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt)

    prox1 = SQLassoProx1() 
    prox2 = SQLassoProx2()

    alg = SquareRootLASSO(op, prox1, prox2, measurements)

    initial_x = op(measurements, adjoint=True)

    result_coeffs = alg.run(initial_x)

    real_idwt = idwt2d(tf.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels)
    result_image = tf.complex(real_idwt, imag_idwt)

    tf_solution = tf.placeholder(tf.complex64, shape=[N,N,1], name='actual')

    tf_obj = tf.nn.l2_loss(tf.abs(result_image - tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr))
    # End building objective function for adv noise


    return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint

    return result_image
Ejemplo n.º 3
0
def build_fista_graph(N, wav, levels):
    tf_im = tf.compat.v1.placeholder(tf.complex64,
                                     shape=[N, N, 1],
                                     name='image')
    tf_samp_patt = tf.compat.v1.placeholder(tf.bool,
                                            shape=[N, N, 1],
                                            name='sampling_pattern')

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_im)

    initial_x = op(measurements, adjoint=True)

    gradient = LASSOGradient(op, measurements)
    alg = FISTA(gradient)

    result_coeffs = alg.run(op(measurements, True))
    real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels)
    imag_idwt = idwt2d(tf.math.imag(result_coeffs), wav, levels)
    output_node = tf.complex(real_idwt, imag_idwt)

    return output_node
Ejemplo n.º 4
0
def build_pd_graph(N, wav, levels):
    '''Returns the output node of the Primal-dual algorithm'''

    # Shapes must be set for the wavelet transform to be applicable
    tf_im = tf.compat.v1.placeholder(tf.complex64,
                                     shape=[N, N, 1],
                                     name='image')
    tf_samp_patt = tf.compat.v1.placeholder(tf.bool,
                                            shape=[N, N, 1],
                                            name='sampling_pattern')

    op = MRIOperator(tf_samp_patt, wav, levels)
    measurements = op.sample(tf_im)

    prox_f_star = BPDNFStar(measurements)
    prox_g = BPDNG()
    alg = PrimalDual(op, prox_f_star, prox_g)

    initial_x = op(measurements, adjoint=True)

    result_coeffs = alg.run(initial_x)

    return result_coeffs
pl_sigma = tf.compat.v1.placeholder(dtype, shape=(), name='sigma')
pl_tau = tf.compat.v1.placeholder(dtype, shape=(), name='tau')
pl_lam = tf.compat.v1.placeholder(dtype, shape=(), name='lambda')

# Build Primal-dual graph
tf_im = tf.compat.v1.placeholder(cdtype, shape=[N, N, 1], name='image')
tf_samp_patt = tf.compat.v1.placeholder(tf.bool,
                                        shape=[N, N, 1],
                                        name='sampling_pattern')

# For the weighted l^1-norm
pl_weights = tf.compat.v1.placeholder(dtype, shape=[N, N, 1], name='weights')

tf_input = tf_im

op = MRIOperator(tf_samp_patt, wav, levels, dtype=dtype)
measurements = op.sample(tf_input)

tf_adjoint_coeffs = op(measurements, adjoint=True)
adj_real_idwt = idwt2d(tf.math.real(tf_adjoint_coeffs), wav, levels)
adj_imag_idwt = idwt2d(tf.math.imag(tf_adjoint_coeffs), wav, levels)
tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt)

prox1 = WeightedL1Prox(pl_weights, pl_lam * pl_tau, dtype=dtype)
prox2 = SQLassoProx2(dtype=dtype)

alg = SquareRootLASSO(op,
                      prox1,
                      prox2,
                      measurements,
                      sigma=pl_sigma,