Exemplo n.º 1
0
def G_train_graph():
    # ======================================
    # =                 graph              =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    z = Genc(xa)
    xa_ = Gdec(z, a_)
    xb_ = Gdec(z, b_)

    # discriminate
    xb__logit_gan, xb__logit_att = D(xb_)

    # generator losses
    xb__loss_gan = g_loss_fn(xb__logit_gan)
    xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
    xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
    reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses)

    loss = (xb__loss_gan +
            xb__loss_att * args.g_attribute_loss_weight +
            xa__loss_rec * args.g_reconstruction_loss_weight +
            reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = tl.summary_v2({
            'xb__loss_gan': xb__loss_gan,
            'xb__loss_att': xb__loss_att,
            'xa__loss_rec': xa__loss_rec,
            'reg_loss': reg_loss
        }, step=step_cnt, name='G')

    # ======================================
    # =           generator size           =
    # ======================================

    n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables)
    print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
Exemplo n.º 2
0
def G_train_graph():
    # ======================================
    # =                 graph              =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    xb, _, ms, ms_multi = G(xa, b_ - a_)

    # discriminate
    xb_logit_gan, xb_logit_att = D(xb)

    # generator losses
    xb_loss_gan = g_loss_fn(xb_logit_gan)
    xb_loss_att = tf.losses.sigmoid_cross_entropy(b, xb_logit_att)
    spasity_loss = tf.reduce_sum([
        tf.reduce_mean(m) * w
        for m, w in zip(ms, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    ])
    full_overlap_mask_pair_loss, non_overlap_mask_pair_loss = module.overlap_loss_fn(
        ms_multi, args.att_names)
    reg_loss = tf.reduce_sum(G.func.reg_losses)

    loss = (
        xb_loss_gan + xb_loss_att * args.g_attribute_loss_weight +
        spasity_loss * args.g_spasity_loss_weight +
        full_overlap_mask_pair_loss * args.g_full_overlap_mask_pair_loss_weight
        +
        non_overlap_mask_pair_loss * args.g_non_overlap_mask_pair_loss_weight +
        reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(
        loss, global_step=step_cnt, var_list=G.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = tl.summary_v2(
            {
                'xb_loss_gan': xb_loss_gan,
                'xb_loss_att': xb_loss_att,
                'spasity_loss': spasity_loss,
                'full_overlap_mask_pair_loss': full_overlap_mask_pair_loss,
                'non_overlap_mask_pair_loss': non_overlap_mask_pair_loss,
                'reg_loss': reg_loss
            },
            step=step_cnt,
            name='G')

    # ======================================
    # =           generator size           =
    # ======================================

    n_params, n_bytes = tl.count_parameters(G.func.variables)
    print('Generator Size: n_parameters = %d = %.2fMB' %
          (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
Exemplo n.º 3
0
def G_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])
    zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims]
    eps = tf.random.normal([args.batch_size, args.eps_dim])

    # counter
    step_cnt, _ = tl.counter()

    # optimizer
    optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1)

    def graph_per_gpu(zs, eps):
        # generate
        x_f = G(zs, eps)

        # discriminate
        x_f_logit = D(x_f)

        # loss
        x_f_loss = g_loss_fn(x_f_logit)
        orth_loss = tf.reduce_sum(
            tl.tensors_filter(G.func.reg_losses, 'orthogonal_regularizer'))
        reg_loss = tf.reduce_sum(
            tl.tensors_filter(G.func.reg_losses, 'l2_regularizer'))

        loss = (x_f_loss * args.g_loss_weight_x_gan +
                orth_loss * args.g_loss_weight_orth_loss +
                reg_loss * args.weight_decay)

        # optim
        grads = optimizer.compute_gradients(
            loss, var_list=G.func.trainable_variables)

        return grads, x_f_loss, orth_loss, reg_loss

    split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(
        *tl.parellel_run(tl.gpus(), graph_per_gpu,
                         tl.split_nest((zs, eps), len(tl.gpus()))))
    # split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((zs, eps), 1)))
    grads = tl.average_gradients(split_grads)
    x_f_loss, orth_loss, reg_loss = [
        tf.reduce_mean(t)
        for t in [split_x_f_loss, split_orth_loss, split_reg_loss]
    ]

    step = optimizer.apply_gradients(grads, global_step=step_cnt)

    # moving average
    with tf.control_dependencies([step]):
        step = G_ema.apply(G.func.trainable_variables)

    # summary
    summary_dict = {
        'x_f_loss': x_f_loss,
        'orth_loss': orth_loss,
        'reg_loss': reg_loss
    }
    summary_dict.update({
        'L_%d' % i: t
        for i, t in enumerate(tl.tensors_filter(G.func.variables, 'L'))
    })
    summary_loss = tl.create_summary_statistic_v2(summary_dict,
                                                  './output/%s/summaries/G' %
                                                  args.experiment_name,
                                                  step=step_cnt,
                                                  n_steps_per_record=10,
                                                  name='G_loss')

    summary_image = tl.create_summary_image_v2(
        {
            'orth_U_%d' % i: t[None, :, :, None]
            for i, t in enumerate(tf.get_collection('orth', G.func.scope +
                                                    '/'))
        },
        './output/%s/summaries/G' % args.experiment_name,
        step=step_cnt,
        n_steps_per_record=10,
        name='G_image')

    # ======================================
    # =             model size             =
    # ======================================

    n_params, n_bytes = tl.count_parameters(G.func.variables)
    print('Model Size: n_parameters = %d = %.2fMB' %
          (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary_loss, summary_image],
                 feed_dict={lr: pl_ipts['lr']})

    return run