Esempio n. 1
0
def D_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    xb, _, ms, _ = G(xa, b_ - a_)

    # discriminate
    xa_logit_gan, xa_logit_att = D(xa)
    xb_logit_gan, xb_logit_att = D(xb)

    # discriminator losses
    xa_loss_gan, xb_loss_gan = d_loss_fn(xa_logit_gan, xb_logit_gan)
    gp = tfprob.gradient_penalty(lambda x: D(x)[0], xa, xb,
                                 args.gradient_penalty_mode,
                                 args.gradient_penalty_sample_mode)
    xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
    reg_loss = tf.reduce_sum(D.func.reg_losses)

    loss = (xa_loss_gan + xb_loss_gan + gp * args.d_gradient_penalty_weight +
            xa_loss_att * args.d_attribute_loss_weight + reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(
        loss, global_step=step_cnt, var_list=D.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/D' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = [
            tl.summary_v2(
                {
                    'loss_gan': xa_loss_gan + xb_loss_gan,
                    'gp': gp,
                    'xa_loss_att': xa_loss_att,
                    'reg_loss': reg_loss
                },
                step=step_cnt,
                name='D'),
            tl.summary_v2({'lr': lr}, step=step_cnt, name='learning_rate')
        ]

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
Esempio n. 2
0
    def graph_per_gpu(x_r, zs, eps):

        # generate
        x_f = G(zs, eps)

        # discriminate
        x_r_logit = D(x_r)
        x_f_logit = D(x_f)

        # loss
        x_r_loss, x_f_loss = d_loss_fn(x_r_logit, x_f_logit)
        x_gp = tf.cond(
            tf.equal(step_cnt % args.d_lazy_reg_period, 0),
            lambda: tfprob.gradient_penalty(
                D, x_r, x_f, args.gradient_penalty_mode, args.
                gradient_penalty_sample_mode) * args.d_lazy_reg_period,
            lambda: tf.constant(0.0))
        if args.d_loss_weight_x_gp == 0:
            x_gp = tf.constant(0.0)

        reg_loss = tf.reduce_sum(D.func.reg_losses)

        loss = ((x_r_loss + x_f_loss) * args.d_loss_weight_x_gan +
                x_gp * args.d_loss_weight_x_gp + reg_loss * args.weight_decay)

        # optim
        grads = optimizer.compute_gradients(
            loss, var_list=D.func.trainable_variables)

        return grads, x_r_loss, x_f_loss, x_gp, reg_loss