def build_graph(args, a_r, b_r, a2b_s, b2a_s): with tf.device('/gpu:{}'.format(args.gpu)): a2b = g_net(a_r, 'a2b') b2a = g_net(b_r, 'b2a') a2b2a = g_net(a2b, 'b2a', reuse=True) b2a2b = g_net(b2a, 'a2b', reuse=True) cvt = (a2b, b2a, a2b2a, b2a2b) a_d = d_net(a_r, 'a') b2a_d = d_net(b2a, 'a', reuse=True) b2a_s_d = d_net(b2a_s, 'a', reuse=True) b_d = d_net(b_r, 'b') a2b_d = d_net(a2b, 'b', reuse=True) a2b_s_d = d_net(a2b_s, 'b', reuse=True) g_loss_a2b = tf.identity(ops.l2_loss(a2b_d, tf.ones_like(a2b_d)), name='g_loss_a2b') g_loss_b2a = tf.identity(ops.l2_loss(b2a_d, tf.ones_like(b2a_d)), name='g_loss_b2a') cyc_loss_a = tf.identity(ops.l1_loss(a_r, a2b2a) * 10.0, name='cyc_loss_a') cyc_loss_b = tf.identity(ops.l1_loss(b_r, b2a2b) * 10.0, name='cyc_loss_b') g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b d_loss_a_r = ops.l2_loss(a_d, tf.ones_like(a_d)) d_loss_b2a_s = ops.l2_loss(b2a_s_d, tf.zeros_like(b2a_s_d)) d_loss_a = tf.identity((d_loss_a_r + d_loss_b2a_s) / 2., name='d_loss_a') d_loss_b_r = ops.l2_loss(b_d, tf.ones_like(b_d)) d_loss_a2b_s = ops.l2_loss(a2b_s_d, tf.zeros_like(a2b_s_d)) d_loss_b = tf.identity((d_loss_b_r + d_loss_a2b_s) / 2., name='d_loss_b') g_sum = ops.summary_tensors( [g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b]) d_sum_a = ops.summary(d_loss_a) d_sum_b = ops.summary(d_loss_b) sum_ = (g_sum, d_sum_a, d_sum_b) all_var = tf.trainable_variables() g_var = [ var for var in all_var if 'a2b_g' in var.name or 'b2a_g' in var.name ] d_a_var = [var for var in all_var if 'a_d' in var.name] d_b_var = [var for var in all_var if 'b_d' in var.name] g_tr_op = tf.train.AdamOptimizer(args.lr, beta1=args.beta1).minimize( g_loss, var_list=g_var) d_tr_op_a = tf.train.AdamOptimizer(args.lr, beta1=args.beta1).minimize( d_loss_a, var_list=d_a_var) d_tr_op_b = tf.train.AdamOptimizer(args.lr, beta1=args.beta1).minimize( d_loss_b, var_list=d_b_var) tr_op = (g_tr_op, d_tr_op_a, d_tr_op_b) return cvt, sum_, tr_op
shape=[None, crop_size, crop_size, 3]) a2b = models.generator(a_real, 'a2b') b2a = models.generator(b_real, 'b2a') b2a2b = models.generator(b2a, 'a2b', reuse=True) a2b2a = models.generator(a2b, 'b2a', reuse=True) a_dis = models.discriminator(a_real, 'a') b2a_dis = models.discriminator(b2a, 'a', reuse=True) b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True) b_dis = models.discriminator(b_real, 'b') a2b_dis = models.discriminator(a2b, 'b', reuse=True) a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True) # losses g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b') g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a') cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0, name='cyc_loss_a') cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0, name='cyc_loss_b') g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis)) d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis)) d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0, name='d_loss_a') d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
def build_networks(): with tf.device('/gpu:%d' % args.gpu_id): # Nodes a_real = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) b_real = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) a2b_sample = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) b2a_sample = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) a2b1 = models.generator(a_real, 'a2b') b2a1 = models.generator(b_real, 'b2a') if args.transform_twice: #a-b-c a2b = models.generator(a2b1, 'a2b', reuse=True) b2a = models.generator(b2a1, 'b2a', reuse=True) else: a2b = a2b1 b2a = b2a1 b2a2b = models.generator(b2a, 'a2b', reuse=True) a2b2a = models.generator(a2b, 'b2a', reuse=True) if args.transform_twice: #a-b-c b2a2b = models.generator(b2a2b, 'a2b', reuse=True) a2b2a = models.generator(a2b2a, 'b2a', reuse=True) # Add extra loss term to enforce the discriminator's power to discern A samples from B samples a_dis = models.discriminator(a_real, 'a') a_from_b_dis = models.discriminator(b_real, 'a', reuse=True) #mod1 b2a_dis = models.discriminator(b2a, 'a', reuse=True) b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True) b_dis = models.discriminator(b_real, 'b') b_from_a_dis = models.discriminator(a_real, 'b', reuse=True) #mod1 a2b_dis = models.discriminator(a2b, 'b', reuse=True) a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True) double_cycle_loss = 0.0 if args.double_cycle: #Now making these double-processed samples belong to the same domain as 1-processed. I.e. the domains are "reflexive". a2b_sample_dis2 = models.discriminator(models.generator(a2b_sample, 'a2b', reuse=True), 'b', reuse=True) b2a_sample_dis2 = models.discriminator(models.generator(b2a_sample, 'b2a', reuse=True), 'a', reuse=True) a2b2b = models.generator(a2b, 'a2b', reuse=True) a2b2b2a = models.generator(a2b2b, 'b2a', reuse=True) a2b2b2a2a = models.generator(a2b2b2a, 'b2a', reuse=True) b2a2a = models.generator(b2a, 'b2a', reuse=True) b2a2a2b = models.generator(b2a2a, 'a2b', reuse=True) b2a2a2b2b = models.generator(b2a2a2b, 'a2b', reuse=True) cyc_loss_a2 = tf.identity(ops.l1_loss(a_real, a2b2b2a2a) * 10.0, name='cyc_loss_a2') cyc_loss_b2 = tf.identity(ops.l1_loss(b_real, b2a2a2b2b) * 10.0, name='cyc_loss_b2') double_cycle_loss = cyc_loss_a2 + cyc_loss_b2 # Losses g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b') g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a') cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0, name='cyc_loss_a') cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0, name='cyc_loss_b') g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b + double_cycle_loss d_loss_b2a_sample2 = d_loss_a2b_sample2 = 0.0 d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis)) d_loss_a_from_b_real = tf.identity(ops.l2_loss(a_from_b_dis, tf.zeros_like(a_from_b_dis)), name='d_loss_a_from_b') #mod1 d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis)) if args.double_cycle: d_loss_b2a_sample2 = ops.l2_loss(b2a_sample_dis2, tf.zeros_like(b2a_sample_dis)) d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample + d_loss_b2a_sample2 + d_loss_a_from_b_real) / 3.0, name='d_loss_a') d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis)) d_loss_b_from_a_real = tf.identity(ops.l2_loss(b_from_a_dis, tf.zeros_like(b_from_a_dis)), name='d_loss_b_from_a') #mod1 d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis)) if args.double_cycle: d_loss_a2b_sample2 = ops.l2_loss(a2b_sample_dis2, tf.zeros_like(a2b_sample_dis)) d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample + d_loss_a2b_sample2 + d_loss_b_from_a_real) / 3.0, name='d_loss_b') # Summaries g_summary = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b]) d_summary_a = ops.summary_tensors([d_loss_a, d_loss_a_from_b_real]) d_summary_b = ops.summary_tensors([d_loss_b, d_loss_b_from_a_real]) # Optim t_var = tf.trainable_variables() d_a_var = [var for var in t_var if 'a_discriminator' in var.name] d_b_var = [var for var in t_var if 'b_discriminator' in var.name] g_var = [var for var in t_var if 'a2b_generator' in var.name or 'b2a_generator' in var.name] d_a_train_op = tf.train.AdamOptimizer(args.lr, beta1=0.5).minimize(d_loss_a, var_list=d_a_var) d_b_train_op = tf.train.AdamOptimizer(args.lr, beta1=0.5).minimize(d_loss_b, var_list=d_b_var) g_train_op = tf.train.AdamOptimizer(args.lr, beta1=0.5).minimize(g_loss, var_list=g_var) return g_train_op, d_a_train_op, d_b_train_op, g_summary, d_summary_a, d_summary_b, a2b, a2b2a, b2a, b2a2b, a_real, b_real, a2b_sample, b2a_sample, a2b1, b2a1
def __init__(self, sess, config, name, is_train): self.sess = sess self.name = name self.is_train = is_train im_shape = [config['batch_size']] + config['image_size'] + [1] self.x = tf.placeholder(tf.float32, im_shape) self.y = tf.placeholder(tf.float32, im_shape) self.xy = tf.concat([self.x, self.y], axis=3) if self.is_train: lb_shape = [config['batch_size']] + \ config['image_size'] + [config['n_labels'] + 1] self.xlabel = tf.placeholder(tf.float32, lb_shape) self.ylabel = tf.placeholder(tf.float32, lb_shape) else: if im_shape[1:-1] != [64, 64]: dsfac = im_shape[1] / 64. x_reshaped = tf.image.resize_images(self.x, size=[64, 64]) y_reshaped = tf.image.resize_images(self.y, size=[64, 64]) self.xy = tf.concat([x_reshaped, y_reshaped], axis=3) self.VectorCNN = VectorCNN('VectorCNN', is_train=self.is_train) self.v = self.VectorCNN(self.xy) if self.is_train: self.z = batch_displacement_warp2d( self.x, self.v, vector_fields_in_pixel_space=True) self.zlabel = batch_displacement_warp2d( self.xlabel, self.v, vector_fields_in_pixel_space=True) self.CNN_AE = CNN_AE('CNN_AE', is_train=False) with tf.name_scope('AE_1'): h1, _ = self.CNN_AE(self.ylabel, lb_shape[-1], config['n_codes']) with tf.name_scope('AE_2'): h2, _ = self.CNN_AE(self.zlabel, lb_shape[-1], config['n_codes']) self.loss = -ncc(self.y, self.z) + \ config['tv_reg'] * total_variation(self.v) + \ config['ce_reg'] * softmax_cross_entropy( labels=self.ylabel, logits=self.zlabel) + \ config['ae_reg'] * l2_loss(h1, h2) self.optim = tf.train.AdamOptimizer(config['learning_rate']) self.train = self.optim.minimize(self.loss, var_list=self.VectorCNN.var_list) self.sess.run(tf.global_variables_initializer()) self.CNN_AE.restore( self.sess, os.path.join(config['aenet_dir'], 'model.ckpt')) else: if list(im_shape[1:-1]) != [64, 64]: self.v = tf.image.resize_images(self.v, size=im_shape[1:-1]) self.v = self.v * dsfac self.z = batch_displacement_warp2d( self.x, self.v, vector_fields_in_pixel_space=True) self.sess.run(tf.global_variables_initializer())
b2a_sample = tf.placeholder(tf.float32, shape=[None, crop_size, crop_size, 3]) a2b = models.generator(a_real, 'a2b') b2a = models.generator(b_real, 'b2a') b2a2b = models.generator(b2a, 'a2b', reuse=True) a2b2a = models.generator(a2b, 'b2a', reuse=True) a_dis = models.discriminator(a_real, 'a') b2a_dis = models.discriminator(b2a, 'a', reuse=True) b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True) b_dis = models.discriminator(b_real, 'b') a2b_dis = models.discriminator(a2b, 'b', reuse=True) a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True) # losses g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b') g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a') cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0, name='cyc_loss_a') cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0, name='cyc_loss_b') g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis)) d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis)) d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0, name='d_loss_a') d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis)) d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis)) d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample) / 2.0, name='d_loss_b') # summaries g_summary = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b]) d_summary_a = ops.summary(d_loss_a)