示例#1
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    gen_model, gen_config, disc_model, disc_config = model.get_model(FLAGS.model_name, TFLAGS)
    gen_model = model.good_generator.GoodGeneratorEncoder(**gen_config)
    print("Common length: %d" % gen_model.common_length)
    TFLAGS['AE_weight'] = 1.0

    # Data Preparation
    # raw image is 0~255
    print("Get dataset")
    if TFLAGS['dataset_kind'] == "file":
        dataset = dataloader.CustomDataset(
            root_dir=TFLAGS['data_dir'],
            npy_dir=TFLAGS['npy_dir'],
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            disturb=TFLAGS['disturb'],
            flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "numpy":
        train_data, test_data, train_label, test_label = utils.load_mnist_4d(TFLAGS['data_dir'])
        train_data = train_data.reshape([-1] + TFLAGS['input_shape'])

        # TODO: validation in training
        dataset = dataloader.NumpyArrayDataset(
            data_npy=train_data,
            label_npy=train_label,
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            flip=TFLAGS['preproc_flip'])
    
    dl = dataloader.CustomDataLoader(dataset, batch_size=TFLAGS['batch_size'], num_threads=TFLAGS['data_threads'])

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real")
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="z_noise")
    
    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    inc_length = tf.placeholder(tf.int32, [], name="inc_length")
    lr = tf.placeholder(tf.float32, [], name="lr")

    # build inference network

    # image_inputs : x_real, x_fake, x_trans
    # noise_input : noise, noise_rec, noise_trans

    # feat_image + image_feat
    # x_real -> trans_feat -> x_trans -> trans_x_trans_feat[REC] -> x_trans_trans_

    # feat_noise + noise_feat
    # x_real -> trans_feat -> noise_trans -> noise_x_trans_feat[REC] -> x_trans_fake

    with tf.variable_scope(gen_model.field_name):
        noise_input = tf.concat([z_noise, c_noise], axis=1)
        gen_model.image_input = x_real
        gen_model.noise_input = noise_input
        
        x_fake, x_trans, noise_rec, noise_trans = gen_model.build_inference()
        noise_feat = tf.identity(gen_model.noise_feat, "noise_feat")
        trans_feat = tf.identity(gen_model.image_feat, "trans_feat")
        
        gen_model.noise_input = noise_rec
        gen_model.image_input = x_fake

        x_fake_fake, x_fake_trans, noise_rec_rec, noise_x_fake_trans = gen_model.build_inference()
        noise_rec_feat = tf.identity(gen_model.noise_feat, "noise_rec_feat")
        trans_fake_feat = tf.identity(gen_model.image_feat, "trans_fake_feat")
        
        gen_model.noise_input = noise_trans
        gen_model.image_input = x_trans

        x_trans_fake, x_trans_trans, noise_trans_rec, noise_trans_trans = gen_model.build_inference()
        noise_trans_feat = tf.identity(gen_model.noise_feat, "noise_trans_feat")
        trans_trans_feat = tf.identity(gen_model.image_feat, "trans_trans_feat")

    # noise -> noise_feat -> x_fake | noise_rec
    # image -> trans_feat -> x_trans | noise_trans
    # noise_rec -> noise_rec_feat -> x_fake_fake | noise_rec_rec
    # x_fake -> trans_fake_feat -> x_fake_trans | noise_fake_trans
    # noise_trans -> noise_trans_feat -> x_trans_fake | noise_trans_rec    
    # x_trans -> trans_trans_feat -> x_trans_trans | noise_trans_trans

    # Image Feat Recs:
    # trans_feat == noise_trans_feat : feat_noise + noise_feat
    # trans_feat == trans_trans_feat : feat_image + image_feat
    # noise_feat == noise_rec_feat : feat_noise + noise_feat
    # noise_feat == trans_fake_feat : feat_image + image_feat

    # Noise Recs:
    # noise_input == noise_rec : noise_feat + feat_noise
    
    # Image Recs:
    # x_real -[REC]- x_trans : image_feat + feat_image

    disc_real_out, cls_real_logits = disc_model(x_real)[:2]; disc_model.set_reuse()
    disc_fake_out, cls_fake_logits = disc_model(x_fake)[:2]
    disc_trans_out, cls_trans_logits = disc_model(x_trans)[:2]

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    # Select loss builder and model trainer
    print("Build training graph")

    # Naive GAN
    loss.naive_ganloss.func_gen_loss(disc_fake_out, adv_weight, name="Gen", model=gen_model)
    loss.naive_ganloss.func_disc_fake_loss(disc_fake_out, adv_weight / 2, name="DiscFake", model=disc_model)
    loss.naive_ganloss.func_disc_fake_loss(disc_trans_out, adv_weight / 2, name="DiscTrans", model=disc_model)
    loss.naive_ganloss.func_disc_real_loss(disc_real_out, adv_weight, name="DiscGen", model=disc_model)

    # ACGAN
    real_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=real_cls_logits,
        labels=c_label), name="real_cls_reduce_mean")
    real_cls_cost_sum = tf.summary.scalar("real_cls", real_cls_cost)
    
    fake_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=fake_cls_logits,
        labels=c_noise), name='fake_cls_reduce_mean')
    fake_cls_cost_sum = tf.summary.scalar("fake_cls", fake_cls_cost)
    
    disc_model.cost += real_cls_cost * real_cls_weight + fake_cls_cost * fake_cls_weight
    disc_model.sum_op.extend([real_cls_cost_sum, fake_cls_cost_sum])

    gen_model.cost += fake_cls_cost * fake_cls_weight
    gen_model.sum_op.append(fake_cls_cost_sum)

    # DRAGAN gradient penalty
    disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss(disc_model, x_real, TFLAGS['gp_weight'])
    disc_model.cost += disc_cost_
    disc_model.sum_op.append(disc_sum_)

    gen_cost_, gen_sum_ = loss.naive_ganloss.func_gen_loss(disc_trans_out, adv_weight, name="GenTrans")
    gen_model.cost += gen_cost_

    # Image Feat Recs:
    # trans_feat == noise_trans_feat : feat_noise + noise_feat 1
    # trans_feat == trans_trans_feat : feat_image + image_feat 2
    # noise_feat == noise_rec_feat : feat_noise + noise_feat 1
    # noise_feat == trans_fake_feat : feat_image + image_feat 2
    trans_feat_cost1_, _ = loss.common_loss.reconstruction_loss(
        noise_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat1")
    trans_feat_cost2_, _ = loss.common_loss.reconstruction_loss(
        trans_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat2")
    trans_feat_cost_ = trans_feat_cost1_ + trans_feat_cost2_
    trans_feat_sum_ = tf.summary.scalar("RecTransFeat", trans_feat_cost_)

    noise_feat_cost1_, _ = loss.common_loss.reconstruction_loss(
        noise_rec_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat1")
    noise_feat_cost2_, _ = loss.common_loss.reconstruction_loss(
        trans_fake_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat2")
    noise_feat_cost_ = noise_feat_cost1_ + noise_feat_cost2_
    noise_feat_sum_ = tf.summary.scalar("RecNoiseFeat", noise_feat_cost_)

    # Noise Recs:
    # noise_input -[REC]- noise_rec : noise_feat + feat_noise
    noise_cost_, noise_sum_ = loss.common_loss.reconstruction_loss(
        noise_rec[:, -inc_length:], noise_input[:, -inc_length:], TFLAGS['AE_weight'], name="RecNoise")

    # Image Recs:
    # x_real -[REC]- x_trans : image_feat + feat_image
    rec_cost_, rec_sum_ = loss.common_loss.reconstruction_loss(
        x_trans, x_real, TFLAGS['AE_weight'], name="RecPix")

    gen_model.extra_sum_op = [gen_sum_, trans_feat_sum_, noise_feat_sum_, rec_sum_, noise_sum_]
    gen_model.extra_loss = gen_cost_ + trans_feat_cost_ + noise_feat_cost_ + rec_cost_ + noise_cost_

    gen_model.total_cost = tf.identity(gen_model.extra_loss) + tf.identity(gen_model.cost)

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")
    
    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_trans = ops.get_grid_image_summary(x_trans, edge_num)
    inter_sum_op.append(tf.summary.image("trans image", grid_x_trans))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))

    # merge summary op
    gen_model.extra_sum_op = tf.summary.merge(gen_model.extra_sum_op)
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    print("=> Compute gradient")

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()
    
    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # normal GAN train op
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        # train reconstruction branch
        var_list = gen_model.branch_vars['image_feat'] + gen_model.branch_vars['feat_noise']
        gen_model.extra_train_op_partial1 = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list)

        var_list += gen_model.branch_vars['noise_feat'] + gen_model.branch_vars['feat_image']
        gen_model.extra_train_op_full = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list)

        gen_model.full_train_op = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.total_cost, var_list=gen_model.vars)
        
        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
            "real_cls"  : real_cls_weight,
            "fake_cls"  : fake_cls_weight,
            "adv"       : adv_weight,
            "lr"        : lr,
            "inc_length": inc_length
        }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model" : disc_model,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.base_gantrainer_rep.BaseGANTrainer
    trainer_feed.update({
        "inputs": [z_noise, c_noise, x_real, c_label],
    })
    
    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    ModelTrainer.train()
示例#2
0
def main_convex():
    print("Convex game")

    lr = tf.placeholder(DTYPE, shape=[], name="learning_rate")
    raw_real_data = tf.placeholder(tf.float64, [None, 784], "real_data")
    real_data = tf.cast(raw_real_data, DTYPE)
    raw_fake_data = tf.placeholder(tf.float64, [None, 784], "real_data")
    fake_data = tf.cast(raw_fake_data, DTYPE)

    trX, teX, trY, teY = load_mnist_4d("data/MNIST")
    del teX, teY
    trX = trX.reshape(trX.shape[0], -1) / 127.5 - 1

    # build net
    with tf.variable_scope("disc", reuse=tf.AUTO_REUSE):
        vars = get_vars([784, 1])
    flt_vars = []
    for vp in vars:
        flt_vars.extend(vp)
    disc_real = build_net(real_data, vars)
    disc_fake = build_net(fake_data, vars)

    # build loss
    raw_gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=disc_fake, labels=tf.ones_like(disc_fake)),
                                  name="raw_gen_cost")

    raw_disc_cost_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=disc_fake, labels=tf.zeros_like(disc_fake)),
        name="raw_disc_cost_fake")

    raw_disc_cost_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=disc_real, labels=tf.ones_like(disc_real)),
        name="raw_disc_cost_real")

    #disc_cost = raw_disc_cost_real + raw_disc_cost_fake
    #gen_cost = raw_gen_cost

    disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
    gen_cost = -tf.reduce_mean(disc_fake)

    # compute gradient

    grad_input = tf.gradients(gen_cost, raw_fake_data)[0]
    grad_disc = tf.gradients(disc_cost * 1e4, flt_vars)
    # notice : grad disc is 1e4 enlarged

    # update op

    #update_op = get_update_op(flt_vars, grad_disc, lr)
    #train_disc_op = tf.train.AdamOptimizer(lr, 0, 0).minimize(disc_cost, var_list=flt_vars)
    train_disc_op = tf.train.GradientDescentOptimizer(lr).minimize(
        disc_cost, var_list=flt_vars)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    ## initial status

    def basic_statistic(ifhist=False):
        global rec_loss_fake, rec_loss_real, rec_loss_gen

        fetches = [disc_cost, gen_cost]
        sample_disc_cost, sample_gen_cost = sess.run(
            fetches, {
                raw_real_data: trX[:N_SAMPLE],
                raw_fake_data: gen_data
            })
        rec_loss_fake.append(sample_disc_cost)
        rec_loss_gen.append(sample_gen_cost)

        print("iter %d, disc_cost %.3f, gen_cost %.3f" %
              (global_iter, sample_disc_cost, sample_gen_cost))
        logfile.write("iter %d, disc_cost %.3f, gen_cost %.3f" %
                      (global_iter, sample_disc_cost, sample_gen_cost))
        """
        fetches = [disc_real, disc_fake, raw_disc_cost_real, raw_disc_cost_fake, raw_gen_cost]
        if ifhist:
            fetches.extend([grad_input * 1e4, grad_disc])
            sample_disc_real, sample_disc_fake, sample_real_cost, sample_fake_cost, sample_gen_cost, sample_grad_input, sample_grad_disc = sess.run(fetches, {raw_real_data : trX[:N_SAMPLE], raw_fake_data : gen_data})
            hist(sample_disc_real, base_dir + "hist_real_%d"        %   global_iter)
            hist(sample_disc_fake, base_dir + "hist_fake_%d"        %   global_iter)
            hist(flatten_list(sample_grad_input),base_dir + "hist_grad_input_%d"%global_iter, (-1, 1))
            hist(flatten_list(sample_grad_disc),base_dir + "hist_grad_disc_%d"%global_iter, (-1, 1))

        else:
            sample_disc_real, sample_disc_fake, sample_real_cost, sample_fake_cost, sample_gen_cost = sess.run(fetches, {raw_real_data : trX[:N_SAMPLE], raw_fake_data : gen_data})

        rec_loss_real.append(sample_real_cost)
        rec_loss_fake.append(sample_fake_cost)
        rec_loss_gen.append(sample_gen_cost)

        print("iter %d, real_cost %.3f, fake_cost %.3f, gen_cost %.3f" % (global_iter, sample_real_cost, sample_fake_cost, sample_gen_cost))
        logfile.write("iter %d, real_cost %.3f, fake_cost %.3f, gen_cost %.3f\n" % (global_iter, sample_real_cost, sample_fake_cost, sample_gen_cost))    
        """

    def train_op(learning_rate=0.001):
        global global_iter
        sess.run(
            train_disc_op, {
                raw_real_data: trX[:N_SAMPLE],
                raw_fake_data: gen_data,
                lr: learning_rate
            })
        global_iter += 1

    def update_gen_data(learning_rate=1):
        global gen_data
        rg = 0
        for i in range(NGT):
            g = sess.run([grad_input],
                         {raw_fake_data: gen_data})[0] * learning_rate

            gen_data -= g

            rg = g[0, :] + rg

        return rg

    def one_iter(learning_rate=0.01, do_statis=False):
        train_op(learning_rate)
        basic_statistic(False)
        g = update_gen_data(GT)  #.sum(0).reshape(28, 28)
        logfile.write("Iter%d:\n" % global_iter)
        logfile.write("grad input %.4f %.4f %.4f\n" %
                      (g.max(), g.min(), g.mean()))
        logfile.write("gen data %.4f %.4f %.4f\n" %
                      (gen_data[:512].max(), gen_data[:512].min(),
                       gen_data[:512].mean()))

        if do_statis:
            basic_statistic(False)
            #save_batch_img(trX.reshape(-1, 28, 28, 1), base_dir + "real_%d.png" % global_iter)
            #save_batch_img(gen_data[:16].reshape(-1, 28, 28, 1), base_dir + "sample_%d.png" % global_iter)

        logfile.flush()

    basic_statistic()
    try:
        for i in range(500):
            one_iter(0.001, True)
            if i % 10 == 0:
                var_np = sess.run(flt_vars)
                np.save("test/expr/weight/%05d" % i, var_np)

            if i % 100 == 0:
                plot(rec_loss_fake, "test/expr/loss_fake_%d" % NGT)
                plot(rec_loss_real, "test/expr/loss_real_%d" % NGT)
                plot(rec_loss_gen, "test/expr/loss_gen_%d" % NGT)

    except KeyboardInterrupt:
        print("Get Interrupt")
        plt.close()

    plot(rec_loss_fake, "test/expr/loss_fake_%d" % NGT)
    plot(rec_loss_real, "test/expr/loss_real_%d" % NGT)
    plot(rec_loss_gen, "test/expr/loss_gen_%d" % NGT)
    logfile.close()
示例#3
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)

    gen_config = cfg.good_generator.deepmodel_gen({})
    disc_config = cfg.good_generator.deepmodel_disc({})
    gen_model = model.good_generator.DeepGenerator(**gen_config)
    disc_model = model.good_generator.DeepDiscriminator(**disc_config)

    # Data Preparation
    # raw image is 0~255
    print("Get dataset")
    if TFLAGS['dataset_kind'] == "file":
        dataset = dataloader.CustomDataset(root_dir=TFLAGS['data_dir'],
                                           npy_dir=TFLAGS['npy_dir'],
                                           preproc_kind=TFLAGS['preprocess'],
                                           img_size=TFLAGS['input_shape'][:2],
                                           filter_data=TFLAGS['filter_data'],
                                           class_num=TFLAGS['c_len'],
                                           disturb=TFLAGS['disturb'],
                                           flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "numpy":
        train_data, test_data, train_label, test_label = utils.load_mnist_4d(
            TFLAGS['data_dir'])
        train_data = train_data.reshape([-1] + TFLAGS['input_shape'])

        # TODO: validation in training
        dataset = dataloader.NumpyArrayDataset(
            data_npy=train_data,
            label_npy=train_label,
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "fuel":
        dataset = dataloader.FuelDataset(hdfname=TFLAGS['data_dir'],
                                         npy_dir=TFLAGS['npy_dir'],
                                         preproc_kind=TFLAGS['preprocess'],
                                         img_size=TFLAGS['input_shape'][:2],
                                         filter_data=TFLAGS['filter_data'],
                                         class_num=TFLAGS['c_len'],
                                         flip=TFLAGS['preproc_flip'])

    dl = dataloader.CustomDataLoader(dataset,
                                     batch_size=TFLAGS['batch_size'],
                                     num_threads=TFLAGS['data_threads'])

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name="x_real")
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                                 name="c_label")

    # Select loss builder and model trainer
    print("Build training graph")
    if TFLAGS['gan_loss'].find("naive") > -1:
        loss_builder_ = loss.naive_ganloss.NaiveGANLoss
        Trainer = trainer.base_gantrainer.BaseGANTrainer
    elif TFLAGS['gan_loss'].find("wass") > -1:
        loss_builder_ = loss.wass_ganloss.WGANLoss
        Trainer = trainer.base_gantrainer.BaseGANTrainer
    elif TFLAGS['gan_loss'].find("dra") > -1:
        loss_builder_ = loss.dragan_loss.DRAGANLoss
        Trainer = trainer.base_gantrainer.BaseGANTrainer
    elif TFLAGS['gan_loss'].find("ian") > -1:
        loss_builder_ = loss.ian_loss.IANLoss
        Trainer = trainer.ian_trainer.IANTrainer

    if FLAGS.cgan:
        loss_builder = loss_builder_(gen_model=gen_model,
                                     disc_model=disc_model,
                                     gen_inputs=[z_noise, c_noise],
                                     real_inputs=[x_real, c_label],
                                     has_ac=True,
                                     **TFLAGS)
    else:
        loss_builder = loss_builder_(gen_model=gen_model,
                                     disc_model=disc_model,
                                     gen_inputs=[z_noise],
                                     real_inputs=[x_real],
                                     has_ac=False,
                                     **TFLAGS)

    loss_builder.build()
    int_sum_op = tf.summary.merge(loss_builder.inter_sum_op)
    loss_builder.get_trainable_variables()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    if FLAGS.cgan:
        inputs = [z_noise, c_noise, x_real, c_label]
        if TFLAGS['gan_loss'].find("ian") > -1:
            ctrl_weight = {
                "real_cls": loss_builder.real_cls_weight,
                "fake_cls": loss_builder.fake_cls_weight,
                "rec_weight": loss_builder.rec_weight,
                "adv": loss_builder.adv_weight,
                "recadv_weight": loss_builder.recadv_weight
            }
        else:
            ctrl_weight = {
                "real_cls": loss_builder.real_cls_weight,
                "fake_cls": loss_builder.fake_cls_weight,
                "adv": loss_builder.adv_weight
            }
    else:
        inputs = [z_noise, x_real]
        ctrl_weight = {
            "real_cls": loss_builder.real_cls_weight,
            "fake_cls": loss_builder.fake_cls_weight,
            "adv": loss_builder.adv_weight
        }

    ModelTrainer = Trainer(gen_model=gen_model,
                           disc_model=disc_model,
                           inputs=inputs,
                           dataloader=dl,
                           int_sum_op=int_sum_op,
                           ctrl_weight=ctrl_weight,
                           gpu_mem=TFLAGS['gpu_mem'],
                           FLAGS=FLAGS,
                           TFLAGS=TFLAGS)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()
    ModelTrainer.train()
示例#4
0

def analysis_snap(model_typename, model_index):
    np_vars = np.load(model_typename % model_index)

    visualize_fc(np_vars[0])


model_list = [300, 310]
net_out = []

# set up tf
lr = tf.placeholder(DTYPE, shape=[], name="learning_rate")
X = tf.placeholder(DTYPE, [None, 784], "real_data")

trX, teX, trY, teY = load_mnist_4d("data/MNIST")
del teX, teY
trX = trX.reshape(trX.shape[0], -1) / 127.5 - 1

for idx in model_list:
    print("Model %d" % idx)

    # load model weight
    np_vars = np.load(MODEL_TYPENAME % idx)
    # reconstruct the model
    net_out.append(fltlist2net(X, np_vars))

    #visualize_fc(np_vars[0])

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())