def build_graph(tau): """Builds graph where the input is an image placeholder + a perturbation variable""" # tf_tau = tf.placeholder(tf.float32, shape=(), name='stab_tau') tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda') N = 128 wav = db4 levels = 5 # Build FISTA graph tf_im = tf.placeholder(tf.complex64, shape=[N, N, 1], name='image') tf_samp_patt = tf.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') # perturbation tf_rr_real = tf.Variable(tau * tf.random_uniform(tf_im.shape), name='rr_real', trainable=True) tf_rr_imag = tf.Variable(tau * tf.random_uniform(tf_im.shape), name='rr_imag', trainable=True) tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr') tf_input = tf_im + tf_rr op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_input) tf_adjoint_coeffs = op(measurements, adjoint=True) adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels) adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels) tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt) gradient = LASSOGradient(op, measurements) alg = FISTA(gradient) initial_x = op(measurements, adjoint=True) result_coeffs = alg.run(initial_x) real_idwt = idwt2d(tf.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels) result_image = tf.complex(real_idwt, imag_idwt) # End build primal-dual graph # Start building objective function for adv noise # the output of PD with no adv. noise tf_solution = tf.placeholder(tf.complex64, shape=[N, N, 1], name='actual') tf_obj = tf.nn.l2_loss( tf.abs(result_image - tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr)) # End building objective function for adv noise return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint
def build_graph(tau): N = 128 wav = db4 levels = 4 tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda') # Build Primal-dual graph tf_im = tf.placeholder(tf.complex64, shape=[N,N,1], name='image') tf_samp_patt = tf.placeholder(tf.bool, shape=[N,N,1], name='sampling_pattern') # perturbation tf_rr_real = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_real', trainable=True) tf_rr_imag = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_imag', trainable=True) tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr') tf_input = tf_im + tf_rr op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_input) tf_adjoint_coeffs = op(measurements, adjoint=True) adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels) adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels) tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt) prox1 = SQLassoProx1() prox2 = SQLassoProx2() alg = SquareRootLASSO(op, prox1, prox2, measurements) initial_x = op(measurements, adjoint=True) result_coeffs = alg.run(initial_x) real_idwt = idwt2d(tf.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels) result_image = tf.complex(real_idwt, imag_idwt) tf_solution = tf.placeholder(tf.complex64, shape=[N,N,1], name='actual') tf_obj = tf.nn.l2_loss(tf.abs(result_image - tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr)) # End building objective function for adv noise return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint return result_image
def build_fista_graph(N, wav, levels): tf_im = tf.compat.v1.placeholder(tf.complex64, shape=[N, N, 1], name='image') tf_samp_patt = tf.compat.v1.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_im) initial_x = op(measurements, adjoint=True) gradient = LASSOGradient(op, measurements) alg = FISTA(gradient) result_coeffs = alg.run(op(measurements, True)) real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.math.imag(result_coeffs), wav, levels) output_node = tf.complex(real_idwt, imag_idwt) return output_node
def build_pd_graph(N, wav, levels): '''Returns the output node of the Primal-dual algorithm''' # Shapes must be set for the wavelet transform to be applicable tf_im = tf.compat.v1.placeholder(tf.complex64, shape=[N, N, 1], name='image') tf_samp_patt = tf.compat.v1.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_im) prox_f_star = BPDNFStar(measurements) prox_g = BPDNG() alg = PrimalDual(op, prox_f_star, prox_g) initial_x = op(measurements, adjoint=True) result_coeffs = alg.run(initial_x) return result_coeffs
pl_sigma = tf.compat.v1.placeholder(dtype, shape=(), name='sigma') pl_tau = tf.compat.v1.placeholder(dtype, shape=(), name='tau') pl_lam = tf.compat.v1.placeholder(dtype, shape=(), name='lambda') # Build Primal-dual graph tf_im = tf.compat.v1.placeholder(cdtype, shape=[N, N, 1], name='image') tf_samp_patt = tf.compat.v1.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') # For the weighted l^1-norm pl_weights = tf.compat.v1.placeholder(dtype, shape=[N, N, 1], name='weights') tf_input = tf_im op = MRIOperator(tf_samp_patt, wav, levels, dtype=dtype) measurements = op.sample(tf_input) tf_adjoint_coeffs = op(measurements, adjoint=True) adj_real_idwt = idwt2d(tf.math.real(tf_adjoint_coeffs), wav, levels) adj_imag_idwt = idwt2d(tf.math.imag(tf_adjoint_coeffs), wav, levels) tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt) prox1 = WeightedL1Prox(pl_weights, pl_lam * pl_tau, dtype=dtype) prox2 = SQLassoProx2(dtype=dtype) alg = SquareRootLASSO(op, prox1, prox2, measurements, sigma=pl_sigma,