def get_interval_summary(data1, data2, batch_size): len_edge = int(np.sqrt(batch_size)) if len_edge > 4: len_edge = 4 generated_samples = ops.get_grid_image_summary(data1, len_edge) real_samples = ops.get_grid_image_summary(data2, len_edge) gen_image_sum = tf.summary.image("generated", generated_samples + 1) real_image_sum = tf.summary.image("real", real_samples + 1) sum_interval_op = tf.summary.merge([gen_image_sum, real_image_sum]) return sum_interval_op
def main(): # get configuration print("Get configuration") TFLAGS = cfg.get_train_config(FLAGS.model_name) gen_model, gen_config, disc_model, disc_config = model.get_model(FLAGS.model_name, TFLAGS) gen_model = model.good_generator.GoodGeneratorEncoder(**gen_config) print("Common length: %d" % gen_model.common_length) TFLAGS['AE_weight'] = 1.0 # Data Preparation # raw image is 0~255 print("Get dataset") if TFLAGS['dataset_kind'] == "file": dataset = dataloader.CustomDataset( root_dir=TFLAGS['data_dir'], npy_dir=TFLAGS['npy_dir'], preproc_kind=TFLAGS['preprocess'], img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], disturb=TFLAGS['disturb'], flip=TFLAGS['preproc_flip']) elif TFLAGS['dataset_kind'] == "numpy": train_data, test_data, train_label, test_label = utils.load_mnist_4d(TFLAGS['data_dir']) train_data = train_data.reshape([-1] + TFLAGS['input_shape']) # TODO: validation in training dataset = dataloader.NumpyArrayDataset( data_npy=train_data, label_npy=train_label, preproc_kind=TFLAGS['preprocess'], img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], flip=TFLAGS['preproc_flip']) dl = dataloader.CustomDataLoader(dataset, batch_size=TFLAGS['batch_size'], num_threads=TFLAGS['data_threads']) # TF Input x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real") s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name='s_real') z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="z_noise") if FLAGS.cgan: c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise") c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label") # control variables real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight") fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight") adv_weight = tf.placeholder(tf.float32, [], name="adv_weight") inc_length = tf.placeholder(tf.int32, [], name="inc_length") lr = tf.placeholder(tf.float32, [], name="lr") # build inference network # image_inputs : x_real, x_fake, x_trans # noise_input : noise, noise_rec, noise_trans # feat_image + image_feat # x_real -> trans_feat -> x_trans -> trans_x_trans_feat[REC] -> x_trans_trans_ # feat_noise + noise_feat # x_real -> trans_feat -> noise_trans -> noise_x_trans_feat[REC] -> x_trans_fake with tf.variable_scope(gen_model.field_name): noise_input = tf.concat([z_noise, c_noise], axis=1) gen_model.image_input = x_real gen_model.noise_input = noise_input x_fake, x_trans, noise_rec, noise_trans = gen_model.build_inference() noise_feat = tf.identity(gen_model.noise_feat, "noise_feat") trans_feat = tf.identity(gen_model.image_feat, "trans_feat") gen_model.noise_input = noise_rec gen_model.image_input = x_fake x_fake_fake, x_fake_trans, noise_rec_rec, noise_x_fake_trans = gen_model.build_inference() noise_rec_feat = tf.identity(gen_model.noise_feat, "noise_rec_feat") trans_fake_feat = tf.identity(gen_model.image_feat, "trans_fake_feat") gen_model.noise_input = noise_trans gen_model.image_input = x_trans x_trans_fake, x_trans_trans, noise_trans_rec, noise_trans_trans = gen_model.build_inference() noise_trans_feat = tf.identity(gen_model.noise_feat, "noise_trans_feat") trans_trans_feat = tf.identity(gen_model.image_feat, "trans_trans_feat") # noise -> noise_feat -> x_fake | noise_rec # image -> trans_feat -> x_trans | noise_trans # noise_rec -> noise_rec_feat -> x_fake_fake | noise_rec_rec # x_fake -> trans_fake_feat -> x_fake_trans | noise_fake_trans # noise_trans -> noise_trans_feat -> x_trans_fake | noise_trans_rec # x_trans -> trans_trans_feat -> x_trans_trans | noise_trans_trans # Image Feat Recs: # trans_feat == noise_trans_feat : feat_noise + noise_feat # trans_feat == trans_trans_feat : feat_image + image_feat # noise_feat == noise_rec_feat : feat_noise + noise_feat # noise_feat == trans_fake_feat : feat_image + image_feat # Noise Recs: # noise_input == noise_rec : noise_feat + feat_noise # Image Recs: # x_real -[REC]- x_trans : image_feat + feat_image disc_real_out, cls_real_logits = disc_model(x_real)[:2]; disc_model.set_reuse() disc_fake_out, cls_fake_logits = disc_model(x_fake)[:2] disc_trans_out, cls_trans_logits = disc_model(x_trans)[:2] gen_model.cost = disc_model.cost = 0 gen_model.sum_op = disc_model.sum_op = [] inter_sum_op = [] # Select loss builder and model trainer print("Build training graph") # Naive GAN loss.naive_ganloss.func_gen_loss(disc_fake_out, adv_weight, name="Gen", model=gen_model) loss.naive_ganloss.func_disc_fake_loss(disc_fake_out, adv_weight / 2, name="DiscFake", model=disc_model) loss.naive_ganloss.func_disc_fake_loss(disc_trans_out, adv_weight / 2, name="DiscTrans", model=disc_model) loss.naive_ganloss.func_disc_real_loss(disc_real_out, adv_weight, name="DiscGen", model=disc_model) # ACGAN real_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=real_cls_logits, labels=c_label), name="real_cls_reduce_mean") real_cls_cost_sum = tf.summary.scalar("real_cls", real_cls_cost) fake_cls_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=fake_cls_logits, labels=c_noise), name='fake_cls_reduce_mean') fake_cls_cost_sum = tf.summary.scalar("fake_cls", fake_cls_cost) disc_model.cost += real_cls_cost * real_cls_weight + fake_cls_cost * fake_cls_weight disc_model.sum_op.extend([real_cls_cost_sum, fake_cls_cost_sum]) gen_model.cost += fake_cls_cost * fake_cls_weight gen_model.sum_op.append(fake_cls_cost_sum) # DRAGAN gradient penalty disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss(disc_model, x_real, TFLAGS['gp_weight']) disc_model.cost += disc_cost_ disc_model.sum_op.append(disc_sum_) gen_cost_, gen_sum_ = loss.naive_ganloss.func_gen_loss(disc_trans_out, adv_weight, name="GenTrans") gen_model.cost += gen_cost_ # Image Feat Recs: # trans_feat == noise_trans_feat : feat_noise + noise_feat 1 # trans_feat == trans_trans_feat : feat_image + image_feat 2 # noise_feat == noise_rec_feat : feat_noise + noise_feat 1 # noise_feat == trans_fake_feat : feat_image + image_feat 2 trans_feat_cost1_, _ = loss.common_loss.reconstruction_loss( noise_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat1") trans_feat_cost2_, _ = loss.common_loss.reconstruction_loss( trans_trans_feat, trans_feat, TFLAGS['AE_weight'] / 4, name="RecTransFeat2") trans_feat_cost_ = trans_feat_cost1_ + trans_feat_cost2_ trans_feat_sum_ = tf.summary.scalar("RecTransFeat", trans_feat_cost_) noise_feat_cost1_, _ = loss.common_loss.reconstruction_loss( noise_rec_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat1") noise_feat_cost2_, _ = loss.common_loss.reconstruction_loss( trans_fake_feat, noise_feat, TFLAGS['AE_weight'] / 4, name="RecNoiseFeat2") noise_feat_cost_ = noise_feat_cost1_ + noise_feat_cost2_ noise_feat_sum_ = tf.summary.scalar("RecNoiseFeat", noise_feat_cost_) # Noise Recs: # noise_input -[REC]- noise_rec : noise_feat + feat_noise noise_cost_, noise_sum_ = loss.common_loss.reconstruction_loss( noise_rec[:, -inc_length:], noise_input[:, -inc_length:], TFLAGS['AE_weight'], name="RecNoise") # Image Recs: # x_real -[REC]- x_trans : image_feat + feat_image rec_cost_, rec_sum_ = loss.common_loss.reconstruction_loss( x_trans, x_real, TFLAGS['AE_weight'], name="RecPix") gen_model.extra_sum_op = [gen_sum_, trans_feat_sum_, noise_feat_sum_, rec_sum_, noise_sum_] gen_model.extra_loss = gen_cost_ + trans_feat_cost_ + noise_feat_cost_ + rec_cost_ + noise_cost_ gen_model.total_cost = tf.identity(gen_model.extra_loss) + tf.identity(gen_model.cost) gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost") disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost") # total summary gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost)) disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost)) # add interval summary edge_num = int(np.sqrt(TFLAGS['batch_size'])) if edge_num > 4: edge_num = 4 grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num) inter_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_trans = ops.get_grid_image_summary(x_trans, edge_num) inter_sum_op.append(tf.summary.image("trans image", grid_x_trans)) grid_x_real = ops.get_grid_image_summary(x_real, edge_num) inter_sum_op.append(tf.summary.image("real image", grid_x_real)) # merge summary op gen_model.extra_sum_op = tf.summary.merge(gen_model.extra_sum_op) gen_model.sum_op = tf.summary.merge(gen_model.sum_op) disc_model.sum_op = tf.summary.merge(disc_model.sum_op) inter_sum_op = tf.summary.merge(inter_sum_op) print("=> Compute gradient") # get train op gen_model.get_trainable_variables() disc_model.get_trainable_variables() # Not applying update op will result in failure update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): # normal GAN train op gen_model.train_op = tf.train.AdamOptimizer( learning_rate=lr, beta1=0.5, beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars) # train reconstruction branch var_list = gen_model.branch_vars['image_feat'] + gen_model.branch_vars['feat_noise'] gen_model.extra_train_op_partial1 = tf.train.AdamOptimizer( learning_rate=lr/5, beta1=0.5, beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list) var_list += gen_model.branch_vars['noise_feat'] + gen_model.branch_vars['feat_image'] gen_model.extra_train_op_full = tf.train.AdamOptimizer( learning_rate=lr/5, beta1=0.5, beta2=0.9).minimize(gen_model.extra_loss, var_list=var_list) gen_model.full_train_op = tf.train.AdamOptimizer( learning_rate=lr/5, beta1=0.5, beta2=0.9).minimize(gen_model.total_cost, var_list=gen_model.vars) disc_model.train_op = tf.train.AdamOptimizer( learning_rate=lr, beta1=0.5, beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars) print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s" % v.name) ctrl_weight = { "real_cls" : real_cls_weight, "fake_cls" : fake_cls_weight, "adv" : adv_weight, "lr" : lr, "inc_length": inc_length } trainer_feed = { "gen_model": gen_model, "disc_model" : disc_model, "ctrl_weight": ctrl_weight, "dataloader": dl, "int_sum_op": inter_sum_op, "gpu_mem": TFLAGS['gpu_mem'], "FLAGS": FLAGS, "TFLAGS": TFLAGS } Trainer = trainer.base_gantrainer_rep.BaseGANTrainer trainer_feed.update({ "inputs": [z_noise, c_noise, x_real, c_label], }) ModelTrainer = Trainer(**trainer_feed) command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) command_controller.start_thread() ModelTrainer.init_training() ModelTrainer.train()
def main(): size = FLAGS.img_size # debug if len(FLAGS.train_dir) < 1: bn_name = ["nobn", "caffebn", "simplebn", "defaultbn", "cbn"] FLAGS.train_dir = os.path.join( "logs", FLAGS.model_name + "_" + bn_name[FLAGS.bn] + "_" + str(FLAGS.phases)) if FLAGS.cgan: # the label file is npy format npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy' else: npy_dir = None if "celeb" in FLAGS.data_dir: dataset = dataloader.CelebADataset(FLAGS.data_dir, img_size=(size, size), npy_dir=npy_dir) elif "cityscapes" in FLAGS.data_dir: augmentations = Compose([ RandomCrop(size * 4), Scale(size * 2), RandomRotate(10), RandomHorizontallyFlip(), RandomSizedCrop(size) ]) dataset = dataloader.cityscapesLoader(FLAGS.data_dir, is_transform=True, augmentations=augmentations, img_size=(size, size)) FLAGS.batch_size /= 64 else: dataset = dataloader.FileDataset(FLAGS.data_dir, npy_dir=npy_dir, img_size=(size, size)) dl = dataloader.TFDataloader(dataset, FLAGS.batch_size, dataset.file_num // FLAGS.batch_size) # TF Input x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample") x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real") s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real') z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise") if FLAGS.cgan: c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise") c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label") gen_input = [z_noise, c_noise] else: gen_input = z_noise # look up the config function from lib.config module gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num) disc_model.norm_mtd = FLAGS.bn x_fake = gen_model(gen_input, update_collection=None) gen_model.set_reuse() gen_model.x_fake = x_fake disc_model.set_label(c_noise) if FLAGS.phases > 1: disc_model.set_phase("fake") else: disc_model.set_phase("default") disc_fake, fake_cls_logits = disc_model(x_fake, update_collection=None) disc_model.set_reuse() disc_model.set_label(c_label) if FLAGS.phases > 1: disc_model.set_phase("real") else: disc_model.set_phase("default") disc_real, real_cls_logits = disc_model(x_real, update_collection=None) disc_model.disc_real = disc_real disc_model.disc_fake = disc_fake disc_model.real_cls_logits = real_cls_logits disc_model.fake_cls_logits = fake_cls_logits int_sum_op = [] if FLAGS.use_cache: disc_fake_sample = disc_model(x_fake_sample)[0] disc_cost_sample = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_fake_sample, labels=tf.zeros_like(disc_fake_sample)), name="cost_disc_fake_sample") disc_cost_sample_sum = tf.summary.scalar("disc_sample", disc_cost_sample) fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4) int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid)) sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample] else: sample_method = None grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4) int_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_real = ops.get_grid_image_summary(x_real, 4) int_sum_op.append(tf.summary.image("real image", grid_x_real)) int_sum_op = tf.summary.merge(int_sum_op) raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss( gen_model, disc_model, adv_weight=1.0, summary=False) disc_model.disc_real_loss = raw_disc_real disc_model.disc_fake_loss = raw_disc_fake if FLAGS.cgan: real_cls_cost, fake_cls_cost = loss.classifier_loss(gen_model, disc_model, x_real, c_label, c_noise, weight=1.0 / dataset.class_num, summary=False) step_sum_op = [] subloss_names = ["fake_cls", "real_cls", "gen", "disc_real", "disc_fake"] sublosses = [ fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real, raw_disc_fake ] for n, l in zip(subloss_names, sublosses): step_sum_op.append(tf.summary.scalar(n, l)) step_sum_op = tf.summary.merge(step_sum_op) ModelTrainer = trainer.base_gantrainer.BaseGANTrainer( #trainer.separated_gantrainer.SeparatedGANTrainer(# int_sum_op=int_sum_op, step_sum_op=step_sum_op, dataloader=dl, FLAGS=FLAGS, gen_model=gen_model, disc_model=disc_model, gen_input=gen_input, x_real=x_real, label=c_label, sample_method=sample_method) #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) #command_controller.start_thread() print("=> Build train op") ModelTrainer.build_train_op() print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) print("=> #### Moving Variable ####") for v in tf.global_variables(): if "moving" in v.name: print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) ModelTrainer.init_training() ModelTrainer.train()
def main(): size = FLAGS.img_size if len(FLAGS.train_dir) < 1: # if the train dir is not set, then automatically decide one FLAGS.train_dir = os.path.join("logs", FLAGS.model_name + str(FLAGS.img_size)) if FLAGS.cgan: FLAGS.train_dir += "_cgan" if FLAGS.cgan: # the label file should be npy format npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy' else: npy_dir = None if "celeb" in FLAGS.data_dir: dataset = dataloader.CelebADataset(FLAGS.data_dir, img_size=(size, size), npy_dir=npy_dir) elif "cityscapes" in FLAGS.data_dir: # outdated augmentations = Compose([ RandomCrop(size * 4), Scale(size * 2), RandomRotate(10), RandomHorizontallyFlip(), RandomSizedCrop(size) ]) dataset = dataloader.cityscapesLoader(FLAGS.data_dir, is_transform=True, augmentations=augmentations, img_size=(size, size)) FLAGS.batch_size /= 64 else: dataset = dataloader.FileDataset(FLAGS.data_dir, npy_dir=npy_dir, img_size=(size, size)) dl = dataloader.TFDataloader(dataset, FLAGS.batch_size, dataset.file_num // FLAGS.batch_size) # TF Input x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample") x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real") s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real') z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise") if FLAGS.cgan: c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise") c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label") gen_input = [z_noise, c_noise] else: gen_input = z_noise c_label = c_noise = None # look up the config function from lib.config module gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num) gen_model.label = c_noise x_fake = gen_model(gen_input) gen_model.set_reuse() gen_model.x_fake = x_fake disc_model.label = c_noise disc_fake, fake_cls_logits = disc_model(x_fake) disc_model.set_reuse() disc_model.label = c_label disc_real, real_cls_logits = disc_model(x_real) disc_model.disc_real = disc_real disc_model.disc_fake = disc_fake disc_model.real_cls_logits = real_cls_logits disc_model.fake_cls_logits = fake_cls_logits raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss( gen_model, disc_model, adv_weight=1.0, summary=False) disc_model.disc_real_loss = raw_disc_real disc_model.disc_fake_loss = raw_disc_fake if FLAGS.cgan: real_cls_cost, fake_cls_cost = loss.classifier_loss(gen_model, disc_model, x_real, c_label, c_noise, weight=1.0 / dataset.class_num, summary=False) subloss_names = [ "fake_cls", "real_cls", "gen", "disc_real", "disc_fake" ] sublosses = [ fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real, raw_disc_fake ] else: subloss_names = ["gen", "disc_real", "disc_fake"] sublosses = [raw_gen_cost, raw_disc_real, raw_disc_fake] step_sum_op = [] # summary at every step for n, l in zip(subloss_names, sublosses): step_sum_op.append(tf.summary.scalar(n, l)) if gen_model.debug or disc_model.debug: for model_var in tf.global_variables(): if gen_model.name in model_var.op.name or disc_model.name in model_var.op.name: step_sum_op.append( tf.summary.histogram(model_var.op.name, model_var)) step_sum_op = tf.summary.merge(step_sum_op) int_sum_op = [] # summary at some interval grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4) int_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_real = ops.get_grid_image_summary(x_real, 4) int_sum_op.append(tf.summary.image("real image", grid_x_real)) int_sum_op = tf.summary.merge(int_sum_op) ModelTrainer = trainer.base_gantrainer.BaseGANTrainer( int_sum_op=int_sum_op, step_sum_op=step_sum_op, dataloader=dl, FLAGS=FLAGS, gen_model=gen_model, disc_model=disc_model, gen_input=gen_input, x_real=x_real, label=c_label) print("=> Build train op") ModelTrainer.build_train_op() print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> #### Moving Variable ####") for v in tf.global_variables(): if "moving" in v.name: print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) print("=> #### Generator update dependency ####") for v in gen_model.update_ops: print("%s" % (v.name)) print("=> #### Discriminator update dependency ####") for v in disc_model.update_ops: print("%s" % (v.name)) ModelTrainer.init_training() ModelTrainer.train()
def main(): size = FLAGS.img_size if FLAGS.cgan: npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy' else: npy_dir = None if "celeb" in FLAGS.data_dir: dataset = dataloader.CelebADataset(FLAGS.data_dir, img_size=(size, size), npy_dir=npy_dir) elif "cityscapes" in FLAGS.data_dir: augmentations = Compose([ RandomCrop(size * 4), Scale(size * 2), RandomRotate(10), RandomHorizontallyFlip(), RandomSizedCrop(size) ]) dataset = dataloader.cityscapesLoader(FLAGS.data_dir, is_transform=True, augmentations=augmentations, img_size=(size, size)) else: dataset = dataloader.FileDataset(FLAGS.data_dir, npy_dir=npy_dir, img_size=(size, size), shuffle=True) dl = DataLoader(dataset, batch_size=FLAGS.batch_size // 64, shuffle=True, num_workers=NUM_WORKER, collate_fn=dataloader.default_collate) # TF Input x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample") x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real") s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real') z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise") if FLAGS.cgan: c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise") c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label") gen_input = [z_noise, c_noise] else: gen_input = z_noise gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num) gen_model.cbn_project = FLAGS.cbn_project x_fake = gen_model(gen_input, update_collection=None) gen_model.set_reuse() gen_model.x_fake = x_fake disc_real, real_cls_logits = disc_model(x_real, update_collection=None) disc_model.set_reuse() disc_fake, fake_cls_logits = disc_model(x_fake, update_collection="no_ops") disc_model.disc_real = disc_real disc_model.disc_fake = disc_fake disc_model.real_cls_logits = real_cls_logits disc_model.fake_cls_logits = fake_cls_logits int_sum_op = [] if FLAGS.use_cache: disc_fake_sample = disc_model(x_fake_sample)[0] disc_cost_sample = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_fake_sample, labels=tf.zeros_like(disc_fake_sample)), name="cost_disc_fake_sample") disc_cost_sample_sum = tf.summary.scalar("disc_sample", disc_cost_sample) fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4) int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid)) sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample] else: sample_method = None grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4) int_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_real = ops.get_grid_image_summary(x_real, 4) int_sum_op.append(tf.summary.image("real image", grid_x_real)) if FLAGS.cgan: loss.classifier_loss(gen_model, disc_model, x_real, c_label, c_noise, weight=1.0) loss.hinge_loss(gen_model, disc_model, adv_weight=1.0) int_sum_op = tf.summary.merge(int_sum_op) ModelTrainer = trainer.base_gantrainer.BaseGANTrainer( int_sum_op=int_sum_op, dataloader=dl, FLAGS=FLAGS, gen_model=gen_model, disc_model=disc_model, gen_input=gen_input, x_real=x_real, label=c_label, sample_method=sample_method) command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) command_controller.start_thread() print("=> Build train op") ModelTrainer.build_train_op() print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) print("=> #### Moving Variable ####") for v in tf.global_variables(): if "moving" in v.name: print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) ModelTrainer.init_training() ModelTrainer.train()
def main(): size = FLAGS.img_size if FLAGS.cgan: # the label file is npy format npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy' else: npy_dir = None if "celeb" in FLAGS.data_dir: dataset = dataloader.CelebADataset(FLAGS.data_dir, img_size=(size, size), npy_dir=npy_dir) elif "cityscapes" in FLAGS.data_dir: augmentations = Compose([ RandomCrop(size * 4), Scale(size * 2), RandomRotate(10), RandomHorizontallyFlip(), RandomSizedCrop(size) ]) dataset = dataloader.cityscapesLoader(FLAGS.data_dir, is_transform=True, augmentations=augmentations, img_size=(size, size)) FLAGS.batch_size /= 64 else: dataset = dataloader.FileDataset(FLAGS.data_dir, npy_dir=npy_dir, img_size=(size, size)) dl = dataloader.TFDataloader(dataset, FLAGS.batch_size, dataset.file_num // FLAGS.batch_size) # TF Input x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample") x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real") s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real') z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise") if FLAGS.cgan: c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise") c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label") gen_input = [z_noise, c_noise] else: gen_input = z_noise # look up the config function from lib.config module gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num) gen_model.cbn_project = FLAGS.cbn_project gen_model.spectral_norm = FLAGS.sn disc_model.cbn_project = FLAGS.cbn_project disc_model.spectral_norm = FLAGS.sn ModelTrainer = trainer.base_gantrainer.BaseGANTrainer( step_sum_op=None, int_sum_op=None, dataloader=dl, FLAGS=FLAGS, gen_model=gen_model, disc_model=disc_model, gen_input=gen_input, x_real=x_real, label=c_label) g_tower_grads = [] d_tower_grads = [] g_optim = tf.train.AdamOptimizer(learning_rate=ModelTrainer.g_lr, beta1=0., beta2=0.9) d_optim = tf.train.AdamOptimizer(learning_rate=ModelTrainer.d_lr, beta1=0., beta2=0.9) grad_x = [] grad_x_name = [] xs = [] x_name = [] def tower(gpu_id, gen_input, x_real, c_label=None, c_noise=None, update_collection=None, loss_collection=[]): """ The loss function builder of gen and disc """ gen_model.cost = disc_model.cost = 0 gen_model.set_phase("gpu%d" % gpu_id) x_fake = gen_model(gen_input, update_collection=update_collection) gen_model.set_reuse() gen_model.x_fake = x_fake disc_model.set_phase("gpu%d" % gpu_id) disc_real, real_cls_logits = disc_model( x_real, update_collection=update_collection) disc_model.set_reuse() disc_model.recorded_tensors = [] disc_model.recorded_names = [] disc_fake, fake_cls_logits = disc_model( x_fake, update_collection=update_collection) disc_model.disc_real = disc_real disc_model.disc_fake = disc_fake disc_model.real_cls_logits = real_cls_logits disc_model.fake_cls_logits = fake_cls_logits if FLAGS.cgan: fake_cls_cost, real_cls_cost = loss.classifier_loss( gen_model, disc_model, x_real, c_label, c_noise, weight=1.0 / dataset.class_num, summary=False) raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss( gen_model, disc_model, adv_weight=1.0, summary=False) gen_model.vars = [ v for v in tf.trainable_variables() if gen_model.name in v.name ] disc_model.vars = [ v for v in tf.trainable_variables() if disc_model.name in v.name ] g_grads = tf.gradients(gen_model.cost, gen_model.vars, colocate_gradients_with_ops=True) d_grads = tf.gradients(disc_model.cost, disc_model.vars, colocate_gradients_with_ops=True) g_grads = [ tf.check_numerics(g, "G grad nan: " + str(g)) for g in g_grads ] d_grads = [ tf.check_numerics(g, "D grad nan: " + str(g)) for g in d_grads ] g_tower_grads.append(g_grads) d_tower_grads.append(d_grads) tensors = gen_model.recorded_tensors + disc_model.recorded_tensors names = gen_model.recorded_names + disc_model.recorded_names if gpu_id == 0: x_name.extend(names) xs.append(tensors) names = names[::-1] tensors = tensors[::-1] grads = tf.gradients(disc_fake, tensors, colocate_gradients_with_ops=True) for n, g in zip(names, grads): print(n, g) grad_x.append( [tf.check_numerics(g, "BP nan: " + str(g)) for g in grads]) if gpu_id == 0: grad_x_name.extend(names) disc_model.recorded_tensors = [] disc_model.recorded_names = [] gen_model.recorded_tensors = [] gen_model.recorded_names = [] return gen_model.cost, disc_model.cost, [ fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real, raw_disc_fake ] def average_gradients(tower_grads): average_grads = [] num_gpus = len(tower_grads) num_items = len(tower_grads[0]) for i in range(num_items): average_grads.append(0.0) for j in range(num_gpus): average_grads[i] += tower_grads[j][i] average_grads[i] /= num_gpus return average_grads sbs = FLAGS.batch_size // NUM_GPU for i in range(NUM_GPU): if i == 0: update_collection = None else: update_collection = "no_ops" with tf.device(tf.DeviceSpec(device_type="GPU", device_index=i)): if FLAGS.cgan: l1, l2, ot1 = tower(i, [ z_noise[sbs * i:sbs * (i + 1)], c_noise[sbs * i:sbs * (i + 1)] ], x_real[sbs * i:sbs * (i + 1)], c_label[sbs * i:sbs * (i + 1)], c_noise[sbs * i:sbs * (i + 1)], update_collection=update_collection) else: l1, l2, ot1 = tower(i, z_noise[sbs * i:sbs * (i + 1)], x_real[sbs * i:sbs * (i + 1)], update_collection=update_collection) if i == 0: int_sum_op = [] grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4) int_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_real = ops.get_grid_image_summary(x_real, 4) int_sum_op.append(tf.summary.image("real image", grid_x_real)) step_sum_op = [] sub_loss_names = [ "fake_cls", "real_cls", "gen", "disc_real", "disc_fake" ] for n, l in zip(sub_loss_names, ot1): step_sum_op.append(tf.summary.scalar(n, l)) step_sum_op.append(tf.summary.scalar("gen", gen_model.cost)) step_sum_op.append(tf.summary.scalar("disc", disc_model.cost)) with tf.device(tf.DeviceSpec(device_type="GPU", device_index=0)): g_grads = average_gradients(g_tower_grads) d_grads = average_gradients(d_tower_grads) gen_model.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=gen_model.name + "/") disc_model.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=disc_model.name + "/") print(gen_model.update_ops) print(disc_model.update_ops) def merge_list(l1, l2): return [[l1[i], l2[i]] for i in range(len(l1))] with tf.control_dependencies(gen_model.update_ops): gen_model.train_op = g_optim.apply_gradients( merge_list(g_grads, gen_model.vars)) with tf.control_dependencies(disc_model.update_ops): disc_model.train_op = d_optim.apply_gradients( merge_list(d_grads, disc_model.vars)) if FLAGS.use_cache: disc_fake_sample = disc_model(x_fake_sample)[0] disc_cost_sample = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_fake_sample, labels=tf.zeros_like(disc_fake_sample)), name="cost_disc_fake_sample") disc_cost_sample_sum = tf.summary.scalar("disc_sample", disc_cost_sample) fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4) int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid)) sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample] else: sample_method = None ModelTrainer.int_sum_op = tf.summary.merge(int_sum_op) ModelTrainer.step_sum_op = tf.summary.merge(step_sum_op) ModelTrainer.grad_x = grad_x ModelTrainer.grad_x_name = grad_x_name ModelTrainer.xs = xs ModelTrainer.x_name = x_name #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) #command_controller.start_thread() print("=> Build train op") # ModelTrainer.build_train_op() print("=> ##### Generator Variable #####") gen_model.print_variables() print("=> ##### Discriminator Variable #####") disc_model.print_variables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) print("=> #### Moving Variable ####") for v in tf.global_variables(): if "moving" in v.name: print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list()))) ModelTrainer.init_training() ModelTrainer.train()
def main(): # get configuration print("Get configuration") TFLAGS = cfg.get_train_config(FLAGS.model_name) gen_model, gen_config, disc_model, disc_config = model.get_model( FLAGS.model_name, TFLAGS) gen_model = model.conditional_generator.ImageConditionalEncoder() disc_model = model.conditional_generator.ImageConditionalDeepDiscriminator( ) gen_model.name = "ImageConditionalEncoder" disc_model.name = "ImageConditionalDeepDiscriminator" print("Common length: %d" % gen_model.common_length) TFLAGS['AE_weight'] = 1.0 TFLAGS['batch_size'] = 1 TFLAGS['input_shape'] = [256, 256, 3] face_dataset = dataloader.CustomDataset( root_dir="/data/datasets/getchu/crop_character", npy_dir="/data/datasets/getchu/true_character.npy", preproc_kind="tanh", img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], has_gray=False, disturb=False, flip=False) sketch_dataset = dataloader.CustomDataset( root_dir="/data/datasets/getchu/crop_sketch_character/", npy_dir=None, preproc_kind="tanh", img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], has_gray=True, disturb=False, flip=False) label_dataset = dataloader.NumpyArrayDataset( np.load("/data/datasets/getchu/true_face.npy")[:, 1:]) listsampler = dataloader.ListSampler( [face_dataset, sketch_dataset, label_dataset]) dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4) # TF Input x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real") fake_x_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real") # semantic segmentation s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'][:-1] + [ 1, ], name='s_real') z_noise = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="z_noise") c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise") c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label") # control variables real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight") fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight") adv_weight = tf.placeholder(tf.float32, [], name="adv_weight") inc_length = tf.placeholder(tf.int32, [], name="inc_length") lr = tf.placeholder(tf.float32, [], name="lr") with tf.variable_scope(gen_model.name): gen_model.image_input = x_real gen_model.seg_input = s_real gen_model.noise_input = tf.concat([z_noise, c_noise], axis=1) seg_image, image_image, seg_seg, image_seg = gen_model.build_inference( ) seg_feat = tf.identity(gen_model.seg_feat, "seg_feat") image_feat = tf.identity(gen_model.image_feat, "image_feat") gen_model.x_fake = tf.identity(seg_image) disc_real_out = disc_model([x_real, s_real]) disc_model.set_reuse() disc_fake_out = disc_model([seg_image, s_real]) disc_fake_sketch_out = disc_model([x_real, image_seg]) #disc_rec_out = disc_model([image_image, s_real]) disc_fake_sample_out = disc_model([fake_x_sample, s_real]) gen_model.cost = disc_model.cost = 0 gen_model.sum_op = disc_model.sum_op = [] inter_sum_op = [] # Select loss builder and model trainer print("Build training graph") # Naive GAN loss.naive_ganloss.func_gen_loss(disc_fake_out, adv_weight * 0.9, name="GenSeg", model=gen_model) loss.naive_ganloss.func_gen_loss(disc_fake_sketch_out, adv_weight * 0.1, name="GenSketch", model=gen_model) if FLAGS.cache: loss.naive_ganloss.func_disc_fake_loss(disc_fake_sample_out, adv_weight * 0.9, name="DiscFake", model=disc_model) loss.naive_ganloss.func_disc_fake_loss(disc_fake_sample_out, adv_weight * 0.1, name="DiscFakeSketch", model=disc_model) else: loss.naive_ganloss.func_disc_fake_loss(disc_fake_out, adv_weight * 0.9, name="DiscFake", model=disc_model) loss.naive_ganloss.func_disc_fake_loss(disc_fake_sketch_out, adv_weight * 0.1, name="DiscFakeSketch", model=disc_model) loss.naive_ganloss.func_disc_real_loss(disc_real_out, adv_weight, name="DiscGen", model=disc_model) # recontrust of sketch rec_sketch_cost_, rec_sketch_sum_ = loss.common_loss.reconstruction_loss( image_seg, s_real, TFLAGS['AE_weight'], name="RecSketch") gen_model.cost += rec_sketch_cost_ gen_model.sum_op.append(rec_sketch_sum_) gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost") disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost") # total summary gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost)) disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost)) # add interval summary edge_num = int(np.sqrt(TFLAGS['batch_size'])) if edge_num > 4: edge_num = 4 grid_x_fake = ops.get_grid_image_summary(seg_image, edge_num) inter_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_seg = ops.get_grid_image_summary(image_seg, edge_num) inter_sum_op.append(tf.summary.image("inv image", grid_x_seg)) grid_x_real = ops.get_grid_image_summary(x_real, edge_num) inter_sum_op.append(tf.summary.image("real image", grid_x_real)) grid_s_real = ops.get_grid_image_summary(s_real, edge_num) inter_sum_op.append(tf.summary.image("sketch image", grid_s_real)) # merge summary op gen_model.sum_op = tf.summary.merge(gen_model.sum_op) disc_model.sum_op = tf.summary.merge(disc_model.sum_op) inter_sum_op = tf.summary.merge(inter_sum_op) print("=> Compute gradient") # get train op gen_model.get_trainable_variables() disc_model.get_trainable_variables() # Not applying update op will result in failure update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): # normal GAN train op gen_model.train_op = tf.train.AdamOptimizer( learning_rate=lr, beta1=0.5, beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars) disc_model.train_op = tf.train.AdamOptimizer( learning_rate=lr, beta1=0.5, beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars) print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s" % v.name) ctrl_weight = { "real_cls": real_cls_weight, "fake_cls": fake_cls_weight, "adv": adv_weight, "lr": lr, "inc_length": inc_length } trainer_feed = { "gen_model": gen_model, "disc_model": disc_model, "ctrl_weight": ctrl_weight, "dataloader": dl, "int_sum_op": inter_sum_op, "gpu_mem": TFLAGS['gpu_mem'], "FLAGS": FLAGS, "TFLAGS": TFLAGS } Trainer = trainer.cgantrainer.CGANTrainer Trainer.use_cache = False trainer_feed.update({ "inputs": [x_real, s_real, c_label], "noise": z_noise, "cond": c_noise }) ModelTrainer = Trainer(**trainer_feed) ModelTrainer.fake_x_sample = fake_x_sample command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) command_controller.start_thread() ModelTrainer.init_training() ModelTrainer.train()
def main(): # get configuration print("Get configuration") TFLAGS = cfg.get_train_config(FLAGS.model_name) gen_config = cfg.good_generator.goodmodel_gen({}) gen_config['name'] = "CondDeepGenerator" gen_model = model.conditional_generator.ImageConditionalDeepGenerator2( **gen_config) disc_config = cfg.good_generator.goodmodel_disc({}) disc_config['norm_mtd'] = None disc_model = model.good_generator.GoodDiscriminator(**disc_config) TFLAGS['batch_size'] = 1 TFLAGS['AE_weight'] = 10.0 TFLAGS['side_noise'] = FLAGS.side_noise TFLAGS['input_shape'] = [128, 128, 3] # Data Preparation # raw image is 0~255 print("Get dataset") face_dataset = dataloader.CustomDataset( root_dir="/data/datasets/getchu/true_face", npy_dir=None, preproc_kind="tanh", img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], disturb=False, flip=False) sketch_dataset = dataloader.CustomDataset( root_dir="/data/datasets/getchu/sketch_face/", npy_dir=None, preproc_kind="tanh", img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], disturb=False, flip=False) label_dataset = dataloader.NumpyArrayDataset( np.load("/data/datasets/getchu/true_face.npy")[:, 1:]) listsampler = dataloader.ListSampler( [sketch_dataset, face_dataset, label_dataset]) dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4) # TF Input x_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real") s_real = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'][:-1] + [1], name='s_real') fake_x_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="x_real") noise_A = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="noise_A") noise_B = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="noise_B") label_A = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="label_A") label_B = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="label_B") # c label is for real samples c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label") c_noise = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_noise") # control variables real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight") fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight") adv_weight = tf.placeholder(tf.float32, [], name="adv_weight") lr = tf.placeholder(tf.float32, [], name="lr") side_noise_A = tf.concat([noise_A, c_label], axis=1, name='side_noise_A') if TFLAGS['side_noise']: gen_model.side_noise = side_noise_A x_fake = gen_model([s_real]) gen_model.set_reuse() gen_model.x_fake = x_fake if FLAGS.cgan: disc_model.is_cgan = True disc_model.disc_real_out, disc_model.real_cls_logits = disc_model( x_real)[:2] disc_model.set_reuse() disc_model.disc_fake_out, disc_model.fake_cls_logits = disc_model( x_fake)[:2] gen_model.cost = disc_model.cost = 0 gen_model.sum_op = disc_model.sum_op = [] inter_sum_op = [] if TFLAGS['gan_loss'] == "dra": # naive GAN loss gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.naive_ganloss.get_naive_ganloss( gen_model, disc_model, adv_weight) gen_model.cost += gen_cost_ disc_model.cost += disc_cost_ gen_model.sum_op.extend(gen_sum_) disc_model.sum_op.extend(disc_sum_) gen_model.gen_cost = gen_cost_ disc_model.disc_cost = disc_cost_ # dragan loss disc_cost_, disc_sum_ = loss.dragan_loss.get_dragan_loss( disc_model, x_real, TFLAGS['gp_weight']) disc_model.cost += disc_cost_ disc_model.sum_op.append(disc_sum_) elif TFLAGS['gan_loss'] == "wass": gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.wass_ganloss.wass_gan_loss( gen_model, disc_model, x_real, x_fake) gen_model.cost += gen_cost_ disc_model.cost += disc_cost_ gen_model.sum_op.extend(gen_sum_) disc_model.sum_op.extend(disc_sum_) elif TFLAGS['gan_loss'] == "naive": # naive GAN loss gen_cost_, disc_cost_, gen_sum_, disc_sum_ = loss.naive_ganloss.get_naive_ganloss( gen_model, disc_model, adv_weight, lsgan=True) gen_model.cost += gen_cost_ disc_model.cost += disc_cost_ gen_model.sum_op.extend(gen_sum_) disc_model.sum_op.extend(disc_sum_) gen_model.gen_cost = gen_cost_ disc_model.disc_cost = disc_cost_ if FLAGS.cgan: real_cls = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_model.real_cls_logits, labels=c_label), name="real_cls_reduce_mean") * real_cls_weight #fake_cls = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( # logits=disc_model.fake_cls_logits, # labels=c_label), name="fake_cls_reduce_mean") * fake_cls_weight disc_model.cost += real_cls # + fake_cls disc_model.sum_op.extend([tf.summary.scalar("RealCls", real_cls)]) #, #tf.summary.scalar("FakeCls", fake_cls)]) gen_cost_, ae_loss_sum_ = loss.common_loss.reconstruction_loss( x_fake, x_real, TFLAGS['AE_weight']) gen_model.sum_op.append(ae_loss_sum_) gen_model.cost += gen_cost_ gen_model.cost = tf.identity(gen_model.cost, "TotalGenCost") disc_model.cost = tf.identity(disc_model.cost, "TotalDiscCost") # total summary gen_model.sum_op.append(tf.summary.scalar("GenCost", gen_model.cost)) disc_model.sum_op.append(tf.summary.scalar("DiscCost", disc_model.cost)) # add interval summary edge_num = int(np.sqrt(TFLAGS['batch_size'])) if edge_num > 4: edge_num = 4 grid_x_fake = ops.get_grid_image_summary(x_fake, edge_num) inter_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_real = ops.get_grid_image_summary(x_real, edge_num) grid_x_real = tf.Print( grid_x_real, [tf.reduce_max(grid_x_real), tf.reduce_min(grid_x_real)], "Real") inter_sum_op.append(tf.summary.image("real image", grid_x_real)) inter_sum_op.append( tf.summary.image("sketch image", ops.get_grid_image_summary(s_real, edge_num))) # merge summary op gen_model.sum_op = tf.summary.merge(gen_model.sum_op) disc_model.sum_op = tf.summary.merge(disc_model.sum_op) inter_sum_op = tf.summary.merge(inter_sum_op) # get train op gen_model.get_trainable_variables() disc_model.get_trainable_variables() # Not applying update op will result in failure update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): gen_model.train_op = tf.train.AdamOptimizer( learning_rate=lr, beta1=0.5, beta2=0.9).minimize(gen_model.cost, var_list=gen_model.vars) disc_model.train_op = tf.train.AdamOptimizer( learning_rate=lr, beta1=0.5, beta2=0.9).minimize(disc_model.cost, var_list=disc_model.vars) print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s" % v.name) ctrl_weight = { "real_cls": real_cls_weight, "fake_cls": fake_cls_weight, "adv": adv_weight, "lr": lr } trainer_feed = { "gen_model": gen_model, "disc_model": disc_model, "noise": noise_A, "ctrl_weight": ctrl_weight, "dataloader": dl, "int_sum_op": inter_sum_op, "gpu_mem": TFLAGS['gpu_mem'], "FLAGS": FLAGS, "TFLAGS": TFLAGS } Trainer = trainer.cgantrainer.CGANTrainer trainer_feed.update({ "inputs": [s_real, x_real, c_label], }) ModelTrainer = Trainer(**trainer_feed) command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) command_controller.start_thread() ModelTrainer.init_training() disc_model.load_from_npz('success/goodmodel_dragan_anime1_disc.npz', ModelTrainer.sess) ModelTrainer.train()
def main(): size = FLAGS.img_size if FLAGS.cgan: npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy' else: npy_dir = None if "celeb" in FLAGS.data_dir: dataset = dataloader.CelebADataset(FLAGS.data_dir, img_size=(size, size), npy_dir=npy_dir) else: dataset = dataloader.FileDataset(FLAGS.data_dir, npy_dir=npy_dir, img_size=(size, size), shuffle=True) dl = DataLoader(dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=NUM_WORKER) # TF Input x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample") x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real") s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real') z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise") if FLAGS.cgan: c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise") c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label") gen_input = [z_noise, c_noise] else: gen_input = z_noise gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num) gen_model.mask_num = FLAGS.mask_num gen_model.cbn_project = FLAGS.cbn_project x_fake = gen_model(gen_input, update_collection=None) gen_model.set_reuse() gen_model.x_fake = x_fake disc_real, real_cls_logits = disc_model(x_real, update_collection="no_ops") disc_model.set_reuse() disc_fake, fake_cls_logits = disc_model(x_fake, update_collection=None) disc_model.disc_real = disc_real disc_model.disc_fake = disc_fake disc_model.real_cls_logits = real_cls_logits disc_model.fake_cls_logits = fake_cls_logits int_sum_op = [] if FLAGS.use_cache: disc_fake_sample = disc_model(x_fake_sample)[0] disc_cost_sample = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_fake_sample, labels=tf.zeros_like(disc_fake_sample)), name="cost_disc_fake_sample") disc_cost_sample_sum = tf.summary.scalar("disc_sample", disc_cost_sample) fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4) int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid)) sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample] else: sample_method = None print("=> Mask num " + str(gen_model.overlapped_mask.get_shape())) # diverse mask diverse_loss, diverse_loss_sum = loss.cosine_diverse_distribution(gen_model.overlapped_mask) # make sure mask is not eliminated mask_weight = tf.reduce_sum(gen_model.overlapped_mask, [1, 2]) mask_num = mask_weight.get_shape().as_list()[-1] avg_map_weight = (size ** 2) / float(mask_num) diff_map = tf.abs(mask_weight - avg_map_weight) restricted_diff_map = tf.nn.relu(diff_map - 2 * avg_map_weight) restricted_var_loss = 1e-3 * tf.reduce_mean(restricted_diff_map) var_loss_sum = tf.summary.scalar("variance loss", restricted_var_loss) # semantic """ uniform_loss = 0 vgg_net = model.classifier.MyVGG16("lib/tensorflowvgg/vgg16.npy") vgg_net.build(tf.image.resize_bilinear(x_fake, (224, 224))) sf = vgg_net.conv3_3 mask_shape = sf.get_shape().as_list()[1:3] print("=> VGG feature shape: " + str(mask_shape)) diff_maps = [] for i in range(mask_num): mask = tf.image.resize_bilinear(gen_model.overlapped_mask[:, :, :, i:i+1], mask_shape) # (batch, size, size, 1) mask = mask / tf.reduce_sum(mask, [1, 2], keepdims=True) expected_feature = tf.reduce_sum(mask * sf, [1, 2], keepdims=True) # (batch, 1, 1, 256) diff_map = tf.reduce_mean(tf.abs(mask * (sf - expected_feature)), [3]) # (batch, size, size) diff_maps.append(diff_map[0] / tf.reduce_max(diff_map[0])) restricted_diff_map = diff_map # TODO: add margin uniform_loss += 1e-3 * tf.reduce_mean(tf.reduce_sum(diff_map, [1, 2])) uniform_loss_sum = tf.summary.scalar("uniform loss", uniform_loss) """ # smooth mask tv_loss = tf.reduce_mean(tf.image.total_variation(gen_model.overlapped_mask)) / (size ** 2) tv_sum = tf.summary.scalar("TV loss", tv_loss) gen_model.cost += diverse_loss + tv_loss + restricted_var_loss gen_model.sum_op.extend([tv_sum, diverse_loss_sum, var_loss_sum]) edge_num = int(np.sqrt(gen_model.overlapped_mask.get_shape().as_list()[-1])) mask_seq = tf.transpose(gen_model.overlapped_mask[0], [2, 0, 1]) grid_mask = tf.expand_dims(ops.get_grid_image_summary(mask_seq, edge_num), -1) int_sum_op.append(tf.summary.image("stroke mask", grid_mask)) #uniform_diff_map = tf.expand_dims(ops.get_grid_image_summary(tf.stack(diff_maps, 0), edge_num), -1) #int_sum_op.append(tf.summary.image("uniform diff map", uniform_diff_map)) grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4) int_sum_op.append(tf.summary.image("generated image", grid_x_fake)) grid_x_real = ops.get_grid_image_summary(x_real, 4) int_sum_op.append(tf.summary.image("real image", grid_x_real)) if FLAGS.cgan: loss.classifier_loss(gen_model, disc_model, x_real, c_label, c_noise, weight=1.0) loss.hinge_loss(gen_model, disc_model, adv_weight=1.0) int_sum_op = tf.summary.merge(int_sum_op) ModelTrainer = trainer.base_gantrainer.BaseGANTrainer( int_sum_op=int_sum_op, dataloader=dl, FLAGS=FLAGS, gen_model=gen_model, disc_model=disc_model, gen_input=gen_input, x_real=x_real, label=c_label, sample_method=sample_method) #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) #command_controller.start_thread() print("=> Build train op") ModelTrainer.build_train_op() print("=> ##### Generator Variable #####") gen_model.print_trainble_vairables() print("=> ##### Discriminator Variable #####") disc_model.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s" % v.name) ModelTrainer.init_training() ModelTrainer.train()
def build_train(self, input): # ADD label fake_data, real_data, label = input with tf.variable_scope(self.name): self.disc_fake, self.disc_fake_cls = self.build_inference( fake_data) self.reuse = True self.disc_real, self.disc_real_cls = self.build_inference( real_data) self.vars = [ v for v in tf.trainable_variables() if self.name in v.name ] self.tot_loss = 0 if self.is_wgan: self.gen_cost = -tf.reduce_mean(self.disc_fake) self.disc_cost = tf.reduce_mean(self.disc_fake) - tf.reduce_mean( self.disc_real) alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) differences = real_data - fake_data interpolates = fake_data + alpha * differences #tf.multiply(alpha, differences) with tf.variable_scope(self.name): self.disc_interp = self.build_inference(interpolates) gradients = tf.gradients(self.disc_interp, [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) self.gradient_penalty = self.lambda_gp * tf.reduce_mean( (slopes - 1.)**2) self.tot_loss += self.gradient_penalty else: # normal disc self.gen_cost = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_fake, labels=tf.ones_like( self.disc_fake))) self.disc_cost_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_fake, labels=tf.zeros_like( self.disc_fake))) self.disc_cost_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_real, labels=tf.ones_like( self.disc_real))) self.disc_cost = self.disc_cost_fake + self.disc_cost_real self.cls_cost = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_fake_cls, labels=label)) self.cls_cost += tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_real_cls, labels=label)) reg_losses = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) self.tot_loss += self.disc_cost + reg_losses + self.cls_cost with tf.name_scope(self.name + "_Loss"): l2_sum = tf.summary.scalar("L2 regularize", reg_losses) cls_sum = tf.summary.scalar("classification", self.cls_cost) if self.is_wgan: gp_sum = tf.summary.scalar("gradient penalty", self.gradient_penalty) disc_cost_sum = tf.summary.scalar("discriminator loss", self.disc_cost) else: disc_cost_real_sum = tf.summary.scalar( "discriminator loss real", self.disc_cost_real) disc_cost_fake_sum = tf.summary.scalar( "discriminator loss fake", self.disc_cost_fake) gen_cost_sum = tf.summary.scalar("generator loss", self.gen_cost) self.sum_op = tf.summary.merge_all() # grid summary 4x4 show_img = ops.get_grid_image_summary(fake_data, 4) gen_image_sum = tf.summary.image("generated", show_img + 1) self.sum_interval_op = tf.summary.merge([gen_image_sum])
def main(): # get configuration print("Get configuration") TFLAGS = cfg.get_train_config(FLAGS.model_name) TFLAGS['AE_weight'] = 10.0 TFLAGS['side_noise'] = False # A: sketch face, B: anime face # Gen and Disc usually init from pretrained model print("=> Building discriminator") disc_config = cfg.good_generator.goodmodel_disc({}) disc_config['name'] = "DiscA" DiscA = model.good_generator.GoodDiscriminator(**disc_config) disc_config['name'] = "DiscB" DiscB = model.good_generator.GoodDiscriminator(**disc_config) gen_config = cfg.good_generator.goodmodel_gen({}) gen_config['name'] = 'GenA' gen_config['out_dim'] = 3 GenA = model.good_generator.GoodGenerator(**gen_config) gen_config['name'] = 'GenB' gen_config['out_dim'] = 1 GenB = model.good_generator.GoodGenerator(**gen_config) gen_config = {} gen_config['name'] = 'TransForward' gen_config['out_dim'] = 3 TransF = model.conditional_generator.ImageConditionalDeepGenerator2( **gen_config) gen_config['name'] = 'TransBackward' gen_config['out_dim'] = 1 TransB = model.conditional_generator.ImageConditionalDeepGenerator2( **gen_config) models = [DiscA, DiscB, TransF, TransB] model_names = ["DiscA", "DiscB", "TransF", "TransB"] TFLAGS['batch_size'] = 1 TFLAGS['input_shape'] = [128, 128, 3] # Data Preparation # raw image is 0~255 print("Get dataset") face_dataset = dataloader.CustomDataset( root_dir="/data/datasets/getchu/true_face", npy_dir=None, preproc_kind="tanh", img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], disturb=False, flip=False) sketch_dataset = dataloader.CustomDataset( root_dir="/data/datasets/getchu/sketch_face/", npy_dir=None, preproc_kind="tanh", img_size=TFLAGS['input_shape'][:2], filter_data=TFLAGS['filter_data'], class_num=TFLAGS['c_len'], disturb=False, flip=False) label_dataset = dataloader.NumpyArrayDataset( np.load("/data/datasets/getchu/true_face.npy")[:, 1:]) listsampler = dataloader.ListSampler( [sketch_dataset, face_dataset, label_dataset]) dl = dataloader.CustomDataLoader(listsampler, TFLAGS['batch_size'], 4) # TF Input real_A_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'][:-1] + [1], name='real_A_sample') real_B_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="real_B_sample") fake_A_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'][:-1] + [1], name='fake_A_sample') fake_B_sample = tf.placeholder(tf.float32, [None] + TFLAGS['input_shape'], name="fake_B_sample") noise_A = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="noise_A") noise_B = tf.placeholder(tf.float32, [None, TFLAGS['z_len']], name="noise_B") label_A = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="label_A") label_B = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="label_B") # c label is for real samples c_label = tf.placeholder(tf.float32, [None, TFLAGS['c_len']], name="c_label") # control variables real_cls_weight = tf.placeholder(tf.float32, [], name="real_cls_weight") fake_cls_weight = tf.placeholder(tf.float32, [], name="fake_cls_weight") adv_weight = tf.placeholder(tf.float32, [], name="adv_weight") lr = tf.placeholder(tf.float32, [], name="lr") inter_sum_op = [] ### build graph """ fake_A = GenA(noise_A); GenA.set_reuse() # GenA and GenB only used once fake_B = GenB(noise_B); GenB.set_reuse() # transF and TransB used three times trans_real_A = TransF(real_A); TransF.set_reuse() # domain B trans_fake_A = TransF(fake_A) # domain B trans_real_B = TransB(real_B); TransB.set_reuse() # domain A trans_fake_B = TransB(fake_B) # domain A rec_real_A = TransB(trans_real_A) rec_fake_A = TransB(trans_fake_A) rec_real_B = TransF(trans_real_B) rec_fake_B = TransF(trans_fake_B) # DiscA and DiscB reused many times disc_real_A, cls_real_A = DiscA(x_real)[:2]; DiscA.set_reuse() disc_fake_A, cls_fake_A = DiscA(x_fake)[:2]; disc_trans_real_B, cls_trans_real_B = DiscA(trans_real_B)[:2] disc_trans_fake_B, cls_trans_fake_B = DiscA(trans_real_B)[:2] disc_rec_real_A, cls_rec_real_A = DiscA(rec_real_A)[:2] disc_rec_fake_A, cls_rec_fake_A = DiscA(rec_real_A)[:2] disc_real_B, cls_real_B = DiscB(x_real)[:2]; DiscB.set_reuse() disc_fake_B, cls_fake_B = DiscB(x_fake)[:2]; disc_trans_real_A, cls_trans_real_A = DiscB(trans_real_A)[:2] disc_trans_fake_A, cls_trans_fake_A = DiscB(trans_real_A)[:2] disc_rec_real_B, cls_rec_real_B = DiscB(rec_real_B)[:2] disc_rec_fake_B, cls_rec_fake_B = DiscB(rec_real_B)[:2] """ side_noise_A = tf.concat([noise_A, label_A], axis=1, name='side_noise_A') side_noise_B = tf.concat([noise_B, label_B], axis=1, name='side_noise_B') trans_real_A = TransF(real_A_sample) TransF.set_reuse() # domain B trans_real_B = TransB(real_B_sample) TransB.set_reuse() # domain A TransF.trans_real_A = trans_real_A TransB.trans_real_B = trans_real_B rec_real_A = TransB(trans_real_A) rec_real_B = TransF(trans_real_B) # start fake building if TFLAGS['side_noise']: TransF.side_noise = side_noise_A TransB.side_noise = side_noise_B trans_fake_A = TransF(real_A_sample) trans_fake_B = TransB(real_B_sample) DiscA.fake_sample = fake_A_sample DiscB.fake_sample = fake_B_sample disc_fake_A_sample, cls_fake_A_sample = DiscA(fake_A_sample)[:2] DiscA.set_reuse() disc_fake_B_sample, cls_fake_B_sample = DiscB(fake_B_sample)[:2] DiscB.set_reuse() disc_real_A_sample, cls_real_A_sample = DiscA(real_A_sample)[:2] disc_real_B_sample, cls_real_B_sample = DiscB(real_B_sample)[:2] disc_trans_real_A, cls_trans_real_A = DiscB(trans_real_A)[:2] disc_trans_real_B, cls_trans_real_B = DiscA(trans_real_B)[:2] if TFLAGS['side_noise']: disc_trans_fake_A, cls_trans_fake_A = DiscB(trans_real_A)[:2] disc_trans_fake_B, cls_trans_fake_B = DiscA(trans_real_B)[:2] def disc_loss(disc_fake_out, disc_real_out, disc_model, adv_weight=1.0, name="NaiveDisc", acc=True): softL_c = 0.05 with tf.name_scope(name): raw_disc_cost_real = tf.reduce_mean( tf.square(disc_real_out - tf.ones_like(disc_real_out) * np.abs(np.random.normal(1.0, softL_c))), name="raw_disc_cost_real") raw_disc_cost_fake = tf.reduce_mean( tf.square(disc_fake_out - tf.zeros_like(disc_fake_out)), name="raw_disc_cost_fake") disc_cost = tf.multiply( adv_weight, (raw_disc_cost_fake + raw_disc_cost_real) / 2, name="disc_cost") disc_fake_sum = [ tf.summary.scalar("DiscFakeRaw", raw_disc_cost_fake), tf.summary.scalar("DiscRealRaw", raw_disc_cost_real) ] if acc: disc_model.cost += disc_cost disc_model.sum_op.extend(disc_fake_sum) else: return disc_cost, disc_fake_sum def gen_loss(disc_fake_out, gen_model, adv_weight=1.0, name="Naive", acc=True): softL_c = 0.05 with tf.name_scope(name): raw_gen_cost = tf.reduce_mean( tf.square(disc_fake_out - tf.ones_like(disc_fake_out) * np.abs(np.random.normal(1.0, softL_c))), name="raw_gen_cost") gen_cost = tf.multiply(raw_gen_cost, adv_weight, name="gen_cost") if acc: gen_model.cost += gen_cost else: return gen_cost, [] disc_loss(disc_fake_A_sample, disc_real_A_sample, DiscA, adv_weight=TFLAGS['adv_weight'], name="DiscA_Loss") disc_loss(disc_fake_B_sample, disc_real_B_sample, DiscB, adv_weight=TFLAGS['adv_weight'], name="DiscA_Loss") gen_loss(disc_trans_real_A, TransF, adv_weight=TFLAGS['adv_weight'], name="NaiveGenA") gen_loss(disc_trans_real_B, TransB, adv_weight=TFLAGS['adv_weight'], name="NaiveTransA") if TFLAGS['side_noise']: # extra loss is for noise not equal to zero TransF.extra_loss = TransB.extra_loss = 0 TransF.extra_loss = tf.identity(TransF.cost) TransB.extra_loss = tf.identity(TransB.cost) cost_, _ = gen_loss(disc_trans_fake_A, TransF, adv_weight=TFLAGS['adv_weight'], name="NaiveGenAFake", acc=False) TransF.extra_loss += cost_ cost_, _ = gen_loss(disc_trans_fake_B, TransB, adv_weight=TFLAGS['adv_weight'], name="NaiveGenBFake", acc=False) TransB.extra_loss += cost_ #GANLoss(disc_rec_A, DiscA, TransB, adv_weight=TFLAGS['adv_weight'], name="NaiveGenA") #GANLoss(disc_rec_B, DiscA, TransF, adv_weight=TFLAGS['adv_weight'], name="NaiveTransA") # cycle consistent loss def cycle_loss(trans, origin, weight=10.0, name="cycle"): with tf.name_scope(name): # using gray trans = tf.reduce_mean(trans, axis=[3]) origin = tf.reduce_mean(origin, axis=[3]) cost_ = tf.reduce_mean(tf.abs(trans - origin)) * weight sum_ = tf.summary.scalar("Rec", cost_) return cost_, sum_ cost_, sum_ = cycle_loss(rec_real_A, real_A_sample, TFLAGS['AE_weight'], name="cycleA") TransF.cost += cost_ TransB.cost += cost_ TransF.sum_op.append(sum_) cost_, sum_ = cycle_loss(rec_real_B, real_B_sample, TFLAGS['AE_weight'], name="cycleB") TransF.cost += cost_ TransB.cost += cost_ TransB.sum_op.append(sum_) clsB_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=cls_real_B_sample, labels=c_label), name="clsB_real") if TFLAGS['side_noise']: clsB_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=cls_trans_fake_A, labels=c_label), name="clsB_fake") DiscB.cost += clsB_real * real_cls_weight DiscB.sum_op.extend([tf.summary.scalar("RealCls", clsB_real)]) if TFLAGS['side_noise']: DiscB.extra_loss = clsB_fake * fake_cls_weight TransF.extra_loss += clsB_fake * fake_cls_weight # extra loss for integrate stochastic cost_, sum_ = cycle_loss(rec_real_A, real_A_sample, TFLAGS['AE_weight'], name="cycleADisturb") TransF.extra_loss += cost_ TransF.sum_op.append(sum_) TransF.extra_loss += cls_trans_real_A * fake_cls_weight # add interval summary edge_num = int(np.sqrt(TFLAGS['batch_size'])) if edge_num > 4: edge_num = 4 grid_real_A = ops.get_grid_image_summary(real_A_sample, edge_num) inter_sum_op.append(tf.summary.image("Real A", grid_real_A)) grid_real_B = ops.get_grid_image_summary(real_B_sample, edge_num) inter_sum_op.append(tf.summary.image("Real B", grid_real_B)) grid_trans_A = ops.get_grid_image_summary(trans_real_A, edge_num) inter_sum_op.append(tf.summary.image("Trans A", grid_trans_A)) grid_trans_B = ops.get_grid_image_summary(trans_real_B, edge_num) inter_sum_op.append(tf.summary.image("Trans B", grid_trans_B)) if TFLAGS['side_noise']: grid_fake_A = ops.get_grid_image_summary(trans_fake_A, edge_num) inter_sum_op.append(tf.summary.image("Fake A", grid_fake_A)) grid_fake_B = ops.get_grid_image_summary(trans_fake_B, edge_num) inter_sum_op.append(tf.summary.image("Fake B", grid_fake_B)) # merge summary op for m, n in zip(models, model_names): m.cost = tf.identity(m.cost, "Total" + n) m.sum_op.append(tf.summary.scalar("Total" + n, m.cost)) m.sum_op = tf.summary.merge(m.sum_op) m.get_trainable_variables() inter_sum_op = tf.summary.merge(inter_sum_op) # Not applying update op will result in failure update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): for m, n in zip(models, model_names): m.train_op = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.9).minimize( m.cost, var_list=m.vars) if m.extra_loss is not 0: m.extra_train_op = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.9).minimize( m.extra_loss, var_list=m.vars) print("=> ##### %s Variable #####" % n) m.print_trainble_vairables() print("=> ##### All Variable #####") for v in tf.trainable_variables(): print("%s" % v.name) ctrl_weight = { "real_cls": real_cls_weight, "fake_cls": fake_cls_weight, "adv": adv_weight, "lr": lr } # basic settings trainer_feed = { "ctrl_weight": ctrl_weight, "dataloader": dl, "int_sum_op": inter_sum_op, "gpu_mem": TFLAGS['gpu_mem'], "FLAGS": FLAGS, "TFLAGS": TFLAGS } for m, n in zip(models, model_names): trainer_feed.update({n: m}) Trainer = trainer.cycle_trainer.CycleTrainer if TFLAGS['side_noise']: Trainer.noises = [noise_A, noise_B] Trainer.labels = [label_A, label_B] # input trainer_feed.update({"inputs": [real_A_sample, real_B_sample, c_label]}) ModelTrainer = Trainer(**trainer_feed) command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer) command_controller.start_thread() ModelTrainer.init_training() DISC_PATH = [ 'success/goodmodel_dragan_sketch2_disc.npz', 'success/goodmodel_dragan_anime1_disc.npz' ] DiscA.load_from_npz(DISC_PATH[0], ModelTrainer.sess) DiscB.load_from_npz(DISC_PATH[1], ModelTrainer.sess) ModelTrainer.train()