def forward(img_a, img_b): img_a /= 255. img_b /= 255. img_ab = generator(img_a, name='atob', reuse=False) img_ba = generator(img_b, name='btoa', reuse=False) img_aba = generator(img_ab, name='btoa', reuse=True) img_bab = generator(img_ba, name='atob', reuse=True) logit_fake_a = discriminator(img_ba, name='a', reuse=False) logit_fake_b = discriminator(img_ab, name='b', reuse=False) score_fake_a = O.sigmoid(logit_fake_a) score_fake_b = O.sigmoid(logit_fake_b) for name in ['img_a', 'img_b', 'img_ab', 'img_ba', 'img_aba', 'img_bab', 'score_fake_a', 'score_fake_b']: dpc.add_output(locals()[name], name=name) if env.phase is env.Phase.TRAIN: logit_real_a = discriminator(img_a, name='a', reuse=True) logit_real_b = discriminator(img_b, name='b', reuse=True) score_real_a = O.sigmoid(logit_real_a) score_real_b = O.sigmoid(logit_real_b) all_g_loss = 0. all_d_loss = 0. r_loss_ratio = 0.9 for pair_name, (real, fake), (logit_real, logit_fake), (score_real, score_fake) in zip( ['lossa', 'lossb'], [(img_a, img_aba), (img_b, img_bab)], [(logit_real_a, logit_fake_a), (logit_real_b, logit_fake_b)], [(score_real_a, score_fake_a), (score_real_b, score_fake_b)]): with env.name_scope(pair_name): d_loss_real = O.sigmoid_cross_entropy_with_logits(logits=logit_real, labels=O.ones_like(logit_real)).mean(name='d_loss_real') d_loss_fake = O.sigmoid_cross_entropy_with_logits(logits=logit_fake, labels=O.zeros_like(logit_fake)).mean(name='d_loss_fake') g_loss = O.sigmoid_cross_entropy_with_logits(logits=logit_fake, labels=O.ones_like(logit_fake)).mean(name='g_loss') d_acc_real = (score_real > 0.5).astype('float32').mean(name='d_acc_real') d_acc_fake = (score_fake < 0.5).astype('float32').mean(name='d_acc_fake') g_accuracy = (score_fake > 0.5).astype('float32').mean(name='g_accuracy') d_accuracy = O.identity(.5 * (d_acc_real + d_acc_fake), name='d_accuracy') d_loss = O.identity(.5 * (d_loss_real + d_loss_fake), name='d_loss') # r_loss = O.raw_l2_loss('raw_r_loss', real, fake).flatten2().sum(axis=1).mean(name='r_loss') r_loss = O.raw_l2_loss('raw_r_loss', real, fake).mean(name='r_loss') # r_loss = O.raw_cross_entropy_prob('raw_r_loss', real, fake).flatten2().sum(axis=1).mean(name='r_loss') # all_g_loss += g_loss + r_loss all_g_loss += (1 - r_loss_ratio) * g_loss + r_loss_ratio * r_loss all_d_loss += d_loss for v in [d_loss_real, d_loss_fake, g_loss, d_acc_real, d_acc_fake, g_accuracy, d_accuracy, d_loss, r_loss]: dpc.add_output(v, name=re.sub('^tower/\d+/', '', v.name)[:-2], reduce_method='sum') dpc.add_output(all_g_loss, name='g_loss', reduce_method='sum') dpc.add_output(all_d_loss, name='d_loss', reduce_method='sum')
def forward(img): g_batch_size = get_env('trainer.batch_size') if env.phase is env.Phase.TRAIN else 1 z = O.as_varnode(tf.random_normal([g_batch_size, code_length])) with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES): _ = z with O.argscope(O.fc, nonlin=O.tanh): _ = O.fc('fc1', _, 500) _ = O.fc('fc3', _, 784, nonlin=O.sigmoid) x_given_z = _.reshape(-1, 28, 28, 1) def discriminator(x): _ = x with O.argscope(O.fc, nonlin=O.tanh): _ = O.fc('fc1', _, 500) _ = O.fc('fc3', _, 1) logits = _ return logits if is_train: with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES): logits_real = discriminator(img).flatten() score_real = O.sigmoid(logits_real) with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=is_train): logits_fake = discriminator(x_given_z).flatten() score_fake = O.sigmoid(logits_fake) if is_train: # build loss with env.variable_scope('loss'): d_loss_real = O.sigmoid_cross_entropy_with_logits( logits=logits_real, labels=O.ones_like(logits_real)).mean() d_loss_fake = O.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=O.zeros_like(logits_fake)).mean() g_loss = O.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=O.ones_like(logits_fake)).mean() d_acc_real = (score_real > 0.5).astype('float32').mean() d_acc_fake = (score_fake < 0.5).astype('float32').mean() g_accuracy = (score_fake > 0.5).astype('float32').mean() d_accuracy = .5 * (d_acc_real + d_acc_fake) d_loss = .5 * (d_loss_real + d_loss_fake) dpc.add_output(d_loss, name='d_loss', reduce_method='sum') dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum') dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum') dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum') dpc.add_output(g_loss, name='g_loss', reduce_method='sum') dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum') dpc.add_output(x_given_z, name='output') dpc.add_output(score_fake, name='score')
def forward(x, zc): if env.phase is env.Phase.TRAIN: zc = zc_distrib.sample(g_batch_size, prior) zn = O.random_normal([g_batch_size, zn_size], -1 , 1) z = O.concat([zc, zn], axis=1, name='z') with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES): x_given_z = generator(z) with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES): logits_fake, code_fake = discriminator(x_given_z) score_fake = O.sigmoid(logits_fake) dpc.add_output(x_given_z, name='output') dpc.add_output(score_fake, name='score') dpc.add_output(code_fake, name='code') if env.phase is env.Phase.TRAIN: with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=True): logits_real, code_real = discriminator(x) score_real = O.sigmoid(logits_real) # build loss with env.variable_scope('loss'): d_loss_real = O.sigmoid_cross_entropy_with_logits( logits=logits_real, labels=O.ones_like(logits_real)).mean() d_loss_fake = O.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=O.zeros_like(logits_fake)).mean() g_loss = O.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=O.ones_like(logits_fake)).mean() entropy = zc_distrib.cross_entropy(zc, batch_prior) cond_entropy = zc_distrib.cross_entropy(zc, code_fake, process_theta=True) info_gain = entropy - cond_entropy d_acc_real = (score_real > 0.5).astype('float32').mean() d_acc_fake = (score_fake < 0.5).astype('float32').mean() g_accuracy = (score_fake > 0.5).astype('float32').mean() d_accuracy = .5 * (d_acc_real + d_acc_fake) d_loss = .5 * (d_loss_real + d_loss_fake) d_loss -= info_gain g_loss -= info_gain dpc.add_output(d_loss, name='d_loss', reduce_method='sum') dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum') dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum') dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum') dpc.add_output(g_loss, name='g_loss', reduce_method='sum') dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum') dpc.add_output(info_gain, name='g_info_gain', reduce_method='sum')
def forward(x): g_batch_size = get_env('trainer.batch_size' ) if env.phase is env.Phase.TRAIN else 1 z = O.random_normal([g_batch_size, z_dim]) with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES): img_gen = generator(z) # tf.summary.image('generated-samples', img_gen, max_outputs=30) with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES): logits_fake = discriminator(img_gen) score_fake = O.sigmoid(logits_fake) dpc.add_output(img_gen, name='output') dpc.add_output(score_fake, name='score') if env.phase is env.Phase.TRAIN: with env.variable_scope( GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=True): logits_real = discriminator(x) score_real = O.sigmoid(logits_real) # build loss with env.variable_scope('loss'): d_loss_real = O.sigmoid_cross_entropy_with_logits( logits=logits_real, labels=O.ones_like(logits_real)).mean() d_loss_fake = O.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=O.zeros_like(logits_fake)).mean() g_loss = O.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=O.ones_like(logits_fake)).mean() d_acc_real = (score_real > 0.5).astype('float32').mean() d_acc_fake = (score_fake < 0.5).astype('float32').mean() g_accuracy = (score_fake > 0.5).astype('float32').mean() d_accuracy = .5 * (d_acc_real + d_acc_fake) d_loss = .5 * (d_loss_real + d_loss_fake) dpc.add_output(d_loss, name='d_loss', reduce_method='sum') dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum') dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum') dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum') dpc.add_output(g_loss, name='g_loss', reduce_method='sum') dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum')
def image_diff(origin, canvas_logits): """ Get the difference between original image and the canvas, note that the canvas is given without sigmoid activation (we will do it inside) :param origin: original image: batch_size, h, w, c :param canvas_logits: canvas logits: batch_size, h, w, c :return: the difference: origin - sigmoid(logits) """ sigmoid_canvas = O.sigmoid(canvas_logits) return origin - sigmoid_canvas
def generator(z): w_init = O.truncated_normal_initializer(stddev=0.02) with O.argscope(O.conv2d, O.deconv2d, kernel=4, stride=2, W=w_init),\ O.argscope(O.fc, W=w_init): _ = z _ = O.fc('fc1', _, 1024, nonlin=O.bn_relu) _ = O.fc('fc2', _, 128 * 7 * 7, nonlin=O.bn_relu) _ = O.reshape(_, [-1, 7, 7, 128]) _ = O.deconv2d('deconv1', _, 64, nonlin=O.bn_relu) _ = O.deconv2d('deconv2', _, 1) _ = O.sigmoid(_, 'out') return _
def decoder(z): w_init = O.truncated_normal_initializer(stddev=0.02) with O.argscope(O.conv2d, O.deconv2d, kernel=4, stride=2, W=w_init),\ O.argscope(O.fc, W=w_init): _ = z _ = O.deconv2d('deconv1', _, 256, nonlin=O.bn_relu) _ = O.deconv2d('deconv2', _, 128, nonlin=O.bn_relu) _ = O.deconv2d('deconv3', _, 64, nonlin=O.bn_relu) _ = O.deconv2d('deconv4', _, c) _ = O.sigmoid(_, name='out') x = _ return x
def forward(img=None): encoder = O.BasicLSTMCell(256) decoder = O.BasicLSTMCell(256) batch_size = img.shape[0] if is_train else 1 canvas = O.zeros(shape=O.canonize_sym_shape([batch_size, h, w, c]), dtype='float32') enc_state = encoder.zero_state(batch_size, dtype='float32') dec_state = decoder.zero_state(batch_size, dtype='float32') enc_h, dec_h = enc_state[1], dec_state[1] def encode(x, state, reuse): with env.variable_scope('read_encoder', reuse=reuse): return encoder(x, state) def decode(x, state, reuse): with env.variable_scope('write_decoder', reuse=reuse): return decoder(x, state) all_sqr_mus, all_vars, all_log_vars = 0., 0., 0. for step in range(nr_glimpse): reuse = (step != 0) if is_reconstruct or env.phase is env.Phase.TRAIN: img_hat = draw_opr.image_diff(img, canvas) # eq. 3 # Note: here the input should be dec_h with env.variable_scope('read', reuse=reuse): read_param = O.fc('fc_param', dec_h, 5) with env.name_scope('read_step{}'.format(step)): cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, read_param) read_inp = O.concat([img, img_hat], axis=3) # of shape: batch_size x h x w x (2c) read_out = draw_opr.att_read(att_dim, read_inp, cx, cy, delta, var) # eq. 4 enc_inp = O.concat([gamma * read_out.flatten2(), dec_h], axis=1) enc_h, enc_state = encode(enc_inp, enc_state, reuse) # eq. 5 with env.variable_scope('sample', reuse=reuse): _ = enc_h sample_mu = O.fc('fc_mu', _, code_length) sample_log_var = O.fc('fc_sigma', _, code_length) with env.name_scope('sample_step{}'.format(step)): sample_var = O.exp(sample_log_var) sample_std = O.sqrt(sample_var) sample_epsilon = O.random_normal([batch_size, code_length]) z = sample_mu + sample_std * sample_epsilon # eq. 6 # accumulate for losses all_sqr_mus += sample_mu ** 2. all_vars += sample_var all_log_vars += sample_log_var else: z = O.random_normal([1, code_length]) # z = O.callback_injector(z) dec_h, dec_state = decode(z, dec_state, reuse) # eq. 7 with env.variable_scope('write', reuse=reuse): write_param = O.fc('fc_param', dec_h, 5) write_in = O.fc('fc', dec_h, (att_dim * att_dim * c)).reshape(-1, att_dim, att_dim, c) with env.name_scope('write_step{}'.format(step)): cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, write_param) write_out = draw_opr.att_write(h, w, write_in, cx, cy, delta, var) # eq. 8 canvas += write_out if env.phase is env.Phase.TEST: dpc.add_output(O.sigmoid(canvas), name='canvas_step{}'.format(step)) canvas = O.sigmoid(canvas) if env.phase is env.Phase.TRAIN: with env.variable_scope('loss'): img, canvas = img.flatten2(), canvas.flatten2() content_loss = O.raw_cross_entropy_prob('raw_content', canvas, img) content_loss = content_loss.sum(axis=1).mean(name='content') # distrib_loss = 0.5 * (O.sqr(mu) + O.sqr(std) - 2. * O.log(std + 1e-8) - 1.0).sum(axis=1) distrib_loss = -0.5 * (float(nr_glimpse) + all_log_vars - all_sqr_mus - all_vars).sum(axis=1) distrib_loss = distrib_loss.mean(name='distrib') summary.scalar('content_loss', content_loss) summary.scalar('distrib_loss', distrib_loss) loss = content_loss + distrib_loss dpc.add_output(loss, name='loss', reduce_method='sum') dpc.add_output(canvas, name='output')