Exemple #1
0
def apply_from_weight():
    # get configuration
    print("=> Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    gen_model, gen_config, disc_model, disc_config = model.get_model(FLAGS.model_name, TFLAGS)

    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="z_noise")
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label")

    # build model
    with tf.variable_scope(gen_model.field_name):
        if FLAGS.cgan:
            gen_model.is_cgan = True
            x_fake = gen_model.build_inference([z_noise, c_noise])
        else:
            x_fake = gen_model.build_inference(z_noise)

    with tf.variable_scope(disc_model.field_name):
        if FLAGS.cgan:
            disc_model.is_cgan = True
            disc_model.disc_real_out = disc_model.build_inference([x_fake, c_label])[0]
        else:
            disc_model.disc_real_out = disc_model.build_inference(x_fake)[0]

    gen_vars = files.load_npz(BEST_MODEL, "goodmodel_dragan_anime1_gen.npz")
    disc_vars = files.load_npz(BEST_MODEL, "goodmodel_dragan_anime1_disc.npz")

    feed_dict = {
        z_noise: ops.random_truncate_normal((16, 128), 1, 0),
        c_noise: ops.random_boolean((16, 34), True),
        gen_model.training: False,
        disc_model.training: False,
        gen_model.keep_prob: 1.0,
        disc_model.keep_prob: 1.0
    }

    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()
    gen_moving_vars = [v for v in tf.global_variables() if v.name.find("Gen") > -1 and v.name.find("moving") > -1]
    disc_moving_vars = [v for v in tf.global_variables() if v.name.find("Disc") > -1 and v.name.find("moving") > -1]

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)
    sess.run([tf.global_variables_initializer()])

    gen_img = (sess.run([x_fake], feed_dict)[0] + 1) / 2
    utils.save_batch_img(gen_img, "ex_rand.png", 4)

    files.assign_params(sess, gen_vars, gen_model.vars + gen_moving_vars)
    files.assign_params(sess, disc_vars, disc_model.vars + disc_moving_vars)

    gen_img = (sess.run([x_fake], feed_dict)[0] + 1) / 2
    utils.save_batch_img(gen_img, "ex_load.png", 4)

    return x_fake, feed_dict, sess
Exemple #2
0
def main():
    parser = argparse.ArgumentParser(
        description='Train a text classifier in the target language,\
                                                 using training data in the source language and \
                                                 parallel data between the two languages'
    )
    parser.add_argument('-src_train_path',
                        default='data/amazon_review/en/book/train')
    parser.add_argument('-src_emb_path',
                        default='data/amazon_review/en/all.review.vec.txt')
    parser.add_argument('-tgt_test_path',
                        default='data/amazon_review/de/book/train')
    parser.add_argument('-tgt_emb_path',
                        default='data/amazon_review/de/all.review.vec.txt')
    parser.add_argument('-parl_data_path',
                        default='data/amazon_review/de/book/parl')
    parser.add_argument('-save_path', default='experiments/en-de/book')
    parser.add_argument('-dataset', default='amazon_review')
    args = parser.parse_args()

    # load the configuration file
    config = get_train_config(dataset=args.dataset)

    # save the configuration file
    mkdir_p(args.save_path)
    config['save_path'] = args.save_path
    with open(join(args.save_path, 'config.json'), 'w') as outfile:
        json.dump(config, outfile)

    # initialize a model object
    model = CLD(config)

    # read src training data
    model.read_src(train_path=args.src_train_path, emb_path=args.src_emb_path)
    # read tgt testing data
    model.read_tgt(train_path=args.tgt_test_path, emb_path=args.tgt_emb_path)
    # read parallel data
    model.read_parl(parl_path=args.parl_data_path)
    # let's start the cross-lingual training
    model.train()
    # let's make prediction on test data and evaluate
    model.eval(join(args.save_path, 'acc.txt'))
Exemple #3
0
    parser.add_argument('-c',
                        '--conf',
                        default='cfgs/meshnet_train.yaml',
                        help='path of config file')
    parser.add_argument('-trnd',
                        '--train_data',
                        default='pcs_mesh_mask_vols_train_set_1.csv',
                        help='path where the downloaded data is stored')
    parser.add_argument('-tstd',
                        '--test_data',
                        default='pcs_mesh_mask_vols_test_set_1.csv',
                        help='path where the downloaded data is stored')

    args = parser.parse_args()

    cfg = get_train_config(args.conf)
    os.environ['CUDA_VISIBLE_DEVICES'] = cfg['cuda_devices']

    datasets = {
        'train':
        MeshNetDataset(cfg=cfg['dataset'],
                       datapath=args.train_data,
                       part='train'),
        'test':
        MeshNetDataset(cfg=cfg['dataset'],
                       datapath=args.test_data,
                       part='test')
    }
    dataloaders = {
        'train':
        data.DataLoader(datasets['train'],
Exemple #4
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    gen_model, gen_config, disc_model, disc_config = model.get_model(FLAGS.model_name, TFLAGS)
    gen_model = model.good_generator.GoodGeneratorEncoder(**gen_config)
    print("Common length: %d" % gen_model.common_length)
    TFLAGS['AE_weight'] = 1.0

    # Data Preparation
    # raw image is 0~255
    print("Get dataset")
    if TFLAGS['dataset_kind'] == "file":
        dataset = dataloader.CustomDataset(
            root_dir=TFLAGS['data_dir'],
            npy_dir=TFLAGS['npy_dir'],
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            disturb=TFLAGS['disturb'],
            flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "numpy":
        train_data, test_data, train_label, test_label = utils.load_mnist_4d(TFLAGS['data_dir'])
        train_data = train_data.reshape([-1] + TFLAGS['input_shape'])

        # TODO: validation in training
        dataset = dataloader.NumpyArrayDataset(
            data_npy=train_data,
            label_npy=train_label,
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            flip=TFLAGS['preproc_flip'])
    
    dl = dataloader.CustomDataLoader(dataset, batch_size=TFLAGS['batch_size'], num_threads=TFLAGS['data_threads'])

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real")
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="z_noise")
    
    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    inc_length = tf.placeholder(tf.int32, [], name="inc_length")
    lr = tf.placeholder(tf.float32, [], name="lr")

    # build inference network

    # image_inputs : x_real, x_fake, x_trans
    # noise_input : noise, noise_rec, noise_trans

    # feat_image + image_feat
    # x_real -> trans_feat -> x_trans -> trans_x_trans_feat[REC] -> x_trans_trans_

    # feat_noise + noise_feat
    # x_real -> trans_feat -> noise_trans -> noise_x_trans_feat[REC] -> x_trans_fake

    with tf.variable_scope(gen_model.field_name):
        noise_input = tf.concat([z_noise, c_noise], axis=1)
        gen_model.image_input = x_real
        gen_model.noise_input = noise_input
        
        x_fake, x_trans, noise_rec, noise_trans = gen_model.build_inference()
        noise_feat = tf.identity(gen_model.noise_feat, "noise_feat")
        trans_feat = tf.identity(gen_model.image_feat, "trans_feat")
        
        gen_model.noise_input = noise_rec
        gen_model.image_input = x_fake

        x_fake_fake, x_fake_trans, noise_rec_rec, noise_x_fake_trans = gen_model.build_inference()
        noise_rec_feat = tf.identity(gen_model.noise_feat, "noise_rec_feat")
        trans_fake_feat = tf.identity(gen_model.image_feat, "trans_fake_feat")
        
        gen_model.noise_input = noise_trans
        gen_model.image_input = x_trans

        x_trans_fake, x_trans_trans, noise_trans_rec, noise_trans_trans = gen_model.build_inference()
        noise_trans_feat = tf.identity(gen_model.noise_feat, "noise_trans_feat")
        trans_trans_feat = tf.identity(gen_model.image_feat, "trans_trans_feat")

    # noise -> noise_feat -> x_fake | noise_rec
    # image -> trans_feat -> x_trans | noise_trans
    # noise_rec -> noise_rec_feat -> x_fake_fake | noise_rec_rec
    # x_fake -> trans_fake_feat -> x_fake_trans | noise_fake_trans
    # noise_trans -> noise_trans_feat -> x_trans_fake | noise_trans_rec    
    # x_trans -> trans_trans_feat -> x_trans_trans | noise_trans_trans

    # Image Feat Recs:
    # trans_feat == noise_trans_feat : feat_noise + noise_feat
    # trans_feat == trans_trans_feat : feat_image + image_feat
    # noise_feat == noise_rec_feat : feat_noise + noise_feat
    # noise_feat == trans_fake_feat : feat_image + image_feat

    # Noise Recs:
    # noise_input == noise_rec : noise_feat + feat_noise
    
    # Image Recs:
    # x_real -[REC]- x_trans : image_feat + feat_image

    disc_real_out, cls_real_logits = disc_model(x_real)[:2]; disc_model.set_reuse()
    disc_fake_out, cls_fake_logits = disc_model(x_fake)[:2]
    disc_trans_out, cls_trans_logits = disc_model(x_trans)[:2]

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    # Select loss builder and model trainer
    print("Build training graph")

    # Naive GAN
    loss.naive_ganloss.func_gen_loss(disc_fake_out, adv_weight, name="Gen", model=gen_model)
    loss.naive_ganloss.func_disc_fake_loss(disc_fake_out, adv_weight / 2, name="DiscFake", model=disc_model)
    loss.naive_ganloss.func_disc_fake_loss(disc_trans_out, adv_weight / 2, name="DiscTrans", model=disc_model)
    loss.naive_ganloss.func_disc_real_loss(disc_real_out, adv_weight, name="DiscGen", model=disc_model)

    # ACGAN
    real_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=real_cls_logits,
        labels=c_label), name="real_cls_reduce_mean")
    real_cls_cost_sum = tf.summary.scalar("real_cls", real_cls_cost)
    
    fake_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=fake_cls_logits,
        labels=c_noise), name='fake_cls_reduce_mean')
    fake_cls_cost_sum = tf.summary.scalar("fake_cls", fake_cls_cost)
    
    disc_model.cost += real_cls_cost * real_cls_weight + fake_cls_cost * fake_cls_weight
    disc_model.sum_op.extend([real_cls_cost_sum, fake_cls_cost_sum])

    gen_model.cost += fake_cls_cost * fake_cls_weight
    gen_model.sum_op.append(fake_cls_cost_sum)

    # DRAGAN gradient penalty
    disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss(disc_model, x_real, TFLAGS['gp_weight'])
    disc_model.cost += disc_cost_
    disc_model.sum_op.append(disc_sum_)

    gen_cost_, gen_sum_ = loss.naive_ganloss.func_gen_loss(disc_trans_out, adv_weight, name="GenTrans")
    gen_model.cost += gen_cost_

    # Image Feat Recs:
    # trans_feat == noise_trans_feat : feat_noise + noise_feat 1
    # trans_feat == trans_trans_feat : feat_image + image_feat 2
    # noise_feat == noise_rec_feat : feat_noise + noise_feat 1
    # noise_feat == trans_fake_feat : feat_image + image_feat 2
    trans_feat_cost1_, _ = loss.common_loss.reconstruction_loss(
        noise_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat1")
    trans_feat_cost2_, _ = loss.common_loss.reconstruction_loss(
        trans_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat2")
    trans_feat_cost_ = trans_feat_cost1_ + trans_feat_cost2_
    trans_feat_sum_ = tf.summary.scalar("RecTransFeat", trans_feat_cost_)

    noise_feat_cost1_, _ = loss.common_loss.reconstruction_loss(
        noise_rec_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat1")
    noise_feat_cost2_, _ = loss.common_loss.reconstruction_loss(
        trans_fake_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat2")
    noise_feat_cost_ = noise_feat_cost1_ + noise_feat_cost2_
    noise_feat_sum_ = tf.summary.scalar("RecNoiseFeat", noise_feat_cost_)

    # Noise Recs:
    # noise_input -[REC]- noise_rec : noise_feat + feat_noise
    noise_cost_, noise_sum_ = loss.common_loss.reconstruction_loss(
        noise_rec[:, -inc_length:], noise_input[:, -inc_length:], TFLAGS['AE_weight'], name="RecNoise")

    # Image Recs:
    # x_real -[REC]- x_trans : image_feat + feat_image
    rec_cost_, rec_sum_ = loss.common_loss.reconstruction_loss(
        x_trans, x_real, TFLAGS['AE_weight'], name="RecPix")

    gen_model.extra_sum_op = [gen_sum_, trans_feat_sum_, noise_feat_sum_, rec_sum_, noise_sum_]
    gen_model.extra_loss = gen_cost_ + trans_feat_cost_ + noise_feat_cost_ + rec_cost_ + noise_cost_

    gen_model.total_cost = tf.identity(gen_model.extra_loss) + tf.identity(gen_model.cost)

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")
    
    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_trans = ops.get_grid_image_summary(x_trans, edge_num)
    inter_sum_op.append(tf.summary.image("trans image", grid_x_trans))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))

    # merge summary op
    gen_model.extra_sum_op = tf.summary.merge(gen_model.extra_sum_op)
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    print("=> Compute gradient")

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()
    
    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # normal GAN train op
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        # train reconstruction branch
        var_list = gen_model.branch_vars['image_feat'] + gen_model.branch_vars['feat_noise']
        gen_model.extra_train_op_partial1 = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list)

        var_list += gen_model.branch_vars['noise_feat'] + gen_model.branch_vars['feat_image']
        gen_model.extra_train_op_full = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list)

        gen_model.full_train_op = tf.train.AdamOptimizer(
            learning_rate=lr/5,
            beta1=0.5,
            beta2=0.9).minimize(gen_model.total_cost, var_list=gen_model.vars)
        
        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
            "real_cls"  : real_cls_weight,
            "fake_cls"  : fake_cls_weight,
            "adv"       : adv_weight,
            "lr"        : lr,
            "inc_length": inc_length
        }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model" : disc_model,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.base_gantrainer_rep.BaseGANTrainer
    trainer_feed.update({
        "inputs": [z_noise, c_noise, x_real, c_label],
    })
    
    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    ModelTrainer.train()
Exemple #5
0
    elif len(img.shape) == 3 and img.shape[2] == 1:
        return np.delete(cmap(img[:, :, 0]), 3, 2)
    else:
        return np.delete(cmap(img), 3, 2)


# buzz things
rng = np.random.RandomState(3)
CONFIG = {}
with open('nim_server/config.json', 'r') as f:
    CONFIG = json.load(f)
    for x in CONFIG['models'].values():
        CONFIG = x
        break
CONFIG['model_dir'] = "success/"
TFLAGS = cfg.get_train_config(CONFIG['model_name'])
TFLAGS['batch_before'] = False
input_dim = CONFIG['input_dim']
model_dir = CONFIG['model_dir'] + CONFIG['model_name'] + CONFIG['sery']
using_cgan = CONFIG['cgan']
gen_model, gen_config, disc_model, disc_config = model.get_model(
    CONFIG['model_name'], TFLAGS)
gen_model.name = CONFIG['field_name'] + "/" + gen_model.name
disc_model.name = CONFIG['field_name'] + "/" + disc_model.name
delta = 1. - (1. / (1. + np.exp(-5.)) - 1. / (1. + np.exp(5.)))
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)

# building graph
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)

    gen_config = cfg.good_generator.deepmodel_gen({})
    disc_config = cfg.good_generator.deepmodel_disc({})
    gen_model = model.good_generator.DeepGenerator(**gen_config)
    disc_model = model.good_generator.DeepDiscriminator(**disc_config)

    # Data Preparation
    # raw image is 0~255
    print("Get dataset")
    if TFLAGS['dataset_kind'] == "file":
        dataset = dataloader.CustomDataset(root_dir=TFLAGS['data_dir'],
                                           npy_dir=TFLAGS['npy_dir'],
                                           preproc_kind=TFLAGS['preprocess'],
                                           img_size=TFLAGS['input_shape'][:2],
                                           filter_data=TFLAGS['filter_data'],
                                           class_num=TFLAGS['c_len'],
                                           disturb=TFLAGS['disturb'],
                                           flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "numpy":
        train_data, test_data, train_label, test_label = utils.load_mnist_4d(
            TFLAGS['data_dir'])
        train_data = train_data.reshape([-1] + TFLAGS['input_shape'])

        # TODO: validation in training
        dataset = dataloader.NumpyArrayDataset(
            data_npy=train_data,
            label_npy=train_label,
            preproc_kind=TFLAGS['preprocess'],
            img_size=TFLAGS['input_shape'][:2],
            filter_data=TFLAGS['filter_data'],
            class_num=TFLAGS['c_len'],
            flip=TFLAGS['preproc_flip'])

    elif TFLAGS['dataset_kind'] == "fuel":
        dataset = dataloader.FuelDataset(hdfname=TFLAGS['data_dir'],
                                         npy_dir=TFLAGS['npy_dir'],
                                         preproc_kind=TFLAGS['preprocess'],
                                         img_size=TFLAGS['input_shape'][:2],
                                         filter_data=TFLAGS['filter_data'],
                                         class_num=TFLAGS['c_len'],
                                         flip=TFLAGS['preproc_flip'])

    dl = dataloader.CustomDataLoader(dataset,
                                     batch_size=TFLAGS['batch_size'],
                                     num_threads=TFLAGS['data_threads'])

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name="x_real")
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                                 name="c_label")

    # Select loss builder and model trainer
    print("Build training graph")
    if TFLAGS['gan_loss'].find("naive") > -1:
        loss_builder_ = loss.naive_ganloss.NaiveGANLoss
        Trainer = trainer.base_gantrainer.BaseGANTrainer
    elif TFLAGS['gan_loss'].find("wass") > -1:
        loss_builder_ = loss.wass_ganloss.WGANLoss
        Trainer = trainer.base_gantrainer.BaseGANTrainer
    elif TFLAGS['gan_loss'].find("dra") > -1:
        loss_builder_ = loss.dragan_loss.DRAGANLoss
        Trainer = trainer.base_gantrainer.BaseGANTrainer
    elif TFLAGS['gan_loss'].find("ian") > -1:
        loss_builder_ = loss.ian_loss.IANLoss
        Trainer = trainer.ian_trainer.IANTrainer

    if FLAGS.cgan:
        loss_builder = loss_builder_(gen_model=gen_model,
                                     disc_model=disc_model,
                                     gen_inputs=[z_noise, c_noise],
                                     real_inputs=[x_real, c_label],
                                     has_ac=True,
                                     **TFLAGS)
    else:
        loss_builder = loss_builder_(gen_model=gen_model,
                                     disc_model=disc_model,
                                     gen_inputs=[z_noise],
                                     real_inputs=[x_real],
                                     has_ac=False,
                                     **TFLAGS)

    loss_builder.build()
    int_sum_op = tf.summary.merge(loss_builder.inter_sum_op)
    loss_builder.get_trainable_variables()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    if FLAGS.cgan:
        inputs = [z_noise, c_noise, x_real, c_label]
        if TFLAGS['gan_loss'].find("ian") > -1:
            ctrl_weight = {
                "real_cls": loss_builder.real_cls_weight,
                "fake_cls": loss_builder.fake_cls_weight,
                "rec_weight": loss_builder.rec_weight,
                "adv": loss_builder.adv_weight,
                "recadv_weight": loss_builder.recadv_weight
            }
        else:
            ctrl_weight = {
                "real_cls": loss_builder.real_cls_weight,
                "fake_cls": loss_builder.fake_cls_weight,
                "adv": loss_builder.adv_weight
            }
    else:
        inputs = [z_noise, x_real]
        ctrl_weight = {
            "real_cls": loss_builder.real_cls_weight,
            "fake_cls": loss_builder.fake_cls_weight,
            "adv": loss_builder.adv_weight
        }

    ModelTrainer = Trainer(gen_model=gen_model,
                           disc_model=disc_model,
                           inputs=inputs,
                           dataloader=dl,
                           int_sum_op=int_sum_op,
                           ctrl_weight=ctrl_weight,
                           gpu_mem=TFLAGS['gpu_mem'],
                           FLAGS=FLAGS,
                           TFLAGS=TFLAGS)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()
    ModelTrainer.train()
Exemple #7
0
from config import get_train_config
from data import SHREC21Dataset
from models import MeshNet
from tqdm import tqdm
from losses import FocalLoss
from metrics import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-r', '--data_root', type=str, help='path to MeshNet')
parser.add_argument('-t', '--task', type=str, help='[Culture|Shape]')
parser.add_argument('-s', '--saved_path', type=str, help='save checkpoint to')
parser.add_argument('--num_faces', type=int, help='number of faces')
args = parser.parse_args()

cfg = get_train_config()
cfg['dataset']['data_root'] = args.data_root
cfg['dataset']['max_faces'] = args.num_faces
cfg['saved_path'] = args.saved_path

os.environ['CUDA_VISIBLE_DEVICES'] = cfg['cuda_devices']

data_set = {
    x: SHREC21Dataset(cfg=cfg['dataset'], part=x)
    for x in ['train', 'test']
}
data_loader = {
    x: data.DataLoader(data_set[x],
                       batch_size=cfg['batch_size'],
                       num_workers=8,
                       shuffle=True,
def main():
    config = get_train_config()

    # device
    device, device_ids = setup_device(config.n_gpu)

    # tensorboard
    writer = TensorboardWriter(config.summary_dir, config.tensorboard)

    # metric tracker
    metric_names = ['loss', 'acc1', 'acc5']
    train_metrics = MetricTracker(*[metric for metric in metric_names],
                                  writer=writer)
    valid_metrics = MetricTracker(*[metric for metric in metric_names],
                                  writer=writer)

    # create model
    print("create model")
    model = VisionTransformer(image_size=(config.image_size,
                                          config.image_size),
                              patch_size=(config.patch_size,
                                          config.patch_size),
                              emb_dim=config.emb_dim,
                              mlp_dim=config.mlp_dim,
                              num_heads=config.num_heads,
                              num_layers=config.num_layers,
                              num_classes=config.num_classes,
                              attn_dropout_rate=config.attn_dropout_rate,
                              dropout_rate=config.dropout_rate)

    # load checkpoint
    if config.checkpoint_path:
        state_dict = load_checkpoint(config.checkpoint_path)
        if config.num_classes != state_dict['classifier.weight'].size(0):
            del state_dict['classifier.weight']
            del state_dict['classifier.bias']
            print("re-initialize fc layer")
            model.load_state_dict(state_dict, strict=False)
        else:
            model.load_state_dict(state_dict)
        print("Load pretrained weights from {}".format(config.checkpoint_path))

    # send model to device
    model = model.to(device)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    # create dataloader
    print("create dataloaders")
    train_dataloader = eval("{}DataLoader".format(config.dataset))(
        data_dir=os.path.join(config.data_dir, config.dataset),
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        split='train')
    valid_dataloader = eval("{}DataLoader".format(config.dataset))(
        data_dir=os.path.join(config.data_dir, config.dataset),
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        split='val')

    # training criterion
    print("create criterion and optimizer")
    criterion = nn.CrossEntropyLoss()

    # create optimizers and learning rate scheduler
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=config.lr,
                                weight_decay=config.wd,
                                momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=config.lr,
        pct_start=config.warmup_steps / config.train_steps,
        total_steps=config.train_steps)

    # start training
    print("start training")
    best_acc = 0.0
    epochs = config.train_steps // len(train_dataloader)
    for epoch in range(1, epochs + 1):
        log = {'epoch': epoch}

        # train the model
        model.train()
        result = train_epoch(epoch, model, train_dataloader, criterion,
                             optimizer, lr_scheduler, train_metrics, device)
        log.update(result)

        # validate the model
        model.eval()
        result = valid_epoch(epoch, model, valid_dataloader, criterion,
                             valid_metrics, device)
        log.update(**{'val_' + k: v for k, v in result.items()})

        # best acc
        best = False
        if log['val_acc1'] > best_acc:
            best_acc = log['val_acc1']
            best = True

        # save model
        save_model(config.checkpoint_dir, epoch, model, optimizer,
                   lr_scheduler, device_ids, best)

        # print logged informations to the screen
        for key, value in log.items():
            print('    {:15s}: {}'.format(str(key), value))
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from config import get_train_config
from data import ModelNet40
from models import MeshNet
from utils import get_unit_diamond_vertices, point_wise_L1_loss, save_loss_plot, point_wise_mse_loss  #, stochastic_loss
import numpy as np
from scipy.spatial.transform import Rotation as R

root_path = '/content/drive/MyDrive/DL_diamond_cutting/MeshNet/'

cfg = get_train_config(root_path)
os.environ['CUDA_VISIBLE_DEVICES'] = cfg['cuda_devices']
use_gpu = torch.cuda.is_available()

data_set = {
    x: ModelNet40(cfg=cfg['dataset'], root_path=root_path, part=x)
    for x in ['train', 'val']
}
data_loader = {
    x: data.DataLoader(data_set[x],
                       batch_size=cfg['batch_size'],
                       num_workers=4,
                       shuffle=True,
                       pin_memory=False)
    for x in ['train', 'val']
}
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    gen_model, gen_config, disc_model, disc_config = model.get_model(
        FLAGS.model_name, TFLAGS)
    gen_model = model.conditional_generator.ImageConditionalEncoder()
    disc_model = model.conditional_generator.ImageConditionalDeepDiscriminator(
    )
    gen_model.name = "ImageConditionalEncoder"
    disc_model.name = "ImageConditionalDeepDiscriminator"
    print("Common length: %d" % gen_model.common_length)
    TFLAGS['AE_weight'] = 1.0
    TFLAGS['batch_size'] = 1
    TFLAGS['input_shape'] = [256, 256, 3]

    face_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/crop_character",
        npy_dir="/data/datasets/getchu/true_character.npy",
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        has_gray=False,
        disturb=False,
        flip=False)

    sketch_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/crop_sketch_character/",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        has_gray=True,
        disturb=False,
        flip=False)

    label_dataset = dataloader.NumpyArrayDataset(
        np.load("/data/datasets/getchu/true_face.npy")[:, 1:])

    listsampler = dataloader.ListSampler(
        [face_dataset, sketch_dataset, label_dataset])
    dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4)

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name="x_real")
    fake_x_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="x_real")
    # semantic segmentation
    s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'][:-1] + [
        1,
    ],
                            name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="z_noise")
    c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_noise")
    c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    inc_length = tf.placeholder(tf.int32, [], name="inc_length")
    lr = tf.placeholder(tf.float32, [], name="lr")

    with tf.variable_scope(gen_model.name):
        gen_model.image_input = x_real
        gen_model.seg_input = s_real
        gen_model.noise_input = tf.concat([z_noise, c_noise], axis=1)

        seg_image, image_image, seg_seg, image_seg = gen_model.build_inference(
        )
        seg_feat = tf.identity(gen_model.seg_feat, "seg_feat")
        image_feat = tf.identity(gen_model.image_feat, "image_feat")

        gen_model.x_fake = tf.identity(seg_image)

    disc_real_out = disc_model([x_real, s_real])
    disc_model.set_reuse()
    disc_fake_out = disc_model([seg_image, s_real])
    disc_fake_sketch_out = disc_model([x_real, image_seg])
    #disc_rec_out  = disc_model([image_image, s_real])
    disc_fake_sample_out = disc_model([fake_x_sample, s_real])

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    # Select loss builder and model trainer
    print("Build training graph")

    # Naive GAN
    loss.naive_ganloss.func_gen_loss(disc_fake_out,
                                     adv_weight * 0.9,
                                     name="GenSeg",
                                     model=gen_model)
    loss.naive_ganloss.func_gen_loss(disc_fake_sketch_out,
                                     adv_weight * 0.1,
                                     name="GenSketch",
                                     model=gen_model)
    if FLAGS.cache:
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_sample_out,
                                               adv_weight * 0.9,
                                               name="DiscFake",
                                               model=disc_model)
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_sample_out,
                                               adv_weight * 0.1,
                                               name="DiscFakeSketch",
                                               model=disc_model)
    else:
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_out,
                                               adv_weight * 0.9,
                                               name="DiscFake",
                                               model=disc_model)
        loss.naive_ganloss.func_disc_fake_loss(disc_fake_sketch_out,
                                               adv_weight * 0.1,
                                               name="DiscFakeSketch",
                                               model=disc_model)

    loss.naive_ganloss.func_disc_real_loss(disc_real_out,
                                           adv_weight,
                                           name="DiscGen",
                                           model=disc_model)

    # recontrust of sketch
    rec_sketch_cost_, rec_sketch_sum_ = loss.common_loss.reconstruction_loss(
        image_seg, s_real, TFLAGS['AE_weight'], name="RecSketch")
    gen_model.cost += rec_sketch_cost_
    gen_model.sum_op.append(rec_sketch_sum_)

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")

    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(seg_image, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_seg = ops.get_grid_image_summary(image_seg, edge_num)
    inter_sum_op.append(tf.summary.image("inv image", grid_x_seg))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))
    grid_s_real = ops.get_grid_image_summary(s_real, edge_num)
    inter_sum_op.append(tf.summary.image("sketch image", grid_s_real))

    # merge summary op
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    print("=> Compute gradient")

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()

    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # normal GAN train op
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
        "real_cls": real_cls_weight,
        "fake_cls": fake_cls_weight,
        "adv": adv_weight,
        "lr": lr,
        "inc_length": inc_length
    }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model": disc_model,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.cgantrainer.CGANTrainer
    Trainer.use_cache = False
    trainer_feed.update({
        "inputs": [x_real, s_real, c_label],
        "noise": z_noise,
        "cond": c_noise
    })

    ModelTrainer = Trainer(**trainer_feed)
    ModelTrainer.fake_x_sample = fake_x_sample

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    ModelTrainer.train()
Exemple #11
0
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)

    gen_config = cfg.good_generator.goodmodel_gen({})
    gen_config['name'] = "CondDeepGenerator"
    gen_model = model.conditional_generator.ImageConditionalDeepGenerator2(
        **gen_config)
    disc_config = cfg.good_generator.goodmodel_disc({})
    disc_config['norm_mtd'] = None
    disc_model = model.good_generator.GoodDiscriminator(**disc_config)

    TFLAGS['batch_size'] = 1
    TFLAGS['AE_weight'] = 10.0
    TFLAGS['side_noise'] = FLAGS.side_noise
    TFLAGS['input_shape'] = [128, 128, 3]
    # Data Preparation
    # raw image is 0~255
    print("Get dataset")

    face_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/true_face",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    sketch_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/sketch_face/",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    label_dataset = dataloader.NumpyArrayDataset(
        np.load("/data/datasets/getchu/true_face.npy")[:, 1:])

    listsampler = dataloader.ListSampler(
        [sketch_dataset, face_dataset, label_dataset])
    dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4)

    # TF Input
    x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                            name="x_real")
    s_real = tf.placeholder(tf.float32,
                            [None] + TFLAGS['input_shape'][:-1] + [1],
                            name='s_real')
    fake_x_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="x_real")
    noise_A = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_A")
    noise_B = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_B")
    label_A = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_A")
    label_B = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_B")

    # c label is for real samples
    c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_label")
    c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_noise")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    lr = tf.placeholder(tf.float32, [], name="lr")

    side_noise_A = tf.concat([noise_A, c_label], axis=1, name='side_noise_A')

    if TFLAGS['side_noise']:
        gen_model.side_noise = side_noise_A
    x_fake = gen_model([s_real])
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    if FLAGS.cgan:
        disc_model.is_cgan = True
    disc_model.disc_real_out, disc_model.real_cls_logits = disc_model(
        x_real)[:2]
    disc_model.set_reuse()
    disc_model.disc_fake_out, disc_model.fake_cls_logits = disc_model(
        x_fake)[:2]

    gen_model.cost = disc_model.cost = 0
    gen_model.sum_op = disc_model.sum_op = []
    inter_sum_op = []

    if TFLAGS['gan_loss'] == "dra":
        # naive GAN loss
        gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.naive_ganloss.get_naive_ganloss(
            gen_model, disc_model, adv_weight)
        gen_model.cost += gen_cost_
        disc_model.cost += disc_cost_
        gen_model.sum_op.extend(gen_sum_)
        disc_model.sum_op.extend(disc_sum_)

        gen_model.gen_cost = gen_cost_
        disc_model.disc_cost = disc_cost_

        # dragan loss
        disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss(
            disc_model, x_real, TFLAGS['gp_weight'])
        disc_model.cost += disc_cost_
        disc_model.sum_op.append(disc_sum_)

    elif TFLAGS['gan_loss'] == "wass":
        gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.wass_ganloss.wass_gan_loss(
            gen_model, disc_model, x_real, x_fake)
        gen_model.cost += gen_cost_
        disc_model.cost += disc_cost_
        gen_model.sum_op.extend(gen_sum_)
        disc_model.sum_op.extend(disc_sum_)

    elif TFLAGS['gan_loss'] == "naive":
        # naive GAN loss
        gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.naive_ganloss.get_naive_ganloss(
            gen_model, disc_model, adv_weight, lsgan=True)
        gen_model.cost += gen_cost_
        disc_model.cost += disc_cost_
        gen_model.sum_op.extend(gen_sum_)
        disc_model.sum_op.extend(disc_sum_)

        gen_model.gen_cost = gen_cost_
        disc_model.disc_cost = disc_cost_

    if FLAGS.cgan:
        real_cls = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_model.real_cls_logits, labels=c_label),
            name="real_cls_reduce_mean") * real_cls_weight

        #fake_cls = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #    logits=disc_model.fake_cls_logits,
        #    labels=c_label), name="fake_cls_reduce_mean") * fake_cls_weight

        disc_model.cost += real_cls  # + fake_cls
        disc_model.sum_op.extend([tf.summary.scalar("RealCls", real_cls)])  #,
        #tf.summary.scalar("FakeCls", fake_cls)])

    gen_cost_, ae_loss_sum_ = loss.common_loss.reconstruction_loss(
        x_fake, x_real, TFLAGS['AE_weight'])
    gen_model.sum_op.append(ae_loss_sum_)
    gen_model.cost += gen_cost_

    gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost")
    disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost")

    # total summary
    gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost))
    disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost))

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4
    grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num)
    inter_sum_op.append(tf.summary.image("generated image", grid_x_fake))
    grid_x_real = ops.get_grid_image_summary(x_real, edge_num)
    grid_x_real = tf.Print(
        grid_x_real, [tf.reduce_max(grid_x_real),
                      tf.reduce_min(grid_x_real)], "Real")
    inter_sum_op.append(tf.summary.image("real image", grid_x_real))
    inter_sum_op.append(
        tf.summary.image("sketch image",
                         ops.get_grid_image_summary(s_real, edge_num)))

    # merge summary op
    gen_model.sum_op = tf.summary.merge(gen_model.sum_op)
    disc_model.sum_op = tf.summary.merge(disc_model.sum_op)
    inter_sum_op = tf.summary.merge(inter_sum_op)

    # get train op
    gen_model.get_trainable_variables()
    disc_model.get_trainable_variables()

    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        gen_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars)

        disc_model.train_op = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=0.5,
            beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars)

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
        "real_cls": real_cls_weight,
        "fake_cls": fake_cls_weight,
        "adv": adv_weight,
        "lr": lr
    }

    trainer_feed = {
        "gen_model": gen_model,
        "disc_model": disc_model,
        "noise": noise_A,
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    Trainer = trainer.cgantrainer.CGANTrainer
    trainer_feed.update({
        "inputs": [s_real, x_real, c_label],
    })

    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    disc_model.load_from_npz('success/goodmodel_dragan_anime1_disc.npz',
                             ModelTrainer.sess)

    ModelTrainer.train()
Exemple #12
0
            self.net.eval()
            bl = bl.cuda()
            db = self.net(bl, *args)
        return db

    def load_model(self, args):
        ckp = torch.load(
            args.test_ckp_dir,
            map_location=lambda storage, loc: storage.cuda(args.gpu_idx))
        self.net.load_state_dict(ckp['model'])
        return ckp


if __name__ == '__main__':

    args = get_train_config()
    log(args)
    net = vem_deblur_model(args).cuda()

    train_dset = Train_Dataset(args, args.train_sp_dir, args.sigma,
                               args.train_ker_dir)
    val_dset = {}
    for name in args.val_bl_sigma:
        val_dset[str(name)] = Test_Dataset(args.val_sp_dir,
                                           args.val_bl_dir[str(name)],
                                           args.val_ker_dir)

    # trainer
    train = Trainer(args, net, train_dset=train_dset, val_dset=val_dset)
    train()
def main():
    # get configuration
    print("Get configuration")
    TFLAGS = cfg.get_train_config(FLAGS.model_name)
    TFLAGS['AE_weight'] = 10.0
    TFLAGS['side_noise'] = False

    # A: sketch face, B: anime face
    # Gen and Disc usually init from pretrained model
    print("=> Building discriminator")
    disc_config = cfg.good_generator.goodmodel_disc({})
    disc_config['name'] = "DiscA"
    DiscA = model.good_generator.GoodDiscriminator(**disc_config)
    disc_config['name'] = "DiscB"
    DiscB = model.good_generator.GoodDiscriminator(**disc_config)

    gen_config = cfg.good_generator.goodmodel_gen({})
    gen_config['name'] = 'GenA'
    gen_config['out_dim'] = 3
    GenA = model.good_generator.GoodGenerator(**gen_config)
    gen_config['name'] = 'GenB'
    gen_config['out_dim'] = 1
    GenB = model.good_generator.GoodGenerator(**gen_config)

    gen_config = {}
    gen_config['name'] = 'TransForward'
    gen_config['out_dim'] = 3
    TransF = model.conditional_generator.ImageConditionalDeepGenerator2(
        **gen_config)
    gen_config['name'] = 'TransBackward'
    gen_config['out_dim'] = 1
    TransB = model.conditional_generator.ImageConditionalDeepGenerator2(
        **gen_config)

    models = [DiscA, DiscB, TransF, TransB]
    model_names = ["DiscA", "DiscB", "TransF", "TransB"]

    TFLAGS['batch_size'] = 1
    TFLAGS['input_shape'] = [128, 128, 3]
    # Data Preparation
    # raw image is 0~255
    print("Get dataset")

    face_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/true_face",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    sketch_dataset = dataloader.CustomDataset(
        root_dir="/data/datasets/getchu/sketch_face/",
        npy_dir=None,
        preproc_kind="tanh",
        img_size=TFLAGS['input_shape'][:2],
        filter_data=TFLAGS['filter_data'],
        class_num=TFLAGS['c_len'],
        disturb=False,
        flip=False)

    label_dataset = dataloader.NumpyArrayDataset(
        np.load("/data/datasets/getchu/true_face.npy")[:, 1:])

    listsampler = dataloader.ListSampler(
        [sketch_dataset, face_dataset, label_dataset])
    dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4)

    # TF Input
    real_A_sample = tf.placeholder(tf.float32,
                                   [None] + TFLAGS['input_shape'][:-1] + [1],
                                   name='real_A_sample')
    real_B_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="real_B_sample")
    fake_A_sample = tf.placeholder(tf.float32,
                                   [None] + TFLAGS['input_shape'][:-1] + [1],
                                   name='fake_A_sample')
    fake_B_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'],
                                   name="fake_B_sample")

    noise_A = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_A")
    noise_B = tf.placeholder(tf.float32, [None, TFLAGS['z_len']],
                             name="noise_B")
    label_A = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_A")
    label_B = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="label_B")
    # c label is for real samples
    c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']],
                             name="c_label")

    # control variables
    real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight")
    fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight")
    adv_weight = tf.placeholder(tf.float32, [], name="adv_weight")
    lr = tf.placeholder(tf.float32, [], name="lr")

    inter_sum_op = []

    ### build graph
    """
    fake_A = GenA(noise_A); GenA.set_reuse() # GenA and GenB only used once
    fake_B = GenB(noise_B); GenB.set_reuse()

    # transF and TransB used three times
    trans_real_A = TransF(real_A); TransF.set_reuse() # domain B
    trans_fake_A = TransF(fake_A) # domain B
    trans_real_B = TransB(real_B); TransB.set_reuse() # domain A
    trans_fake_B = TransB(fake_B) # domain A
    rec_real_A = TransB(trans_real_A)
    rec_fake_A = TransB(trans_fake_A)
    rec_real_B = TransF(trans_real_B)
    rec_fake_B = TransF(trans_fake_B)

    # DiscA and DiscB reused many times
    disc_real_A, cls_real_A = DiscA(x_real)[:2]; DiscA.set_reuse() 
    disc_fake_A, cls_fake_A = DiscA(x_fake)[:2]; 
    disc_trans_real_B, cls_trans_real_B = DiscA(trans_real_B)[:2]
    disc_trans_fake_B, cls_trans_fake_B = DiscA(trans_real_B)[:2]
    disc_rec_real_A, cls_rec_real_A = DiscA(rec_real_A)[:2]
    disc_rec_fake_A, cls_rec_fake_A = DiscA(rec_real_A)[:2]

    disc_real_B, cls_real_B = DiscB(x_real)[:2]; DiscB.set_reuse() 
    disc_fake_B, cls_fake_B = DiscB(x_fake)[:2]; 
    disc_trans_real_A, cls_trans_real_A = DiscB(trans_real_A)[:2]
    disc_trans_fake_A, cls_trans_fake_A = DiscB(trans_real_A)[:2]
    disc_rec_real_B, cls_rec_real_B = DiscB(rec_real_B)[:2]
    disc_rec_fake_B, cls_rec_fake_B = DiscB(rec_real_B)[:2]
    """

    side_noise_A = tf.concat([noise_A, label_A], axis=1, name='side_noise_A')
    side_noise_B = tf.concat([noise_B, label_B], axis=1, name='side_noise_B')

    trans_real_A = TransF(real_A_sample)
    TransF.set_reuse()  # domain B
    trans_real_B = TransB(real_B_sample)
    TransB.set_reuse()  # domain A
    TransF.trans_real_A = trans_real_A
    TransB.trans_real_B = trans_real_B
    rec_real_A = TransB(trans_real_A)
    rec_real_B = TransF(trans_real_B)

    # start fake building
    if TFLAGS['side_noise']:
        TransF.side_noise = side_noise_A
        TransB.side_noise = side_noise_B
        trans_fake_A = TransF(real_A_sample)
        trans_fake_B = TransB(real_B_sample)

    DiscA.fake_sample = fake_A_sample
    DiscB.fake_sample = fake_B_sample
    disc_fake_A_sample, cls_fake_A_sample = DiscA(fake_A_sample)[:2]
    DiscA.set_reuse()
    disc_fake_B_sample, cls_fake_B_sample = DiscB(fake_B_sample)[:2]
    DiscB.set_reuse()
    disc_real_A_sample, cls_real_A_sample = DiscA(real_A_sample)[:2]
    disc_real_B_sample, cls_real_B_sample = DiscB(real_B_sample)[:2]
    disc_trans_real_A, cls_trans_real_A = DiscB(trans_real_A)[:2]
    disc_trans_real_B, cls_trans_real_B = DiscA(trans_real_B)[:2]

    if TFLAGS['side_noise']:
        disc_trans_fake_A, cls_trans_fake_A = DiscB(trans_real_A)[:2]
        disc_trans_fake_B, cls_trans_fake_B = DiscA(trans_real_B)[:2]

    def disc_loss(disc_fake_out,
                  disc_real_out,
                  disc_model,
                  adv_weight=1.0,
                  name="NaiveDisc",
                  acc=True):
        softL_c = 0.05
        with tf.name_scope(name):
            raw_disc_cost_real = tf.reduce_mean(
                tf.square(disc_real_out - tf.ones_like(disc_real_out) *
                          np.abs(np.random.normal(1.0, softL_c))),
                name="raw_disc_cost_real")

            raw_disc_cost_fake = tf.reduce_mean(
                tf.square(disc_fake_out - tf.zeros_like(disc_fake_out)),
                name="raw_disc_cost_fake")

            disc_cost = tf.multiply(
                adv_weight, (raw_disc_cost_fake + raw_disc_cost_real) / 2,
                name="disc_cost")

            disc_fake_sum = [
                tf.summary.scalar("DiscFakeRaw", raw_disc_cost_fake),
                tf.summary.scalar("DiscRealRaw", raw_disc_cost_real)
            ]

            if acc:
                disc_model.cost += disc_cost
                disc_model.sum_op.extend(disc_fake_sum)
            else:
                return disc_cost, disc_fake_sum

    def gen_loss(disc_fake_out,
                 gen_model,
                 adv_weight=1.0,
                 name="Naive",
                 acc=True):
        softL_c = 0.05
        with tf.name_scope(name):
            raw_gen_cost = tf.reduce_mean(
                tf.square(disc_fake_out - tf.ones_like(disc_fake_out) *
                          np.abs(np.random.normal(1.0, softL_c))),
                name="raw_gen_cost")

            gen_cost = tf.multiply(raw_gen_cost, adv_weight, name="gen_cost")

        if acc:
            gen_model.cost += gen_cost
        else:
            return gen_cost, []

    disc_loss(disc_fake_A_sample,
              disc_real_A_sample,
              DiscA,
              adv_weight=TFLAGS['adv_weight'],
              name="DiscA_Loss")
    disc_loss(disc_fake_B_sample,
              disc_real_B_sample,
              DiscB,
              adv_weight=TFLAGS['adv_weight'],
              name="DiscA_Loss")

    gen_loss(disc_trans_real_A,
             TransF,
             adv_weight=TFLAGS['adv_weight'],
             name="NaiveGenA")
    gen_loss(disc_trans_real_B,
             TransB,
             adv_weight=TFLAGS['adv_weight'],
             name="NaiveTransA")

    if TFLAGS['side_noise']:
        # extra loss is for noise not equal to zero
        TransF.extra_loss = TransB.extra_loss = 0
        TransF.extra_loss = tf.identity(TransF.cost)
        TransB.extra_loss = tf.identity(TransB.cost)

        cost_, _ = gen_loss(disc_trans_fake_A,
                            TransF,
                            adv_weight=TFLAGS['adv_weight'],
                            name="NaiveGenAFake",
                            acc=False)
        TransF.extra_loss += cost_
        cost_, _ = gen_loss(disc_trans_fake_B,
                            TransB,
                            adv_weight=TFLAGS['adv_weight'],
                            name="NaiveGenBFake",
                            acc=False)
        TransB.extra_loss += cost_

    #GANLoss(disc_rec_A, DiscA, TransB, adv_weight=TFLAGS['adv_weight'], name="NaiveGenA")
    #GANLoss(disc_rec_B, DiscA, TransF, adv_weight=TFLAGS['adv_weight'], name="NaiveTransA")

    # cycle consistent loss
    def cycle_loss(trans, origin, weight=10.0, name="cycle"):
        with tf.name_scope(name):
            # using gray
            trans = tf.reduce_mean(trans, axis=[3])
            origin = tf.reduce_mean(origin, axis=[3])
            cost_ = tf.reduce_mean(tf.abs(trans - origin)) * weight
            sum_ = tf.summary.scalar("Rec", cost_)

        return cost_, sum_

    cost_, sum_ = cycle_loss(rec_real_A,
                             real_A_sample,
                             TFLAGS['AE_weight'],
                             name="cycleA")
    TransF.cost += cost_
    TransB.cost += cost_
    TransF.sum_op.append(sum_)

    cost_, sum_ = cycle_loss(rec_real_B,
                             real_B_sample,
                             TFLAGS['AE_weight'],
                             name="cycleB")
    TransF.cost += cost_
    TransB.cost += cost_
    TransB.sum_op.append(sum_)

    clsB_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=cls_real_B_sample, labels=c_label),
                               name="clsB_real")

    if TFLAGS['side_noise']:
        clsB_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=cls_trans_fake_A, labels=c_label),
                                   name="clsB_fake")

    DiscB.cost += clsB_real * real_cls_weight
    DiscB.sum_op.extend([tf.summary.scalar("RealCls", clsB_real)])

    if TFLAGS['side_noise']:
        DiscB.extra_loss = clsB_fake * fake_cls_weight
        TransF.extra_loss += clsB_fake * fake_cls_weight

        # extra loss for integrate stochastic
        cost_, sum_ = cycle_loss(rec_real_A,
                                 real_A_sample,
                                 TFLAGS['AE_weight'],
                                 name="cycleADisturb")
        TransF.extra_loss += cost_
        TransF.sum_op.append(sum_)
        TransF.extra_loss += cls_trans_real_A * fake_cls_weight

    # add interval summary
    edge_num = int(np.sqrt(TFLAGS['batch_size']))
    if edge_num > 4:
        edge_num = 4

    grid_real_A = ops.get_grid_image_summary(real_A_sample, edge_num)
    inter_sum_op.append(tf.summary.image("Real A", grid_real_A))
    grid_real_B = ops.get_grid_image_summary(real_B_sample, edge_num)
    inter_sum_op.append(tf.summary.image("Real B", grid_real_B))

    grid_trans_A = ops.get_grid_image_summary(trans_real_A, edge_num)
    inter_sum_op.append(tf.summary.image("Trans A", grid_trans_A))
    grid_trans_B = ops.get_grid_image_summary(trans_real_B, edge_num)
    inter_sum_op.append(tf.summary.image("Trans B", grid_trans_B))

    if TFLAGS['side_noise']:
        grid_fake_A = ops.get_grid_image_summary(trans_fake_A, edge_num)
        inter_sum_op.append(tf.summary.image("Fake A", grid_fake_A))
        grid_fake_B = ops.get_grid_image_summary(trans_fake_B, edge_num)
        inter_sum_op.append(tf.summary.image("Fake B", grid_fake_B))

    # merge summary op
    for m, n in zip(models, model_names):
        m.cost = tf.identity(m.cost, "Total" + n)
        m.sum_op.append(tf.summary.scalar("Total" + n, m.cost))
        m.sum_op = tf.summary.merge(m.sum_op)
        m.get_trainable_variables()

    inter_sum_op = tf.summary.merge(inter_sum_op)

    # Not applying update op will result in failure
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        for m, n in zip(models, model_names):
            m.train_op = tf.train.AdamOptimizer(learning_rate=lr,
                                                beta1=0.5,
                                                beta2=0.9).minimize(
                                                    m.cost, var_list=m.vars)

            if m.extra_loss is not 0:
                m.extra_train_op = tf.train.AdamOptimizer(learning_rate=lr,
                                                          beta1=0.5,
                                                          beta2=0.9).minimize(
                                                              m.extra_loss,
                                                              var_list=m.vars)

            print("=> ##### %s Variable #####" % n)
            m.print_trainble_vairables()

    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ctrl_weight = {
        "real_cls": real_cls_weight,
        "fake_cls": fake_cls_weight,
        "adv": adv_weight,
        "lr": lr
    }

    # basic settings
    trainer_feed = {
        "ctrl_weight": ctrl_weight,
        "dataloader": dl,
        "int_sum_op": inter_sum_op,
        "gpu_mem": TFLAGS['gpu_mem'],
        "FLAGS": FLAGS,
        "TFLAGS": TFLAGS
    }

    for m, n in zip(models, model_names):
        trainer_feed.update({n: m})

    Trainer = trainer.cycle_trainer.CycleTrainer
    if TFLAGS['side_noise']:
        Trainer.noises = [noise_A, noise_B]
        Trainer.labels = [label_A, label_B]
    # input
    trainer_feed.update({"inputs": [real_A_sample, real_B_sample, c_label]})

    ModelTrainer = Trainer(**trainer_feed)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    ModelTrainer.init_training()

    DISC_PATH = [
        'success/goodmodel_dragan_sketch2_disc.npz',
        'success/goodmodel_dragan_anime1_disc.npz'
    ]

    DiscA.load_from_npz(DISC_PATH[0], ModelTrainer.sess)
    DiscB.load_from_npz(DISC_PATH[1], ModelTrainer.sess)

    ModelTrainer.train()