def run_pd(im, samp_patt, wav, levels, n_iter, eta, sigma=0.5, tau=0.5, theta=1): """Perform experiment""" N = im.shape[0] result_coeffs = build_pd_graph(N, wav, levels) real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels) node = tf.complex(real_idwt, imag_idwt) im = np.expand_dims(im, -1).astype(np.complex) samp_patt = np.expand_dims(samp_patt, -1).astype(np.bool) start = time.time() with tf.Session() as sess: result = sess.run(node, feed_dict={ 'image:0': im, 'sampling_pattern:0': samp_patt, 'sigma:0': sigma, 'eta:0': eta, 'tau:0': tau, 'theta:0': theta, 'n_iter:0': n_iter }) end = time.time() print(end - start) return np.abs(np.squeeze(result))
def build_graph(tau): """Builds graph where the input is an image placeholder + a perturbation variable""" # tf_tau = tf.placeholder(tf.float32, shape=(), name='stab_tau') tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda') N = 128 wav = db4 levels = 5 # Build FISTA graph tf_im = tf.placeholder(tf.complex64, shape=[N, N, 1], name='image') tf_samp_patt = tf.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') # perturbation tf_rr_real = tf.Variable(tau * tf.random_uniform(tf_im.shape), name='rr_real', trainable=True) tf_rr_imag = tf.Variable(tau * tf.random_uniform(tf_im.shape), name='rr_imag', trainable=True) tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr') tf_input = tf_im + tf_rr op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_input) tf_adjoint_coeffs = op(measurements, adjoint=True) adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels) adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels) tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt) gradient = LASSOGradient(op, measurements) alg = FISTA(gradient) initial_x = op(measurements, adjoint=True) result_coeffs = alg.run(initial_x) real_idwt = idwt2d(tf.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels) result_image = tf.complex(real_idwt, imag_idwt) # End build primal-dual graph # Start building objective function for adv noise # the output of PD with no adv. noise tf_solution = tf.placeholder(tf.complex64, shape=[N, N, 1], name='actual') tf_obj = tf.nn.l2_loss( tf.abs(result_image - tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr)) # End building objective function for adv noise return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint
def build_full_pd_graph(N, wav, levels): result_coeffs = build_pd_graph(N, wav, levels) real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.math.imag(result_coeffs), wav, levels) node = tf.complex(real_idwt, imag_idwt) return node
def build_graph(tau): N = 128 wav = db4 levels = 4 tf_lam = tf.placeholder(tf.float32, shape=(), name='stab_lambda') # Build Primal-dual graph tf_im = tf.placeholder(tf.complex64, shape=[N,N,1], name='image') tf_samp_patt = tf.placeholder(tf.bool, shape=[N,N,1], name='sampling_pattern') # perturbation tf_rr_real = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_real', trainable=True) tf_rr_imag = tf.Variable(tau*tf.random_uniform(tf_im.shape), name='rr_imag', trainable=True) tf_rr = tf.complex(tf_rr_real, tf_rr_imag, name='rr') tf_input = tf_im + tf_rr op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_input) tf_adjoint_coeffs = op(measurements, adjoint=True) adj_real_idwt = idwt2d(tf.real(tf_adjoint_coeffs), wav, levels) adj_imag_idwt = idwt2d(tf.imag(tf_adjoint_coeffs), wav, levels) tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt) prox1 = SQLassoProx1() prox2 = SQLassoProx2() alg = SquareRootLASSO(op, prox1, prox2, measurements) initial_x = op(measurements, adjoint=True) result_coeffs = alg.run(initial_x) real_idwt = idwt2d(tf.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels) result_image = tf.complex(real_idwt, imag_idwt) tf_solution = tf.placeholder(tf.complex64, shape=[N,N,1], name='actual') tf_obj = tf.nn.l2_loss(tf.abs(result_image - tf_solution)) - tf_lam * tf.nn.l2_loss(tf.abs(tf_rr)) # End building objective function for adv noise return tf_rr_real, tf_rr_imag, tf_input, result_image, tf_obj, tf_adjoint return result_image
def forward(self, x): """ Arguments: x: Tensor """ real_idwt = idwt2d(tf.math.real(x), self.wavelet, self.levels) imag_idwt = idwt2d(tf.math.imag(x), self.wavelet, self.levels) result = tf.dtypes.complex(real_idwt, imag_idwt) result = tf.transpose(result, [2,0,1]) result = tf.dtypes.complex(1.0/tf.sqrt(tf.cast(tf.size(result), self.dtype)), tf.constant(0.0, dtype=self.dtype)) * tf.signal.fft2d(result) result = tf.transpose(result, [1,2,0]) # Subsampling result = tf.compat.v1.where_v2(self.samp_patt, result, tf.zeros_like(result)) return result
def forward(x): # Compute the IDWT both for real and imaginary part # tf.conv1d does not support complex numbers # tfWavelets only support 3D-tensors real_idwt = idwt2d(tf.real(x), wavelet, levels) imag_idwt = idwt2d(tf.imag(x), wavelet, levels) complex_idwt = tf.complex(real_idwt, imag_idwt) # FFT2 uses the two last dimensions result = tf.transpose(complex_idwt, [2,0,1]) # [channels, height, width] # TODO: Scaling? result = 1./nsqrt * tf.fft2d(result) result = tf.transpose(result, [1,2,0]) # [height, width, channels] result = tf.where(mask, result, tf.zeros_like(result)) return result
def build_fista_graph(N, wav, levels): tf_im = tf.compat.v1.placeholder(tf.complex64, shape=[N, N, 1], name='image') tf_samp_patt = tf.compat.v1.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') op = MRIOperator(tf_samp_patt, wav, levels) measurements = op.sample(tf_im) initial_x = op(measurements, adjoint=True) gradient = LASSOGradient(op, measurements) alg = FISTA(gradient) result_coeffs = alg.run(op(measurements, True)) real_idwt = idwt2d(tf.math.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.math.imag(result_coeffs), wav, levels) output_node = tf.complex(real_idwt, imag_idwt) return output_node
wav = db4 wav_name = 'db4' eta = 0.00 # def run(eta): # result = run_pd(im, samp_patt, wav, levels, n_iter, eta=eta) # with open('../data/results/data.csv', 'a') as outfile: # outfile.write('{start_time}.png,{n_iter},{levels},{wav_name},{eta}\n'.format( # start_time = start_time, # n_iter = n_iter, # levels = levels, # wav_name = wav_name, # eta = eta)) result_coeffs = build_pd_graph(im.shape[0], wav, levels) real_idwt = idwt2d(tf.real(result_coeffs), wav, levels) imag_idwt = idwt2d(tf.imag(result_coeffs), wav, levels) node = tf.complex(real_idwt, imag_idwt) im = np.expand_dims(im, -1) samp_patt = np.expand_dims(samp_patt, -1) with tf.Session() as sess: start_time = datetime.now().strftime('%F_%T') start = time.time() result = sess.run(node, feed_dict={ 'image:0': im, 'sampling_pattern:0': samp_patt, 'sigma:0': 0.5,
# Build Primal-dual graph tf_im = tf.compat.v1.placeholder(cdtype, shape=[N, N, 1], name='image') tf_samp_patt = tf.compat.v1.placeholder(tf.bool, shape=[N, N, 1], name='sampling_pattern') # For the weighted l^1-norm pl_weights = tf.compat.v1.placeholder(dtype, shape=[N, N, 1], name='weights') tf_input = tf_im op = MRIOperator(tf_samp_patt, wav, levels, dtype=dtype) measurements = op.sample(tf_input) tf_adjoint_coeffs = op(measurements, adjoint=True) adj_real_idwt = idwt2d(tf.math.real(tf_adjoint_coeffs), wav, levels) adj_imag_idwt = idwt2d(tf.math.imag(tf_adjoint_coeffs), wav, levels) tf_adjoint = tf.complex(adj_real_idwt, adj_imag_idwt) prox1 = WeightedL1Prox(pl_weights, pl_lam * pl_tau, dtype=dtype) prox2 = SQLassoProx2(dtype=dtype) alg = SquareRootLASSO(op, prox1, prox2, measurements, sigma=pl_sigma, tau=pl_tau, lam=pl_lam, dtype=dtype)