Ejemplo n.º 1
0
def get_interval_summary(data1, data2, batch_size):
    len_edge = int(np.sqrt(batch_size))
    if len_edge > 4:
        len_edge = 4
    generated_samples = ops.get_grid_image_summary(data1, len_edge)
    real_samples = ops.get_grid_image_summary(data2, len_edge)
    gen_image_sum = tf.summary.image("generated", generated_samples + 1)
    real_image_sum = tf.summary.image("real", real_samples + 1)
    sum_interval_op = tf.summary.merge([gen_image_sum, real_image_sum])
    return sum_interval_op
Ejemplo n.º 2
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    gen_model, gen_config, disc_model, disc_config = model.get_model(FLAGS.model_name, TFLAGS)
    gen_model = model.good_generator.GoodGeneratorEncoder(**gen_config)
    print("Common length: %d" % gen_model.common_length)
    TFLAGS['AE_weight'] = 1.0

    # Data Preparation
    # raw image is 0~255
    print("Get dataset")
    if TFLAGS['dataset_kind'] == "file":
        dataset = dataloader.CustomDataset(
            root_dir=TFLAGS['data_dir'],
            npy_dir=TFLAGS['npy_dir'],
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            disturb=TFLAGS['disturb'],
            flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "numpy":
        train_data, test_data, train_label, test_label = utils.load_mnist_4d(TFLAGS['data_dir'])
        train_data = train_data.reshape([-1] + TFLAGS['input_shape'])

        # TODO: validation in training
        dataset = dataloader.NumpyArrayDataset(
            data_npy=train_data,
            label_npy=train_label,
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            flip=TFLAGS['preproc_flip'])
    
    dl = dataloader.CustomDataLoader(dataset, batch_size=TFLAGS['batch_size'], num_threads=TFLAGS['data_threads'])

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real")
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="z_noise")
    
    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    inc_length = tf.placeholder(tf.int32, [], name="inc_length")
    lr = tf.placeholder(tf.float32, [], name="lr")

    # build inference network

    # image_inputs : x_real, x_fake, x_trans
    # noise_input : noise, noise_rec, noise_trans

    # feat_image + image_feat
    # x_real -> trans_feat -> x_trans -> trans_x_trans_feat[REC] -> x_trans_trans_

    # feat_noise + noise_feat
    # x_real -> trans_feat -> noise_trans -> noise_x_trans_feat[REC] -> x_trans_fake

    with tf.variable_scope(gen_model.field_name):
        noise_input = tf.concat([z_noise, c_noise], axis=1)
        gen_model.image_input = x_real
        gen_model.noise_input = noise_input
        
        x_fake, x_trans, noise_rec, noise_trans = gen_model.build_inference()
        noise_feat = tf.identity(gen_model.noise_feat, "noise_feat")
        trans_feat = tf.identity(gen_model.image_feat, "trans_feat")
        
        gen_model.noise_input = noise_rec
        gen_model.image_input = x_fake

        x_fake_fake, x_fake_trans, noise_rec_rec, noise_x_fake_trans = gen_model.build_inference()
        noise_rec_feat = tf.identity(gen_model.noise_feat, "noise_rec_feat")
        trans_fake_feat = tf.identity(gen_model.image_feat, "trans_fake_feat")
        
        gen_model.noise_input = noise_trans
        gen_model.image_input = x_trans

        x_trans_fake, x_trans_trans, noise_trans_rec, noise_trans_trans = gen_model.build_inference()
        noise_trans_feat = tf.identity(gen_model.noise_feat, "noise_trans_feat")
        trans_trans_feat = tf.identity(gen_model.image_feat, "trans_trans_feat")

    # noise -> noise_feat -> x_fake | noise_rec
    # image -> trans_feat -> x_trans | noise_trans
    # noise_rec -> noise_rec_feat -> x_fake_fake | noise_rec_rec
    # x_fake -> trans_fake_feat -> x_fake_trans | noise_fake_trans
    # noise_trans -> noise_trans_feat -> x_trans_fake | noise_trans_rec    
    # x_trans -> trans_trans_feat -> x_trans_trans | noise_trans_trans

    # Image Feat Recs:
    # trans_feat == noise_trans_feat : feat_noise + noise_feat
    # trans_feat == trans_trans_feat : feat_image + image_feat
    # noise_feat == noise_rec_feat : feat_noise + noise_feat
    # noise_feat == trans_fake_feat : feat_image + image_feat

    # Noise Recs:
    # noise_input == noise_rec : noise_feat + feat_noise
    
    # Image Recs:
    # x_real -[REC]- x_trans : image_feat + feat_image

    disc_real_out, cls_real_logits = disc_model(x_real)[:2]; disc_model.set_reuse()
    disc_fake_out, cls_fake_logits = disc_model(x_fake)[:2]
    disc_trans_out, cls_trans_logits = disc_model(x_trans)[:2]

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    # Select loss builder and model trainer
    print("Build training graph")

    # Naive GAN
    loss.naive_ganloss.func_gen_loss(disc_fake_out, adv_weight, name="Gen", model=gen_model)
    loss.naive_ganloss.func_disc_fake_loss(disc_fake_out, adv_weight / 2, name="DiscFake", model=disc_model)
    loss.naive_ganloss.func_disc_fake_loss(disc_trans_out, adv_weight / 2, name="DiscTrans", model=disc_model)
    loss.naive_ganloss.func_disc_real_loss(disc_real_out, adv_weight, name="DiscGen", model=disc_model)

    # ACGAN
    real_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=real_cls_logits,
        labels=c_label), name="real_cls_reduce_mean")
    real_cls_cost_sum = tf.summary.scalar("real_cls", real_cls_cost)
    
    fake_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=fake_cls_logits,
        labels=c_noise), name='fake_cls_reduce_mean')
    fake_cls_cost_sum = tf.summary.scalar("fake_cls", fake_cls_cost)
    
    disc_model.cost += real_cls_cost * real_cls_weight + fake_cls_cost * fake_cls_weight
    disc_model.sum_op.extend([real_cls_cost_sum, fake_cls_cost_sum])

    gen_model.cost += fake_cls_cost * fake_cls_weight
    gen_model.sum_op.append(fake_cls_cost_sum)

    # DRAGAN gradient penalty
    disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss(disc_model, x_real, TFLAGS['gp_weight'])
    disc_model.cost += disc_cost_
    disc_model.sum_op.append(disc_sum_)

    gen_cost_, gen_sum_ = loss.naive_ganloss.func_gen_loss(disc_trans_out, adv_weight, name="GenTrans")
    gen_model.cost += gen_cost_

    # Image Feat Recs:
    # trans_feat == noise_trans_feat : feat_noise + noise_feat 1
    # trans_feat == trans_trans_feat : feat_image + image_feat 2
    # noise_feat == noise_rec_feat : feat_noise + noise_feat 1
    # noise_feat == trans_fake_feat : feat_image + image_feat 2
    trans_feat_cost1_, _ = loss.common_loss.reconstruction_loss(
        noise_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat1")
    trans_feat_cost2_, _ = loss.common_loss.reconstruction_loss(
        trans_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat2")
    trans_feat_cost_ = trans_feat_cost1_ + trans_feat_cost2_
    trans_feat_sum_ = tf.summary.scalar("RecTransFeat", trans_feat_cost_)

    noise_feat_cost1_, _ = loss.common_loss.reconstruction_loss(
        noise_rec_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat1")
    noise_feat_cost2_, _ = loss.common_loss.reconstruction_loss(
        trans_fake_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat2")
    noise_feat_cost_ = noise_feat_cost1_ + noise_feat_cost2_
    noise_feat_sum_ = tf.summary.scalar("RecNoiseFeat", noise_feat_cost_)

    # Noise Recs:
    # noise_input -[REC]- noise_rec : noise_feat + feat_noise
    noise_cost_, noise_sum_ = loss.common_loss.reconstruction_loss(
        noise_rec[:, -inc_length:], noise_input[:, -inc_length:], TFLAGS['AE_weight'], name="RecNoise")

    # Image Recs:
    # x_real -[REC]- x_trans : image_feat + feat_image
    rec_cost_, rec_sum_ = loss.common_loss.reconstruction_loss(
        x_trans, x_real, TFLAGS['AE_weight'], name="RecPix")

    gen_model.extra_sum_op = [gen_sum_, trans_feat_sum_, noise_feat_sum_, rec_sum_, noise_sum_]
    gen_model.extra_loss = gen_cost_ + trans_feat_cost_ + noise_feat_cost_ + rec_cost_ + noise_cost_

    gen_model.total_cost = tf.identity(gen_model.extra_loss) + tf.identity(gen_model.cost)

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")
    
    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_trans = ops.get_grid_image_summary(x_trans, edge_num)
    inter_sum_op.append(tf.summary.image("trans image", grid_x_trans))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))

    # merge summary op
    gen_model.extra_sum_op = tf.summary.merge(gen_model.extra_sum_op)
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    print("=> Compute gradient")

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()
    
    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # normal GAN train op
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        # train reconstruction branch
        var_list = gen_model.branch_vars['image_feat'] + gen_model.branch_vars['feat_noise']
        gen_model.extra_train_op_partial1 = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list)

        var_list += gen_model.branch_vars['noise_feat'] + gen_model.branch_vars['feat_image']
        gen_model.extra_train_op_full = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list)

        gen_model.full_train_op = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.total_cost, var_list=gen_model.vars)
        
        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
            "real_cls"  : real_cls_weight,
            "fake_cls"  : fake_cls_weight,
            "adv"       : adv_weight,
            "lr"        : lr,
            "inc_length": inc_length
        }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model" : disc_model,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.base_gantrainer_rep.BaseGANTrainer
    trainer_feed.update({
        "inputs": [z_noise, c_noise, x_real, c_label],
    })
    
    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    ModelTrainer.train()
Ejemplo n.º 3
0
def main():
    size = FLAGS.img_size

    # debug
    if len(FLAGS.train_dir) < 1:
        bn_name = ["nobn", "caffebn", "simplebn", "defaultbn", "cbn"]
        FLAGS.train_dir = os.path.join(
            "logs", FLAGS.model_name + "_" + bn_name[FLAGS.bn] + "_" +
            str(FLAGS.phases))

    if FLAGS.cgan:
        # the label file is npy format
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
        FLAGS.batch_size /= 64
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size))

    dl = dataloader.TFDataloader(dataset, FLAGS.batch_size,
                                 dataset.file_num // FLAGS.batch_size)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    # look up the config function from lib.config module
    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)
    disc_model.norm_mtd = FLAGS.bn

    x_fake = gen_model(gen_input, update_collection=None)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_model.set_label(c_noise)
    if FLAGS.phases > 1:
        disc_model.set_phase("fake")
    else:
        disc_model.set_phase("default")
    disc_fake, fake_cls_logits = disc_model(x_fake, update_collection=None)
    disc_model.set_reuse()

    disc_model.set_label(c_label)
    if FLAGS.phases > 1:
        disc_model.set_phase("real")
    else:
        disc_model.set_phase("default")
    disc_real, real_cls_logits = disc_model(x_real, update_collection=None)
    disc_model.disc_real = disc_real
    disc_model.disc_fake = disc_fake
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    int_sum_op = []

    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)),
            name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample",
                                                 disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    int_sum_op = tf.summary.merge(int_sum_op)

    raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss(
        gen_model, disc_model, adv_weight=1.0, summary=False)
    disc_model.disc_real_loss = raw_disc_real
    disc_model.disc_fake_loss = raw_disc_fake

    if FLAGS.cgan:
        real_cls_cost, fake_cls_cost = loss.classifier_loss(gen_model,
                                                            disc_model,
                                                            x_real,
                                                            c_label,
                                                            c_noise,
                                                            weight=1.0 /
                                                            dataset.class_num,
                                                            summary=False)

    step_sum_op = []
    subloss_names = ["fake_cls", "real_cls", "gen", "disc_real", "disc_fake"]
    sublosses = [
        fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real,
        raw_disc_fake
    ]
    for n, l in zip(subloss_names, sublosses):
        step_sum_op.append(tf.summary.scalar(n, l))
    step_sum_op = tf.summary.merge(step_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(  #trainer.separated_gantrainer.SeparatedGANTrainer(#
        int_sum_op=int_sum_op,
        step_sum_op=step_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label,
        sample_method=sample_method)

    #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    #command_controller.start_thread()

    print("=> Build train op")
    ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))
    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))

    ModelTrainer.init_training()
    ModelTrainer.train()
Ejemplo n.º 4
0
def main():
    size = FLAGS.img_size

    if len(FLAGS.train_dir) < 1:
        # if the train dir is not set, then automatically decide one
        FLAGS.train_dir = os.path.join("logs",
                                       FLAGS.model_name + str(FLAGS.img_size))
        if FLAGS.cgan:
            FLAGS.train_dir += "_cgan"

    if FLAGS.cgan:
        # the label file should be npy format
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        # outdated
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
        FLAGS.batch_size /= 64
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size))

    dl = dataloader.TFDataloader(dataset, FLAGS.batch_size,
                                 dataset.file_num // FLAGS.batch_size)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise
        c_label = c_noise = None

    # look up the config function from lib.config module
    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)

    gen_model.label = c_noise
    x_fake = gen_model(gen_input)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_model.label = c_noise
    disc_fake, fake_cls_logits = disc_model(x_fake)
    disc_model.set_reuse()

    disc_model.label = c_label
    disc_real, real_cls_logits = disc_model(x_real)
    disc_model.disc_real = disc_real
    disc_model.disc_fake = disc_fake
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss(
        gen_model, disc_model, adv_weight=1.0, summary=False)
    disc_model.disc_real_loss = raw_disc_real
    disc_model.disc_fake_loss = raw_disc_fake

    if FLAGS.cgan:
        real_cls_cost, fake_cls_cost = loss.classifier_loss(gen_model,
                                                            disc_model,
                                                            x_real,
                                                            c_label,
                                                            c_noise,
                                                            weight=1.0 /
                                                            dataset.class_num,
                                                            summary=False)
        subloss_names = [
            "fake_cls", "real_cls", "gen", "disc_real", "disc_fake"
        ]
        sublosses = [
            fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real,
            raw_disc_fake
        ]
    else:
        subloss_names = ["gen", "disc_real", "disc_fake"]
        sublosses = [raw_gen_cost, raw_disc_real, raw_disc_fake]

    step_sum_op = []  # summary at every step

    for n, l in zip(subloss_names, sublosses):
        step_sum_op.append(tf.summary.scalar(n, l))
    if gen_model.debug or disc_model.debug:
        for model_var in tf.global_variables():
            if gen_model.name in model_var.op.name or disc_model.name in model_var.op.name:
                step_sum_op.append(
                    tf.summary.histogram(model_var.op.name, model_var))
    step_sum_op = tf.summary.merge(step_sum_op)

    int_sum_op = []  # summary at some interval

    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    int_sum_op = tf.summary.merge(int_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        int_sum_op=int_sum_op,
        step_sum_op=step_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label)

    print("=> Build train op")
    ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))
    print("=> #### Generator update dependency ####")
    for v in gen_model.update_ops:
        print("%s" % (v.name))
    print("=> #### Discriminator update dependency ####")
    for v in disc_model.update_ops:
        print("%s" % (v.name))
    ModelTrainer.init_training()
    ModelTrainer.train()
Ejemplo n.º 5
0
def main():
    size = FLAGS.img_size

    if FLAGS.cgan:
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size),
                                         shuffle=True)

    dl = DataLoader(dataset,
                    batch_size=FLAGS.batch_size // 64,
                    shuffle=True,
                    num_workers=NUM_WORKER,
                    collate_fn=dataloader.default_collate)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)
    gen_model.cbn_project = FLAGS.cbn_project

    x_fake = gen_model(gen_input, update_collection=None)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_real, real_cls_logits = disc_model(x_real, update_collection=None)
    disc_model.set_reuse()
    disc_fake, fake_cls_logits = disc_model(x_fake, update_collection="no_ops")
    disc_model.disc_real = disc_real
    disc_model.disc_fake = disc_fake
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    int_sum_op = []

    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)),
            name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample",
                                                 disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    if FLAGS.cgan:
        loss.classifier_loss(gen_model,
                             disc_model,
                             x_real,
                             c_label,
                             c_noise,
                             weight=1.0)

    loss.hinge_loss(gen_model, disc_model, adv_weight=1.0)

    int_sum_op = tf.summary.merge(int_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        int_sum_op=int_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label,
        sample_method=sample_method)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    print("=> Build train op")
    ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))
    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))

    ModelTrainer.init_training()
    ModelTrainer.train()
Ejemplo n.º 6
0
def main():
    size = FLAGS.img_size

    if FLAGS.cgan:
        # the label file is npy format
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
        FLAGS.batch_size /= 64
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size))

    dl = dataloader.TFDataloader(dataset, FLAGS.batch_size,
                                 dataset.file_num // FLAGS.batch_size)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    # look up the config function from lib.config module
    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)
    gen_model.cbn_project = FLAGS.cbn_project
    gen_model.spectral_norm = FLAGS.sn
    disc_model.cbn_project = FLAGS.cbn_project
    disc_model.spectral_norm = FLAGS.sn

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        step_sum_op=None,
        int_sum_op=None,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label)

    g_tower_grads = []
    d_tower_grads = []
    g_optim = tf.train.AdamOptimizer(learning_rate=ModelTrainer.g_lr,
                                     beta1=0.,
                                     beta2=0.9)
    d_optim = tf.train.AdamOptimizer(learning_rate=ModelTrainer.d_lr,
                                     beta1=0.,
                                     beta2=0.9)

    grad_x = []
    grad_x_name = []
    xs = []
    x_name = []

    def tower(gpu_id,
              gen_input,
              x_real,
              c_label=None,
              c_noise=None,
              update_collection=None,
              loss_collection=[]):
        """
        The loss function builder of gen and disc
        """
        gen_model.cost = disc_model.cost = 0

        gen_model.set_phase("gpu%d" % gpu_id)
        x_fake = gen_model(gen_input, update_collection=update_collection)
        gen_model.set_reuse()
        gen_model.x_fake = x_fake

        disc_model.set_phase("gpu%d" % gpu_id)
        disc_real, real_cls_logits = disc_model(
            x_real, update_collection=update_collection)
        disc_model.set_reuse()
        disc_model.recorded_tensors = []
        disc_model.recorded_names = []
        disc_fake, fake_cls_logits = disc_model(
            x_fake, update_collection=update_collection)
        disc_model.disc_real = disc_real
        disc_model.disc_fake = disc_fake
        disc_model.real_cls_logits = real_cls_logits
        disc_model.fake_cls_logits = fake_cls_logits

        if FLAGS.cgan:
            fake_cls_cost, real_cls_cost = loss.classifier_loss(
                gen_model,
                disc_model,
                x_real,
                c_label,
                c_noise,
                weight=1.0 / dataset.class_num,
                summary=False)

        raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss(
            gen_model, disc_model, adv_weight=1.0, summary=False)

        gen_model.vars = [
            v for v in tf.trainable_variables() if gen_model.name in v.name
        ]
        disc_model.vars = [
            v for v in tf.trainable_variables() if disc_model.name in v.name
        ]
        g_grads = tf.gradients(gen_model.cost,
                               gen_model.vars,
                               colocate_gradients_with_ops=True)
        d_grads = tf.gradients(disc_model.cost,
                               disc_model.vars,
                               colocate_gradients_with_ops=True)
        g_grads = [
            tf.check_numerics(g, "G grad nan: " + str(g)) for g in g_grads
        ]
        d_grads = [
            tf.check_numerics(g, "D grad nan: " + str(g)) for g in d_grads
        ]
        g_tower_grads.append(g_grads)
        d_tower_grads.append(d_grads)

        tensors = gen_model.recorded_tensors + disc_model.recorded_tensors
        names = gen_model.recorded_names + disc_model.recorded_names
        if gpu_id == 0: x_name.extend(names)
        xs.append(tensors)
        names = names[::-1]
        tensors = tensors[::-1]
        grads = tf.gradients(disc_fake,
                             tensors,
                             colocate_gradients_with_ops=True)
        for n, g in zip(names, grads):
            print(n, g)
        grad_x.append(
            [tf.check_numerics(g, "BP nan: " + str(g)) for g in grads])
        if gpu_id == 0: grad_x_name.extend(names)
        disc_model.recorded_tensors = []
        disc_model.recorded_names = []
        gen_model.recorded_tensors = []
        gen_model.recorded_names = []

        return gen_model.cost, disc_model.cost, [
            fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real,
            raw_disc_fake
        ]

    def average_gradients(tower_grads):
        average_grads = []
        num_gpus = len(tower_grads)
        num_items = len(tower_grads[0])
        for i in range(num_items):
            average_grads.append(0.0)
            for j in range(num_gpus):
                average_grads[i] += tower_grads[j][i]
            average_grads[i] /= num_gpus
        return average_grads

    sbs = FLAGS.batch_size // NUM_GPU
    for i in range(NUM_GPU):
        if i == 0:
            update_collection = None
        else:
            update_collection = "no_ops"

        with tf.device(tf.DeviceSpec(device_type="GPU", device_index=i)):
            if FLAGS.cgan:
                l1, l2, ot1 = tower(i, [
                    z_noise[sbs * i:sbs * (i + 1)], c_noise[sbs * i:sbs *
                                                            (i + 1)]
                ],
                                    x_real[sbs * i:sbs * (i + 1)],
                                    c_label[sbs * i:sbs * (i + 1)],
                                    c_noise[sbs * i:sbs * (i + 1)],
                                    update_collection=update_collection)
            else:
                l1, l2, ot1 = tower(i,
                                    z_noise[sbs * i:sbs * (i + 1)],
                                    x_real[sbs * i:sbs * (i + 1)],
                                    update_collection=update_collection)

        if i == 0:
            int_sum_op = []

            grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
            int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

            grid_x_real = ops.get_grid_image_summary(x_real, 4)
            int_sum_op.append(tf.summary.image("real image", grid_x_real))

            step_sum_op = []
            sub_loss_names = [
                "fake_cls", "real_cls", "gen", "disc_real", "disc_fake"
            ]
            for n, l in zip(sub_loss_names, ot1):
                step_sum_op.append(tf.summary.scalar(n, l))
            step_sum_op.append(tf.summary.scalar("gen", gen_model.cost))
            step_sum_op.append(tf.summary.scalar("disc", disc_model.cost))

    with tf.device(tf.DeviceSpec(device_type="GPU", device_index=0)):
        g_grads = average_gradients(g_tower_grads)
        d_grads = average_gradients(d_tower_grads)

        gen_model.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                 scope=gen_model.name + "/")
        disc_model.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                  scope=disc_model.name + "/")

        print(gen_model.update_ops)
        print(disc_model.update_ops)

        def merge_list(l1, l2):
            return [[l1[i], l2[i]] for i in range(len(l1))]

        with tf.control_dependencies(gen_model.update_ops):
            gen_model.train_op = g_optim.apply_gradients(
                merge_list(g_grads, gen_model.vars))
        with tf.control_dependencies(disc_model.update_ops):
            disc_model.train_op = d_optim.apply_gradients(
                merge_list(d_grads, disc_model.vars))

    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)),
            name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample",
                                                 disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    ModelTrainer.int_sum_op = tf.summary.merge(int_sum_op)
    ModelTrainer.step_sum_op = tf.summary.merge(step_sum_op)
    ModelTrainer.grad_x = grad_x
    ModelTrainer.grad_x_name = grad_x_name
    ModelTrainer.xs = xs
    ModelTrainer.x_name = x_name
    #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    #command_controller.start_thread()

    print("=> Build train op")
    # ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_variables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_variables()

    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))

    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))

    ModelTrainer.init_training()
    ModelTrainer.train()
Ejemplo n.º 7
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    gen_model, gen_config, disc_model, disc_config = model.get_model(
        FLAGS.model_name, TFLAGS)
    gen_model = model.conditional_generator.ImageConditionalEncoder()
    disc_model = model.conditional_generator.ImageConditionalDeepDiscriminator(
    )
    gen_model.name = "ImageConditionalEncoder"
    disc_model.name = "ImageConditionalDeepDiscriminator"
    print("Common length: %d" % gen_model.common_length)
    TFLAGS['AE_weight'] = 1.0
    TFLAGS['batch_size'] = 1
    TFLAGS['input_shape'] = [256, 256, 3]

    face_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/crop_character",
        npy_dir="/data/datasets/getchu/true_character.npy",
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        has_gray=False,
        disturb=False,
        flip=False)

    sketch_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/crop_sketch_character/",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        has_gray=True,
        disturb=False,
        flip=False)

    label_dataset = dataloader.NumpyArrayDataset(
        np.load("/data/datasets/getchu/true_face.npy")[:, 1:])

    listsampler = dataloader.ListSampler(
        [face_dataset, sketch_dataset, label_dataset])
    dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4)

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name="x_real")
    fake_x_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="x_real")
    # semantic segmentation
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'][:-1] + [
        1,
    ],
                            name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="z_noise")
    c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_noise")
    c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    inc_length = tf.placeholder(tf.int32, [], name="inc_length")
    lr = tf.placeholder(tf.float32, [], name="lr")

    with tf.variable_scope(gen_model.name):
        gen_model.image_input = x_real
        gen_model.seg_input = s_real
        gen_model.noise_input = tf.concat([z_noise, c_noise], axis=1)

        seg_image, image_image, seg_seg, image_seg = gen_model.build_inference(
        )
        seg_feat = tf.identity(gen_model.seg_feat, "seg_feat")
        image_feat = tf.identity(gen_model.image_feat, "image_feat")

        gen_model.x_fake = tf.identity(seg_image)

    disc_real_out = disc_model([x_real, s_real])
    disc_model.set_reuse()
    disc_fake_out = disc_model([seg_image, s_real])
    disc_fake_sketch_out = disc_model([x_real, image_seg])
    #disc_rec_out  = disc_model([image_image, s_real])
    disc_fake_sample_out = disc_model([fake_x_sample, s_real])

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    # Select loss builder and model trainer
    print("Build training graph")

    # Naive GAN
    loss.naive_ganloss.func_gen_loss(disc_fake_out,
                                     adv_weight * 0.9,
                                     name="GenSeg",
                                     model=gen_model)
    loss.naive_ganloss.func_gen_loss(disc_fake_sketch_out,
                                     adv_weight * 0.1,
                                     name="GenSketch",
                                     model=gen_model)
    if FLAGS.cache:
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_sample_out,
                                               adv_weight * 0.9,
                                               name="DiscFake",
                                               model=disc_model)
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_sample_out,
                                               adv_weight * 0.1,
                                               name="DiscFakeSketch",
                                               model=disc_model)
    else:
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_out,
                                               adv_weight * 0.9,
                                               name="DiscFake",
                                               model=disc_model)
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_sketch_out,
                                               adv_weight * 0.1,
                                               name="DiscFakeSketch",
                                               model=disc_model)

    loss.naive_ganloss.func_disc_real_loss(disc_real_out,
                                           adv_weight,
                                           name="DiscGen",
                                           model=disc_model)

    # recontrust of sketch
    rec_sketch_cost_, rec_sketch_sum_ = loss.common_loss.reconstruction_loss(
        image_seg, s_real, TFLAGS['AE_weight'], name="RecSketch")
    gen_model.cost += rec_sketch_cost_
    gen_model.sum_op.append(rec_sketch_sum_)

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")

    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(seg_image, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_seg = ops.get_grid_image_summary(image_seg, edge_num)
    inter_sum_op.append(tf.summary.image("inv image", grid_x_seg))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))
    grid_s_real = ops.get_grid_image_summary(s_real, edge_num)
    inter_sum_op.append(tf.summary.image("sketch image", grid_s_real))

    # merge summary op
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    print("=> Compute gradient")

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()

    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # normal GAN train op
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
        "real_cls": real_cls_weight,
        "fake_cls": fake_cls_weight,
        "adv": adv_weight,
        "lr": lr,
        "inc_length": inc_length
    }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model": disc_model,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.cgantrainer.CGANTrainer
    Trainer.use_cache = False
    trainer_feed.update({
        "inputs": [x_real, s_real, c_label],
        "noise": z_noise,
        "cond": c_noise
    })

    ModelTrainer = Trainer(**trainer_feed)
    ModelTrainer.fake_x_sample = fake_x_sample

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    ModelTrainer.train()
Ejemplo n.º 8
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)

    gen_config = cfg.good_generator.goodmodel_gen({})
    gen_config['name'] = "CondDeepGenerator"
    gen_model = model.conditional_generator.ImageConditionalDeepGenerator2(
        **gen_config)
    disc_config = cfg.good_generator.goodmodel_disc({})
    disc_config['norm_mtd'] = None
    disc_model = model.good_generator.GoodDiscriminator(**disc_config)

    TFLAGS['batch_size'] = 1
    TFLAGS['AE_weight'] = 10.0
    TFLAGS['side_noise'] = FLAGS.side_noise
    TFLAGS['input_shape'] = [128, 128, 3]
    # Data Preparation
    # raw image is 0~255
    print("Get dataset")

    face_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/true_face",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    sketch_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/sketch_face/",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    label_dataset = dataloader.NumpyArrayDataset(
        np.load("/data/datasets/getchu/true_face.npy")[:, 1:])

    listsampler = dataloader.ListSampler(
        [sketch_dataset, face_dataset, label_dataset])
    dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4)

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name="x_real")
    s_real = tf.placeholder(tf.float32,
                            [None] + TFLAGS['input_shape'][:-1] + [1],
                            name='s_real')
    fake_x_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="x_real")
    noise_A = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_A")
    noise_B = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_B")
    label_A = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_A")
    label_B = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_B")

    # c label is for real samples
    c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_label")
    c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_noise")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    lr = tf.placeholder(tf.float32, [], name="lr")

    side_noise_A = tf.concat([noise_A, c_label], axis=1, name='side_noise_A')

    if TFLAGS['side_noise']:
        gen_model.side_noise = side_noise_A
    x_fake = gen_model([s_real])
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    if FLAGS.cgan:
        disc_model.is_cgan = True
    disc_model.disc_real_out, disc_model.real_cls_logits = disc_model(
        x_real)[:2]
    disc_model.set_reuse()
    disc_model.disc_fake_out, disc_model.fake_cls_logits = disc_model(
        x_fake)[:2]

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    if TFLAGS['gan_loss'] == "dra":
        # naive GAN loss
        gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.naive_ganloss.get_naive_ganloss(
            gen_model, disc_model, adv_weight)
        gen_model.cost += gen_cost_
        disc_model.cost += disc_cost_
        gen_model.sum_op.extend(gen_sum_)
        disc_model.sum_op.extend(disc_sum_)

        gen_model.gen_cost = gen_cost_
        disc_model.disc_cost = disc_cost_

        # dragan loss
        disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss(
            disc_model, x_real, TFLAGS['gp_weight'])
        disc_model.cost += disc_cost_
        disc_model.sum_op.append(disc_sum_)

    elif TFLAGS['gan_loss'] == "wass":
        gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.wass_ganloss.wass_gan_loss(
            gen_model, disc_model, x_real, x_fake)
        gen_model.cost += gen_cost_
        disc_model.cost += disc_cost_
        gen_model.sum_op.extend(gen_sum_)
        disc_model.sum_op.extend(disc_sum_)

    elif TFLAGS['gan_loss'] == "naive":
        # naive GAN loss
        gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.naive_ganloss.get_naive_ganloss(
            gen_model, disc_model, adv_weight, lsgan=True)
        gen_model.cost += gen_cost_
        disc_model.cost += disc_cost_
        gen_model.sum_op.extend(gen_sum_)
        disc_model.sum_op.extend(disc_sum_)

        gen_model.gen_cost = gen_cost_
        disc_model.disc_cost = disc_cost_

    if FLAGS.cgan:
        real_cls = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_model.real_cls_logits, labels=c_label),
            name="real_cls_reduce_mean") * real_cls_weight

        #fake_cls = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #    logits=disc_model.fake_cls_logits,
        #    labels=c_label), name="fake_cls_reduce_mean") * fake_cls_weight

        disc_model.cost += real_cls  # + fake_cls
        disc_model.sum_op.extend([tf.summary.scalar("RealCls", real_cls)])  #,
        #tf.summary.scalar("FakeCls", fake_cls)])

    gen_cost_, ae_loss_sum_ = loss.common_loss.reconstruction_loss(
        x_fake, x_real, TFLAGS['AE_weight'])
    gen_model.sum_op.append(ae_loss_sum_)
    gen_model.cost += gen_cost_

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")

    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    grid_x_real = tf.Print(
        grid_x_real, [tf.reduce_max(grid_x_real),
                      tf.reduce_min(grid_x_real)], "Real")
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))
    inter_sum_op.append(
        tf.summary.image("sketch image",
                         ops.get_grid_image_summary(s_real, edge_num)))

    # merge summary op
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()

    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
        "real_cls": real_cls_weight,
        "fake_cls": fake_cls_weight,
        "adv": adv_weight,
        "lr": lr
    }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model": disc_model,
        "noise": noise_A,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.cgantrainer.CGANTrainer
    trainer_feed.update({
        "inputs": [s_real, x_real, c_label],
    })

    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    disc_model.load_from_npz('success/goodmodel_dragan_anime1_disc.npz',
                             ModelTrainer.sess)

    ModelTrainer.train()
Ejemplo n.º 9
0
def main():
    size = FLAGS.img_size

    if FLAGS.cgan:
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
            img_size=(size, size),
            npy_dir=npy_dir)
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
            npy_dir=npy_dir,
            img_size=(size, size),
            shuffle=True)
    dl = DataLoader(dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=NUM_WORKER)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num)
    gen_model.mask_num = FLAGS.mask_num
    gen_model.cbn_project = FLAGS.cbn_project

    x_fake = gen_model(gen_input, update_collection=None)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_real, real_cls_logits = disc_model(x_real, update_collection="no_ops")
    disc_model.set_reuse()
    disc_fake, fake_cls_logits = disc_model(x_fake, update_collection=None)
    disc_model.disc_real        = disc_real       
    disc_model.disc_fake        = disc_fake       
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    int_sum_op = []
    
    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)), name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample", disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    print("=> Mask num " + str(gen_model.overlapped_mask.get_shape()))

    # diverse mask
    diverse_loss, diverse_loss_sum = loss.cosine_diverse_distribution(gen_model.overlapped_mask)

    # make sure mask is not eliminated
    mask_weight = tf.reduce_sum(gen_model.overlapped_mask, [1, 2])
    mask_num = mask_weight.get_shape().as_list()[-1]
    avg_map_weight = (size ** 2) / float(mask_num)
    diff_map = tf.abs(mask_weight - avg_map_weight)
    restricted_diff_map = tf.nn.relu(diff_map - 2 * avg_map_weight)
    restricted_var_loss = 1e-3 * tf.reduce_mean(restricted_diff_map)
    var_loss_sum = tf.summary.scalar("variance loss", restricted_var_loss)

    # semantic
    """
    uniform_loss = 0
    vgg_net = model.classifier.MyVGG16("lib/tensorflowvgg/vgg16.npy")
    vgg_net.build(tf.image.resize_bilinear(x_fake, (224, 224)))
    sf = vgg_net.conv3_3
    mask_shape = sf.get_shape().as_list()[1:3]
    print("=> VGG feature shape: " + str(mask_shape))
    diff_maps = []
    for i in range(mask_num):
        mask = tf.image.resize_bilinear(gen_model.overlapped_mask[:, :, :, i:i+1], mask_shape) # (batch, size, size, 1)
        mask = mask / tf.reduce_sum(mask, [1, 2], keepdims=True)
        expected_feature = tf.reduce_sum(mask * sf, [1, 2], keepdims=True) # (batch, 1, 1, 256)
        diff_map = tf.reduce_mean(tf.abs(mask * (sf - expected_feature)), [3]) # (batch, size, size)
        diff_maps.append(diff_map[0] / tf.reduce_max(diff_map[0]))
        restricted_diff_map = diff_map # TODO: add margin
        uniform_loss += 1e-3 * tf.reduce_mean(tf.reduce_sum(diff_map, [1, 2]))
    uniform_loss_sum = tf.summary.scalar("uniform loss", uniform_loss)
    """

    # smooth mask
    tv_loss = tf.reduce_mean(tf.image.total_variation(gen_model.overlapped_mask)) / (size ** 2)
    tv_sum = tf.summary.scalar("TV loss", tv_loss)

    gen_model.cost += diverse_loss + tv_loss + restricted_var_loss

    gen_model.sum_op.extend([tv_sum, diverse_loss_sum, var_loss_sum])

    edge_num = int(np.sqrt(gen_model.overlapped_mask.get_shape().as_list()[-1]))
    mask_seq = tf.transpose(gen_model.overlapped_mask[0], [2, 0, 1])
    grid_mask = tf.expand_dims(ops.get_grid_image_summary(mask_seq, edge_num), -1)
    int_sum_op.append(tf.summary.image("stroke mask", grid_mask))

    #uniform_diff_map = tf.expand_dims(ops.get_grid_image_summary(tf.stack(diff_maps, 0), edge_num), -1)
    #int_sum_op.append(tf.summary.image("uniform diff map", uniform_diff_map))
    
    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    if FLAGS.cgan:
        loss.classifier_loss(gen_model, disc_model, x_real, c_label, c_noise,
        weight=1.0)

    loss.hinge_loss(gen_model, disc_model, adv_weight=1.0)
    
    int_sum_op = tf.summary.merge(int_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        int_sum_op=int_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label,
        sample_method=sample_method)

    #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    #command_controller.start_thread()

    print("=> Build train op")
    ModelTrainer.build_train_op()
    
    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ModelTrainer.init_training()
    ModelTrainer.train()
Ejemplo n.º 10
0
    def build_train(self, input):
        # ADD label
        fake_data, real_data, label = input
        with tf.variable_scope(self.name):
            self.disc_fake, self.disc_fake_cls = self.build_inference(
                fake_data)
            self.reuse = True
            self.disc_real, self.disc_real_cls = self.build_inference(
                real_data)
        self.vars = [
            v for v in tf.trainable_variables() if self.name in v.name
        ]

        self.tot_loss = 0

        if self.is_wgan:
            self.gen_cost = -tf.reduce_mean(self.disc_fake)
            self.disc_cost = tf.reduce_mean(self.disc_fake) - tf.reduce_mean(
                self.disc_real)

            alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1],
                                      minval=0.,
                                      maxval=1.)
            differences = real_data - fake_data
            interpolates = fake_data + alpha * differences  #tf.multiply(alpha, differences)
            with tf.variable_scope(self.name):
                self.disc_interp = self.build_inference(interpolates)
            gradients = tf.gradients(self.disc_interp, [interpolates])[0]
            slopes = tf.sqrt(
                tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            self.gradient_penalty = self.lambda_gp * tf.reduce_mean(
                (slopes - 1.)**2)
            self.tot_loss += self.gradient_penalty

        else:
            # normal disc
            self.gen_cost = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_fake,
                                                        labels=tf.ones_like(
                                                            self.disc_fake)))

            self.disc_cost_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_fake,
                                                        labels=tf.zeros_like(
                                                            self.disc_fake)))

            self.disc_cost_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_real,
                                                        labels=tf.ones_like(
                                                            self.disc_real)))

            self.disc_cost = self.disc_cost_fake + self.disc_cost_real

        self.cls_cost = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_fake_cls,
                                                    labels=label))
        self.cls_cost += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_real_cls,
                                                    labels=label))

        reg_losses = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

        self.tot_loss += self.disc_cost + reg_losses + self.cls_cost

        with tf.name_scope(self.name + "_Loss"):
            l2_sum = tf.summary.scalar("L2 regularize", reg_losses)
            cls_sum = tf.summary.scalar("classification", self.cls_cost)
            if self.is_wgan:
                gp_sum = tf.summary.scalar("gradient penalty",
                                           self.gradient_penalty)
                disc_cost_sum = tf.summary.scalar("discriminator loss",
                                                  self.disc_cost)
            else:
                disc_cost_real_sum = tf.summary.scalar(
                    "discriminator loss real", self.disc_cost_real)
                disc_cost_fake_sum = tf.summary.scalar(
                    "discriminator loss fake", self.disc_cost_fake)
            gen_cost_sum = tf.summary.scalar("generator loss", self.gen_cost)
            self.sum_op = tf.summary.merge_all()
            # grid summary 4x4
            show_img = ops.get_grid_image_summary(fake_data, 4)
            gen_image_sum = tf.summary.image("generated", show_img + 1)
            self.sum_interval_op = tf.summary.merge([gen_image_sum])
Ejemplo n.º 11
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    TFLAGS['AE_weight'] = 10.0
    TFLAGS['side_noise'] = False

    # A: sketch face, B: anime face
    # Gen and Disc usually init from pretrained model
    print("=> Building discriminator")
    disc_config = cfg.good_generator.goodmodel_disc({})
    disc_config['name'] = "DiscA"
    DiscA = model.good_generator.GoodDiscriminator(**disc_config)
    disc_config['name'] = "DiscB"
    DiscB = model.good_generator.GoodDiscriminator(**disc_config)

    gen_config = cfg.good_generator.goodmodel_gen({})
    gen_config['name'] = 'GenA'
    gen_config['out_dim'] = 3
    GenA = model.good_generator.GoodGenerator(**gen_config)
    gen_config['name'] = 'GenB'
    gen_config['out_dim'] = 1
    GenB = model.good_generator.GoodGenerator(**gen_config)

    gen_config = {}
    gen_config['name'] = 'TransForward'
    gen_config['out_dim'] = 3
    TransF = model.conditional_generator.ImageConditionalDeepGenerator2(
        **gen_config)
    gen_config['name'] = 'TransBackward'
    gen_config['out_dim'] = 1
    TransB = model.conditional_generator.ImageConditionalDeepGenerator2(
        **gen_config)

    models = [DiscA, DiscB, TransF, TransB]
    model_names = ["DiscA", "DiscB", "TransF", "TransB"]

    TFLAGS['batch_size'] = 1
    TFLAGS['input_shape'] = [128, 128, 3]
    # Data Preparation
    # raw image is 0~255
    print("Get dataset")

    face_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/true_face",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    sketch_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/sketch_face/",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    label_dataset = dataloader.NumpyArrayDataset(
        np.load("/data/datasets/getchu/true_face.npy")[:, 1:])

    listsampler = dataloader.ListSampler(
        [sketch_dataset, face_dataset, label_dataset])
    dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4)

    # TF Input
    real_A_sample = tf.placeholder(tf.float32,
                                   [None] + TFLAGS['input_shape'][:-1] + [1],
                                   name='real_A_sample')
    real_B_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="real_B_sample")
    fake_A_sample = tf.placeholder(tf.float32,
                                   [None] + TFLAGS['input_shape'][:-1] + [1],
                                   name='fake_A_sample')
    fake_B_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="fake_B_sample")

    noise_A = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_A")
    noise_B = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_B")
    label_A = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_A")
    label_B = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_B")
    # c label is for real samples
    c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    lr = tf.placeholder(tf.float32, [], name="lr")

    inter_sum_op = []

    ### build graph
    """
    fake_A = GenA(noise_A); GenA.set_reuse() # GenA and GenB only used once
    fake_B = GenB(noise_B); GenB.set_reuse()

    # transF and TransB used three times
    trans_real_A = TransF(real_A); TransF.set_reuse() # domain B
    trans_fake_A = TransF(fake_A) # domain B
    trans_real_B = TransB(real_B); TransB.set_reuse() # domain A
    trans_fake_B = TransB(fake_B) # domain A
    rec_real_A = TransB(trans_real_A)
    rec_fake_A = TransB(trans_fake_A)
    rec_real_B = TransF(trans_real_B)
    rec_fake_B = TransF(trans_fake_B)

    # DiscA and DiscB reused many times
    disc_real_A, cls_real_A = DiscA(x_real)[:2]; DiscA.set_reuse() 
    disc_fake_A, cls_fake_A = DiscA(x_fake)[:2]; 
    disc_trans_real_B, cls_trans_real_B = DiscA(trans_real_B)[:2]
    disc_trans_fake_B, cls_trans_fake_B = DiscA(trans_real_B)[:2]
    disc_rec_real_A, cls_rec_real_A = DiscA(rec_real_A)[:2]
    disc_rec_fake_A, cls_rec_fake_A = DiscA(rec_real_A)[:2]

    disc_real_B, cls_real_B = DiscB(x_real)[:2]; DiscB.set_reuse() 
    disc_fake_B, cls_fake_B = DiscB(x_fake)[:2]; 
    disc_trans_real_A, cls_trans_real_A = DiscB(trans_real_A)[:2]
    disc_trans_fake_A, cls_trans_fake_A = DiscB(trans_real_A)[:2]
    disc_rec_real_B, cls_rec_real_B = DiscB(rec_real_B)[:2]
    disc_rec_fake_B, cls_rec_fake_B = DiscB(rec_real_B)[:2]
    """

    side_noise_A = tf.concat([noise_A, label_A], axis=1, name='side_noise_A')
    side_noise_B = tf.concat([noise_B, label_B], axis=1, name='side_noise_B')

    trans_real_A = TransF(real_A_sample)
    TransF.set_reuse()  # domain B
    trans_real_B = TransB(real_B_sample)
    TransB.set_reuse()  # domain A
    TransF.trans_real_A = trans_real_A
    TransB.trans_real_B = trans_real_B
    rec_real_A = TransB(trans_real_A)
    rec_real_B = TransF(trans_real_B)

    # start fake building
    if TFLAGS['side_noise']:
        TransF.side_noise = side_noise_A
        TransB.side_noise = side_noise_B
        trans_fake_A = TransF(real_A_sample)
        trans_fake_B = TransB(real_B_sample)

    DiscA.fake_sample = fake_A_sample
    DiscB.fake_sample = fake_B_sample
    disc_fake_A_sample, cls_fake_A_sample = DiscA(fake_A_sample)[:2]
    DiscA.set_reuse()
    disc_fake_B_sample, cls_fake_B_sample = DiscB(fake_B_sample)[:2]
    DiscB.set_reuse()
    disc_real_A_sample, cls_real_A_sample = DiscA(real_A_sample)[:2]
    disc_real_B_sample, cls_real_B_sample = DiscB(real_B_sample)[:2]
    disc_trans_real_A, cls_trans_real_A = DiscB(trans_real_A)[:2]
    disc_trans_real_B, cls_trans_real_B = DiscA(trans_real_B)[:2]

    if TFLAGS['side_noise']:
        disc_trans_fake_A, cls_trans_fake_A = DiscB(trans_real_A)[:2]
        disc_trans_fake_B, cls_trans_fake_B = DiscA(trans_real_B)[:2]

    def disc_loss(disc_fake_out,
                  disc_real_out,
                  disc_model,
                  adv_weight=1.0,
                  name="NaiveDisc",
                  acc=True):
        softL_c = 0.05
        with tf.name_scope(name):
            raw_disc_cost_real = tf.reduce_mean(
                tf.square(disc_real_out - tf.ones_like(disc_real_out) *
                          np.abs(np.random.normal(1.0, softL_c))),
                name="raw_disc_cost_real")

            raw_disc_cost_fake = tf.reduce_mean(
                tf.square(disc_fake_out - tf.zeros_like(disc_fake_out)),
                name="raw_disc_cost_fake")

            disc_cost = tf.multiply(
                adv_weight, (raw_disc_cost_fake + raw_disc_cost_real) / 2,
                name="disc_cost")

            disc_fake_sum = [
                tf.summary.scalar("DiscFakeRaw", raw_disc_cost_fake),
                tf.summary.scalar("DiscRealRaw", raw_disc_cost_real)
            ]

            if acc:
                disc_model.cost += disc_cost
                disc_model.sum_op.extend(disc_fake_sum)
            else:
                return disc_cost, disc_fake_sum

    def gen_loss(disc_fake_out,
                 gen_model,
                 adv_weight=1.0,
                 name="Naive",
                 acc=True):
        softL_c = 0.05
        with tf.name_scope(name):
            raw_gen_cost = tf.reduce_mean(
                tf.square(disc_fake_out - tf.ones_like(disc_fake_out) *
                          np.abs(np.random.normal(1.0, softL_c))),
                name="raw_gen_cost")

            gen_cost = tf.multiply(raw_gen_cost, adv_weight, name="gen_cost")

        if acc:
            gen_model.cost += gen_cost
        else:
            return gen_cost, []

    disc_loss(disc_fake_A_sample,
              disc_real_A_sample,
              DiscA,
              adv_weight=TFLAGS['adv_weight'],
              name="DiscA_Loss")
    disc_loss(disc_fake_B_sample,
              disc_real_B_sample,
              DiscB,
              adv_weight=TFLAGS['adv_weight'],
              name="DiscA_Loss")

    gen_loss(disc_trans_real_A,
             TransF,
             adv_weight=TFLAGS['adv_weight'],
             name="NaiveGenA")
    gen_loss(disc_trans_real_B,
             TransB,
             adv_weight=TFLAGS['adv_weight'],
             name="NaiveTransA")

    if TFLAGS['side_noise']:
        # extra loss is for noise not equal to zero
        TransF.extra_loss = TransB.extra_loss = 0
        TransF.extra_loss = tf.identity(TransF.cost)
        TransB.extra_loss = tf.identity(TransB.cost)

        cost_, _ = gen_loss(disc_trans_fake_A,
                            TransF,
                            adv_weight=TFLAGS['adv_weight'],
                            name="NaiveGenAFake",
                            acc=False)
        TransF.extra_loss += cost_
        cost_, _ = gen_loss(disc_trans_fake_B,
                            TransB,
                            adv_weight=TFLAGS['adv_weight'],
                            name="NaiveGenBFake",
                            acc=False)
        TransB.extra_loss += cost_

    #GANLoss(disc_rec_A, DiscA, TransB, adv_weight=TFLAGS['adv_weight'], name="NaiveGenA")
    #GANLoss(disc_rec_B, DiscA, TransF, adv_weight=TFLAGS['adv_weight'], name="NaiveTransA")

    # cycle consistent loss
    def cycle_loss(trans, origin, weight=10.0, name="cycle"):
        with tf.name_scope(name):
            # using gray
            trans = tf.reduce_mean(trans, axis=[3])
            origin = tf.reduce_mean(origin, axis=[3])
            cost_ = tf.reduce_mean(tf.abs(trans - origin)) * weight
            sum_ = tf.summary.scalar("Rec", cost_)

        return cost_, sum_

    cost_, sum_ = cycle_loss(rec_real_A,
                             real_A_sample,
                             TFLAGS['AE_weight'],
                             name="cycleA")
    TransF.cost += cost_
    TransB.cost += cost_
    TransF.sum_op.append(sum_)

    cost_, sum_ = cycle_loss(rec_real_B,
                             real_B_sample,
                             TFLAGS['AE_weight'],
                             name="cycleB")
    TransF.cost += cost_
    TransB.cost += cost_
    TransB.sum_op.append(sum_)

    clsB_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=cls_real_B_sample, labels=c_label),
                               name="clsB_real")

    if TFLAGS['side_noise']:
        clsB_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=cls_trans_fake_A, labels=c_label),
                                   name="clsB_fake")

    DiscB.cost += clsB_real * real_cls_weight
    DiscB.sum_op.extend([tf.summary.scalar("RealCls", clsB_real)])

    if TFLAGS['side_noise']:
        DiscB.extra_loss = clsB_fake * fake_cls_weight
        TransF.extra_loss += clsB_fake * fake_cls_weight

        # extra loss for integrate stochastic
        cost_, sum_ = cycle_loss(rec_real_A,
                                 real_A_sample,
                                 TFLAGS['AE_weight'],
                                 name="cycleADisturb")
        TransF.extra_loss += cost_
        TransF.sum_op.append(sum_)
        TransF.extra_loss += cls_trans_real_A * fake_cls_weight

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4

    grid_real_A = ops.get_grid_image_summary(real_A_sample, edge_num)
    inter_sum_op.append(tf.summary.image("Real A", grid_real_A))
    grid_real_B = ops.get_grid_image_summary(real_B_sample, edge_num)
    inter_sum_op.append(tf.summary.image("Real B", grid_real_B))

    grid_trans_A = ops.get_grid_image_summary(trans_real_A, edge_num)
    inter_sum_op.append(tf.summary.image("Trans A", grid_trans_A))
    grid_trans_B = ops.get_grid_image_summary(trans_real_B, edge_num)
    inter_sum_op.append(tf.summary.image("Trans B", grid_trans_B))

    if TFLAGS['side_noise']:
        grid_fake_A = ops.get_grid_image_summary(trans_fake_A, edge_num)
        inter_sum_op.append(tf.summary.image("Fake A", grid_fake_A))
        grid_fake_B = ops.get_grid_image_summary(trans_fake_B, edge_num)
        inter_sum_op.append(tf.summary.image("Fake B", grid_fake_B))

    # merge summary op
    for m, n in zip(models, model_names):
        m.cost = tf.identity(m.cost, "Total" + n)
        m.sum_op.append(tf.summary.scalar("Total" + n, m.cost))
        m.sum_op = tf.summary.merge(m.sum_op)
        m.get_trainable_variables()

    inter_sum_op = tf.summary.merge(inter_sum_op)

    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        for m, n in zip(models, model_names):
            m.train_op = tf.train.AdamOptimizer(learning_rate=lr,
                                                beta1=0.5,
                                                beta2=0.9).minimize(
                                                    m.cost, var_list=m.vars)

            if m.extra_loss is not 0:
                m.extra_train_op = tf.train.AdamOptimizer(learning_rate=lr,
                                                          beta1=0.5,
                                                          beta2=0.9).minimize(
                                                              m.extra_loss,
                                                              var_list=m.vars)

            print("=> ##### %s Variable #####" % n)
            m.print_trainble_vairables()

    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
        "real_cls": real_cls_weight,
        "fake_cls": fake_cls_weight,
        "adv": adv_weight,
        "lr": lr
    }

    # basic settings
    trainer_feed = {
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    for m, n in zip(models, model_names):
        trainer_feed.update({n: m})

    Trainer = trainer.cycle_trainer.CycleTrainer
    if TFLAGS['side_noise']:
        Trainer.noises = [noise_A, noise_B]
        Trainer.labels = [label_A, label_B]
    # input
    trainer_feed.update({"inputs": [real_A_sample, real_B_sample, c_label]})

    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    DISC_PATH = [
        'success/goodmodel_dragan_sketch2_disc.npz',
        'success/goodmodel_dragan_anime1_disc.npz'
    ]

    DiscA.load_from_npz(DISC_PATH[0], ModelTrainer.sess)
    DiscB.load_from_npz(DISC_PATH[1], ModelTrainer.sess)

    ModelTrainer.train()