示例#1
0
def D_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    xb, _, ms, _ = G(xa, b_ - a_)

    # discriminate
    xa_logit_gan, xa_logit_att = D(xa)
    xb_logit_gan, xb_logit_att = D(xb)

    # discriminator losses
    xa_loss_gan, xb_loss_gan = d_loss_fn(xa_logit_gan, xb_logit_gan)
    gp = tfprob.gradient_penalty(lambda x: D(x)[0], xa, xb,
                                 args.gradient_penalty_mode,
                                 args.gradient_penalty_sample_mode)
    xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
    reg_loss = tf.reduce_sum(D.func.reg_losses)

    loss = (xa_loss_gan + xb_loss_gan + gp * args.d_gradient_penalty_weight +
            xa_loss_att * args.d_attribute_loss_weight + reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(
        loss, global_step=step_cnt, var_list=D.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/D' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = [
            tl.summary_v2(
                {
                    'loss_gan': xa_loss_gan + xb_loss_gan,
                    'gp': gp,
                    'xa_loss_att': xa_loss_att,
                    'reg_loss': reg_loss
                },
                step=step_cnt,
                name='D'),
            tl.summary_v2({'lr': lr}, step=step_cnt, name='learning_rate')
        ]

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
示例#2
0
def G_train_graph():
    # ======================================
    # =                 graph              =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    xb, _, ms, ms_multi = G(xa, b_ - a_)

    # discriminate
    xb_logit_gan, xb_logit_att = D(xb)

    # generator losses
    xb_loss_gan = g_loss_fn(xb_logit_gan)
    xb_loss_att = tf.losses.sigmoid_cross_entropy(b, xb_logit_att)
    spasity_loss = tf.reduce_sum([
        tf.reduce_mean(m) * w
        for m, w in zip(ms, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    ])
    full_overlap_mask_pair_loss, non_overlap_mask_pair_loss = module.overlap_loss_fn(
        ms_multi, args.att_names)
    reg_loss = tf.reduce_sum(G.func.reg_losses)

    loss = (
        xb_loss_gan + xb_loss_att * args.g_attribute_loss_weight +
        spasity_loss * args.g_spasity_loss_weight +
        full_overlap_mask_pair_loss * args.g_full_overlap_mask_pair_loss_weight
        +
        non_overlap_mask_pair_loss * args.g_non_overlap_mask_pair_loss_weight +
        reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(
        loss, global_step=step_cnt, var_list=G.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = tl.summary_v2(
            {
                'xb_loss_gan': xb_loss_gan,
                'xb_loss_att': xb_loss_att,
                'spasity_loss': spasity_loss,
                'full_overlap_mask_pair_loss': full_overlap_mask_pair_loss,
                'non_overlap_mask_pair_loss': non_overlap_mask_pair_loss,
                'reg_loss': reg_loss
            },
            step=step_cnt,
            name='G')

    # ======================================
    # =           generator size           =
    # ======================================

    n_params, n_bytes = tl.count_parameters(G.func.variables)
    print('Generator Size: n_parameters = %d = %.2fMB' %
          (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
示例#3
0
def G_train_graph():
    # ======================================
    # =                 graph              =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    z = Genc(xa)
    xa_ = Gdec(z, a_)
    xb_ = Gdec(z, b_)

    # discriminate
    xb__logit_gan, xb__logit_att = D(xb_)

    # generator losses
    xb__loss_gan = g_loss_fn(xb__logit_gan)
    xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
    xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
    reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses)

    loss = (xb__loss_gan +
            xb__loss_att * args.g_attribute_loss_weight +
            xa__loss_rec * args.g_reconstruction_loss_weight +
            reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = tl.summary_v2({
            'xb__loss_gan': xb__loss_gan,
            'xb__loss_att': xb__loss_att,
            'xa__loss_rec': xa__loss_rec,
            'reg_loss': reg_loss
        }, step=step_cnt, name='G')

    # ======================================
    # =           generator size           =
    # ======================================

    n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables)
    print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run