def __init__(self, num_channels=3, ngf=100, cg=0.05, lr=1e-4, train_BR=False): super(Blade_runner, self).__init__() self.discriminator = GAN.define_D(input_nc=3, ndf=64) self.generator = nn.Sequential( # input is (nc) x 32 x 32 nn.Conv2d(num_channels, ngf, 3, 1, 1, bias=True), nn.LeakyReLU(0.2, inplace=True), # nn.Dropout2d(), # state size. 48 x 32 x 32 nn.Conv2d(ngf, ngf, 3, 1, 1, bias=True), nn.LeakyReLU(0.2, inplace=True), # nn.Dropout2d(), # state size. 48 x 32 x 32 nn.Conv2d(ngf, ngf, 3, 1, 1, bias=True), nn.LeakyReLU(0.2, inplace=True), # nn.Dropout(), # state size. 48 x 32 x 32 nn.Conv2d(ngf, ngf, 3, 1, 1, bias=True), nn.LeakyReLU(0.2, inplace=True), # nn.Dropout(), # state size. 48 x 32 x 32 nn.Conv2d(ngf, ngf, 3, 1, 1, bias=True), nn.LeakyReLU(0.2, inplace=True), # state size. 48 x 32 x 32 nn.Conv2d(ngf, ngf, 3, 1, 1, bias=True), nn.LeakyReLU(0.2, inplace=True), # state size. 48 x 32 x 32 nn.Conv2d(ngf, ngf, 1, 1, 0, bias=True), nn.LeakyReLU(0.2, inplace=True), # state size. 3 x 32 x 32 nn.Conv2d(ngf, num_channels, 1, 1, 0, bias=True), nn.Tanh()) self.cuda = torch.cuda.is_available() if self.cuda: self.generator.cuda() device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.generator = torch.nn.DataParallel(self.generator).to(device) cudnn.benchmark = True self.cg = cg self.criterionGAN = GAN.GANLoss() self.optimizer = optim.Adam(self.generator.parameters(), lr=lr) self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.train_BR = train_BR self.max_iter = 20 # self.c_feature_weights = 0.002 self.c_feature_weights = [0.00010, 0.00020] self.c_feature_confidence = 0 self.c_misclassify = 1 self.confidence = 0
def main(): d = gaussian_data with tf.Session() as sess: model = GAN(sess, d, [8, 8, 8], [6, 6, 6], x_dim=1, z_dim=1, lr=0.0001, k=1, std=3.0, mean=3.0) model.train(100000, 100) decision, gd, dd, x = model.sample(num_points=num_points) plot(decision, gd, dd, x)
def main(args): with tf.Session() as sess: training = args.training dataset_name = args.dataset if training == 1: if os.path.exists(summary_path): for fname in os.listdir(summary_path): os.remove(summary_path + '/' + fname) else: os.mkdir(path) dataset = load_data(dataset_name, bs=batch_size) gan = GAN() trainer = Trainer(gan, sess, dataset, 200) init = tf.global_variables_initializer() sess.run(init) trainer.train() elif training == 0: gan = GAN() saver = tf.train.Saver() saver = saver.restore(sess, model_save_path) image = sess.run(gan.x_hat, feed_dict={ gan.z: np.random.uniform(-1, 1, (batch_size, num_z)) }) image = (image + 1) / 2 save_generated_examples(image, 'examples')
def main(): sess = tf.Session() config = Config() data = input_data.read_data_sets(config.data_dir, one_hot=True) gan = GAN(sess, config, data) show_all_variables() gan.train()
def __init__(self): self.epochs = 11000 self.batch_size = 128 self.noise_size = 128 self.sample_interval = 250 self.loss_interval = 10 self.half_batch = int(self.batch_size / 2) self.generator, self.discriminator, self.combined = GAN() start_epoch = 19999
def main(): # set session sess = tf.Session() model = GAN(sess=sess, init=False, gf_dim=128) model.restore(model_path='hw3_1/model_file/WGAN_v2') z_plot = np.random.uniform(-1., 1., size=[25, 100]) img = model.generate(z_plot) plot_samples(img, save=True, h=5, w=5, filename='gan', folder_path='samples/')
def test(args): model = GAN(batch_size=args.bs, noise_dim=args.noise_dim, learning_rate=args.lr, trainable=True) model.build() saver = tf.train.Saver(max_to_keep=20) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) latest_ckpt = tf.train.latest_checkpoint(args.loadpath) saver.restore(sess, latest_ckpt) print('restore from', latest_ckpt) batch_y_tot = gen_testdata() # batch_y_tot = gen_fromfile(args.testfile) new_im = Image.new('RGB', (64 * 5, 64 * len(batch_y_tot))) for j in range(len(batch_y_tot)): batch_y = np.tile(batch_y_tot[j], (5, 1)) noise = np.random.uniform(-1, 1, [batch_y.shape[0], args.noise_dim]) generated_test = sess.run(model.sampler, feed_dict={ model.noises: noise, model.labels: batch_y }) for i in range(5): generated = (generated_test[i] + 1) * 127.5 # scale from [-1., 1.] to [0., 255.] generated = np.clip(generated, 0., 255.).astype(np.uint8) generated = misc.imresize(generated, [64, 64, 3]) gen_path = 'samples/sample_' + str(j + 1) + '_' + str(i + 1) + '.jpg' misc.imsave(gen_path, generated) new_im.paste(Image.fromarray(generated, "RGB"), (64 * i, 64 * j)) path = 'samples_' + str(j) pickle.dump(noise, open(path, 'wb')) gen_path = 'samples/' + '1' + '.jpg' new_im.save(gen_path) print('gen results:', len(batch_y_tot), '* 5 in samples')
def __init__(self, flags): run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True self.sess = tf.Session(config=run_config) self.flags = flags self.dataset = Dataset(self.flags.dataset, self.flags) self.model = GAN(self.sess, self.flags, self.dataset.image_size) self.best_auc_sum = 0. self._make_folders() self.saver = tf.train.Saver() self.sess.run(tf.global_variables_initializer()) tf_utils.show_all_variables()
def run(): size_latent = 256 X, input_shape = data() discriminator = Discriminator(input_shape).model generator = Generator(size_latent, input_shape).model gan = GAN(generator, discriminator).model train(X, generator, discriminator, gan, size_latent)
def train(): model = GAN() vars_d = slim.get_variables(scope='discriminator') d_optimizer = tf.train.RMSPropOptimizer(8e-4, decay=6e-8) d_train_step = d_optimizer.minimize(model.d_loss, var_list=vars_d) vars_g = slim.get_variables(scope='generator') g_optimizer = tf.train.RMSPropOptimizer(4e-4, decay=3e-8) g_train_step = g_optimizer.minimize(model.g_loss, var_list=vars_g) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in trange(n_epochs): writer = tf.summary.FileWriter(f'summaries/{epoch}', sess.graph) n_iters = int(np.ceil(len(mnist.train.labels) / batch_size)) for i in trange(n_iters, desc=f'Epoch {epoch}', leave=False): x, _ = mnist.train.next_batch(batch_size) z = np.random.uniform(-1, 1, size=[batch_size, 100]) _, _, summary = sess.run( [d_train_step, g_train_step, model.summary], feed_dict={ model.x: x, model.z: z, model.keep_prob: 0.5, }) writer.add_summary(summary, epoch * n_iters + i) saver.save(sess, model_path)
def __main__(): if not os.path.exists('./images'): os.makedirs('./images') if not os.path.exists('./dataset'): os.makedirs('./dataset') gan = GAN(image_path = "./dataset/train.tfrecords", image_num = 17720) gan()
def main(_): if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) with tf.Session() as sess: gan = GAN( sess, config=FLAGS, batch_size=FLAGS.batch_size, sample_size=FLAGS.sample_size, logdir=FLAGS.log_dir, ) tf.initialize_all_variables().run() gan.train(FLAGS) global gan
def train(X, generator, discriminator, GAN, size_latent, batch=256, epoch=25): count_batch = int(X.shape[0] / batch) for i in range(epoch): print("EPOCH: %d" % (i+1)) for j in range(count_batch): real_images, real_image_labels = sample_real(X, int(batch / 2)) fake_images, fake_image_labels = sample_fake(generator, int(batch / 2), size_latent) latent_points = sample_latent_point(batch, size_latent) latent_points_labels = np.ones((batch, 1)) discriminator.train_on_batch( np.vstack((real_images, fake_images)), np.vstack((real_image_labels, fake_image_labels))) GAN.train_on_batch(latent_points, latent_points_labels) show_accuracy(X, generator, discriminator, size_latent) generator.save("gan.h5")
def main(unuse_args): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) # if not os.path.exists(FLAGS.train_dir): # os.makedirs(FLAGS.train_dir) # gpu = '/gpu:' + str(FLAGS.gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True with tf.Session(config=config) as sess: K.set_session(sess) gan = GAN(sess, model_name=FLAGS.model_name, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, is_crop=False, c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir) if FLAGS.is_train: print('==========Training the model!==========') gan.train(FLAGS) else: print('==========Test the model!==========') # gan.test(FLAGS) # gan.test_patch(FLAGS) # gan.test(FLAGS) gan.test_sidd(FLAGS)
def main(): # Parse the arguments from the command line args = parse_args() # Create output directory if it does not exist check_folder(args.output_dir) # Ensure checkpoint exists before testing if args.mode == "test" and args.checkpoint is None: raise Exception("Checkpoint is required for test mode") # Load the images from the input directory paths, inputs, targets, steps_per_epoch = load_images( args.input_dir, FLAGS.batch_size, args.mode) # Initialise the GAN before running model = GAN(args.input_dir, args.output_dir, args.checkpoint, paths, inputs, targets, FLAGS.batch_size, steps_per_epoch, FLAGS.ngf, FLAGS.ndf, FLAGS.lr, FLAGS.beta1, FLAGS.l1_weight, FLAGS.gan_weight) # Output images for model display_images = { "paths": paths, "inputs": tf.map_fn(tf.image.encode_png, convert(rgbxy_to_rgb(de_process(inputs))), dtype=tf.string, name="inputs_pngs"), "targets": tf.map_fn(tf.image.encode_png, convert(rgbxy_to_rgb(de_process(targets))), dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, convert(rgbxy_to_rgb(de_process(model.get_outputs()))), dtype=tf.string, name="output_pngs"), } sv = tf.train.Supervisor(logdir=args.output_dir, save_summaries_secs=0, saver=None) with sv.managed_session() as sess: # Train or test the initialised GAN based on the chosen mode if args.mode == "train": model.train(sv, sess, FLAGS.max_epochs, FLAGS.progress_freq, FLAGS.save_freq) else: model.test(sess, display_images)
def train(FLAGS): assert os.path.exists(FLAGS.data_dir) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) ### data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) ### Session sess = tf.Session() ### model model = GAN(sess, FLAGS) for i in range(FLAGS.epoch): D_loss, G_loss = 0, 0 for j in range(FLAGS.batch_num): #for j in range(3): X_mb, _ = mnist.train.next_batch(FLAGS.batch_size) ### 200, 784 dis_loss, gen_loss = model.train(X_mb) D_loss += dis_loss G_loss += gen_loss print('Epoch {}, Discriminator loss {}, generator loss {}'.format( i, D_loss, G_loss))
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) with tf.Session() as sess: model = GAN(sess, dataset_dir=FLAGS.dataset_dir, checkpoint_dir=FLAGS.checkpoint_dir, model_dir=FLAGS.model_dir, sample_dir=FLAGS.sample_dir, epoch=FLAGS.epoch, batch_size=FLAGS.batch_size, z_dim=FLAGS.z_dim, epoch_to_sample=FLAGS.epoch_to_sample) if FLAGS.is_train: model.train() else: pass
def main(): args = parse_args() if args is None: exit() if args.benchmark_mode: torch.backends.cudnn.benchmark = True gan = GAN(args) gan.train() print('Training finished') gan.visualize_results(args.epoch) print('Testing finished')
def main(): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # fcn = FCN_mse(2).cuda() # fcn.load_state_dict(torch.load('/home/wilson/causal-infogan/data/FCN_mse')) # fcn.eval() fcn = None def filter_background(x): x[:, (x < 0.3).any(dim=0)] = 0.0 return x def dilate(x): x = x.squeeze(0).numpy() x = grey_dilation(x, size=5) x = x[None, :, :] return torch.from_numpy(x) transform = transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), filter_background, lambda x: x.mean(dim=0)[None, :, :], dilate, transforms.Normalize((0.5, ), (0.5, )), ]) dataset = ImageFolder(args.root, transform=transform) loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2) model = GAN(32, 1).cuda() # model = BigGAN((1, 64, 64), z_dim=32).cuda() train(model, fcn, loader)
def train(): with tf.device("/cpu:0"): if FLAGS.load_model is not None: if FLAGS.savefile is not None: checkpoints_dir = FLAGS.savefile + "/checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") if FLAGS.savefile is not None: checkpoints_dir = FLAGS.savefile + "/checkpoints/{}".format( current_time) else: checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir + "/samples") except os.error: pass for attr, value in FLAGS.flag_values_dict().items(): logging.info("%s\t:\t%s" % (attr, str(value))) graph = tf.Graph() with graph.as_default(): gan = GAN(FLAGS.image_size, FLAGS.learning_rate, FLAGS.batch_size, FLAGS.ngf) input_shape = [ int(FLAGS.batch_size / 4), FLAGS.image_size[0], FLAGS.image_size[1], FLAGS.image_size[2] ] G_optimizer, D_optimizer = gan.optimize() G_grad_list = [] D_grad_list = [] with tf.variable_scope(tf.get_variable_scope()): with tf.device("/gpu:0"): with tf.name_scope("GPU_0"): x_0 = tf.placeholder(tf.float32, shape=input_shape) y_0 = tf.placeholder(tf.float32, shape=input_shape) z_0 = tf.placeholder(tf.float32, shape=input_shape) w_0 = tf.placeholder(tf.float32, shape=input_shape) s_0 = tf.placeholder(tf.float32, shape=input_shape) m_0 = tf.placeholder(tf.float32, shape=input_shape) l_0 = tf.placeholder(tf.float32, shape=input_shape) loss_list_0, image_list_0, judge_list_0 = gan.model( l_0, m_0, s_0, x_0, y_0, z_0, w_0) variables_list_0 = gan.get_variables() G_grad_0 = G_optimizer.compute_gradients( loss_list_0[0], var_list=variables_list_0[0]) D_grad_0 = D_optimizer.compute_gradients( loss_list_0[1], var_list=variables_list_0[1]) G_grad_list.append(G_grad_0) D_grad_list.append(D_grad_0) with tf.device("/gpu:1"): with tf.name_scope("GPU_1"): x_1 = tf.placeholder(tf.float32, shape=input_shape) y_1 = tf.placeholder(tf.float32, shape=input_shape) z_1 = tf.placeholder(tf.float32, shape=input_shape) w_1 = tf.placeholder(tf.float32, shape=input_shape) s_1 = tf.placeholder(tf.float32, shape=input_shape) m_1 = tf.placeholder(tf.float32, shape=input_shape) l_1 = tf.placeholder(tf.float32, shape=input_shape) loss_list_1, image_list_1, judge_list_1 = gan.model( l_0, m_1, s_1, x_1, y_1, z_1, w_1) variables_list_1 = gan.get_variables() G_grad_1 = G_optimizer.compute_gradients( loss_list_1[0], var_list=variables_list_1[0]) D_grad_1 = D_optimizer.compute_gradients( loss_list_1[1], var_list=variables_list_1[1]) G_grad_list.append(G_grad_1) D_grad_list.append(D_grad_1) with tf.device("/gpu:2"): with tf.name_scope("GPU_2"): x_2 = tf.placeholder(tf.float32, shape=input_shape) y_2 = tf.placeholder(tf.float32, shape=input_shape) z_2 = tf.placeholder(tf.float32, shape=input_shape) w_2 = tf.placeholder(tf.float32, shape=input_shape) s_2 = tf.placeholder(tf.float32, shape=input_shape) m_2 = tf.placeholder(tf.float32, shape=input_shape) l_2 = tf.placeholder(tf.float32, shape=input_shape) loss_list_2, image_list_2, judge_list_2 = gan.model( l_2, m_2, s_2, x_2, y_2, z_2, w_2) variables_list_2 = gan.get_variables() G_grad_2 = G_optimizer.compute_gradients( loss_list_2[0], var_list=variables_list_2[0]) D_grad_2 = D_optimizer.compute_gradients( loss_list_2[1], var_list=variables_list_2[1]) G_grad_list.append(G_grad_2) D_grad_list.append(D_grad_2) with tf.device("/gpu:3"): with tf.name_scope("GPU_3"): x_3 = tf.placeholder(tf.float32, shape=input_shape) y_3 = tf.placeholder(tf.float32, shape=input_shape) z_3 = tf.placeholder(tf.float32, shape=input_shape) w_3 = tf.placeholder(tf.float32, shape=input_shape) s_3 = tf.placeholder(tf.float32, shape=input_shape) m_3 = tf.placeholder(tf.float32, shape=input_shape) l_3 = tf.placeholder(tf.float32, shape=input_shape) loss_list_3, image_list_3, judge_list_3 = gan.model( l_3, m_3, s_3, x_3, y_3, z_3, w_3) tensor_name_dirct = gan.tenaor_name variables_list_3 = gan.get_variables() G_grad_3 = G_optimizer.compute_gradients( loss_list_3[0], var_list=variables_list_3[0]) D_grad_3 = D_optimizer.compute_gradients( loss_list_3[1], var_list=variables_list_3[1]) G_grad_list.append(G_grad_3) D_grad_list.append(D_grad_3) G_ave_grad = average_gradients(G_grad_list) D_ave_grad = average_gradients(D_grad_list) G_optimizer_op = G_optimizer.apply_gradients(G_ave_grad) D_optimizer_op = D_optimizer.apply_gradients(D_ave_grad) optimizers = [G_optimizer_op, D_optimizer_op] saver = tf.train.Saver() variables_list = gan.get_variables() with tf.Session( graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess: if FLAGS.load_model is not None: logging.info("restore model:" + FLAGS.load_model) if FLAGS.checkpoint is not None: model_checkpoint_path = checkpoints_dir + "/model.ckpt-" + FLAGS.checkpoint latest_checkpoint = model_checkpoint_path else: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) model_checkpoint_path = checkpoint.model_checkpoint_path latest_checkpoint = tf.train.latest_checkpoint( checkpoints_dir) logging.info("model checkpoint path:" + model_checkpoint_path) meta_graph_path = model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, latest_checkpoint) if FLAGS.step_clear == True: step = 0 else: step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 if FLAGS.load_GT_model is not None: trans_latest_checkpoint = tf.train.latest_checkpoint( "checkpoints/" + FLAGS.load_GT_model) trans_saver = tf.train.Saver(variables_list[2]) trans_saver.restore(sess, trans_latest_checkpoint) if FLAGS.load_GL_model is not None: lesion_latest_checkpoint = tf.train.latest_checkpoint( "checkpoints/" + FLAGS.load_GL_model) lesion_saver = tf.train.Saver(variables_list[3]) lesion_saver.restore(sess, lesion_latest_checkpoint) sess.graph.finalize() logging.info("start step:" + str(step)) try: logging.info("tensor_name_dirct:\n" + str(tensor_name_dirct)) s_train_files = read_filename(FLAGS.S) l_train_files = read_filename(FLAGS.L) index = 0 epoch = 0 while epoch <= FLAGS.epoch: train_true_x = [] train_true_y = [] train_true_z = [] train_true_w = [] train_true_s = [] train_true_m = [] train_true_l = [] for b in range(FLAGS.batch_size): train_x_arr = read_file(FLAGS.X, s_train_files, index).reshape( FLAGS.image_size) train_y_arr = read_file(FLAGS.Y, s_train_files, index).reshape( FLAGS.image_size) train_z_arr = read_file(FLAGS.Z, s_train_files, index).reshape( FLAGS.image_size) train_w_arr = read_file(FLAGS.W, s_train_files, index).reshape( FLAGS.image_size) train_s_arr = read_file(FLAGS.S, s_train_files, index).reshape( FLAGS.image_size) train_m_arr = read_file(FLAGS.M, s_train_files, index).reshape( FLAGS.image_size) train_l_arr = read_file(FLAGS.L, l_train_files, index).reshape( FLAGS.image_size) train_true_x.append(train_x_arr) train_true_y.append(train_y_arr) train_true_z.append(train_z_arr) train_true_w.append(train_w_arr) train_true_s.append(train_s_arr) train_true_m.append(train_m_arr) train_true_l.append(train_l_arr) epoch = int(index / len(s_train_files)) index = index + 1 logging.info("-----------train epoch " + str(epoch) + ", step " + str(step) + ": start-------------") sess.run(optimizers, feed_dict={ x_0: np.asarray(train_true_x) [0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], y_0: np.asarray(train_true_y)[ 0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], z_0: np.asarray(train_true_z)[ 0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], w_0: np.asarray(train_true_w)[ 0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], s_0: np.asarray(train_true_s)[ 0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], m_0: np.asarray(train_true_m)[ 0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], l_0: np.asarray(train_true_l)[ 0 * int(FLAGS.batch_size / 4):1 * int(FLAGS.batch_size / 4), :, :, :], x_1: np.asarray(train_true_x)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], y_1: np.asarray(train_true_y)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], z_1: np.asarray(train_true_z)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], w_1: np.asarray(train_true_w)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], s_1: np.asarray(train_true_s)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], m_1: np.asarray(train_true_m)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], l_1: np.asarray(train_true_l)[ 1 * int(FLAGS.batch_size / 4):2 * int(FLAGS.batch_size / 4), :, :, :], x_2: np.asarray(train_true_x)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], y_2: np.asarray(train_true_y)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], z_2: np.asarray(train_true_z)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], w_2: np.asarray(train_true_w)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], s_2: np.asarray(train_true_s)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], m_2: np.asarray(train_true_m)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], l_2: np.asarray(train_true_l)[ 2 * int(FLAGS.batch_size / 4):3 * int(FLAGS.batch_size / 4), :, :, :], x_3: np.asarray(train_true_x)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], y_3: np.asarray(train_true_y)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], z_3: np.asarray(train_true_z)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], w_3: np.asarray(train_true_w)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], s_3: np.asarray(train_true_s)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], m_3: np.asarray(train_true_m)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], l_3: np.asarray(train_true_l)[ 3 * int(FLAGS.batch_size / 4):4 * int(FLAGS.batch_size / 4), :, :, :], }) logging.info("-----------train epoch " + str(epoch) + ", step " + str(step) + ": end-------------") step += 1 except Exception as e: logging.info("ERROR:" + str(e)) save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path)
type=bool, default=True) parser.add_argument("--sample_image_dir", help="Directory containing sample images (Used only if save_samples is True; Default = samples)", default='samples') parser.add_argument("--A_dir", help="Directory containing the input images for training, testing or inference (Default = A)", default='A') parser.add_argument("--B_dir", help="Directory containing the target images for training or testing. In inference mode, this is used to store results (Default = B)", default='B') parser.add_argument("--custom_data", help="Using your own data as input and target (Default = True)", type=bool, default=True) parser.add_argument("--val_fraction", help="Fraction of dataset to be split for validation (Default = 0.15)", type=float, default=0.15) parser.add_argument("--val_threshold", help="Number of steps to wait before validation is enabled. (Default = 0)", type=int, default=0) parser.add_argument("--val_frequency", help="Number of batches to wait before perfoming the next validation run (Default = 20)", type=int, default=20) parser.add_argument("--logger_frequency", help="Number of batches to wait before logging the next set of loss values (Default = 20)", type=int, default=20) parser.add_argument("--mode", help="Select between train, test or inference modes", default='train', choices=['train', 'test', 'inference']) if __name__ == '__main__': args = parser.parse_args() net = GAN(args) if args.mode == 'train': net.train() if args.mode == 'test': net.test(args.A_dir, args.B_dir) if args.mode == 'inference': net.inference(args.A_dir, args.B_dir)
def train(): with tf.device("/cpu:0"): if FLAGS.load_model is not None: if FLAGS.savefile is not None: checkpoints_dir = FLAGS.savefile + "/checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") if FLAGS.savefile is not None: checkpoints_dir = FLAGS.savefile + "/checkpoints/{}".format( current_time) else: checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir + "/samples") except os.error: pass for attr, value in FLAGS.flag_values_dict().items(): logging.info("%s\t:\t%s" % (attr, str(value))) graph = tf.get_default_graph() gan = GAN(FLAGS.image_size, FLAGS.learning_rate, FLAGS.batch_size, FLAGS.ngf) input_shape = [ int(FLAGS.batch_size / 4), FLAGS.image_size[0], FLAGS.image_size[1], FLAGS.image_size[2] ] G_optimizer, D_optimizer = gan.optimize() G_grad_list = [] D_grad_list = [] with tf.variable_scope(tf.get_variable_scope()): with tf.device("/gpu:0"): with tf.name_scope("GPU_0"): l_x_0 = tf.placeholder(tf.float32, shape=input_shape) l_y_0 = tf.placeholder(tf.float32, shape=input_shape) l_z_0 = tf.placeholder(tf.float32, shape=input_shape) l_w_0 = tf.placeholder(tf.float32, shape=input_shape) x_0 = tf.placeholder(tf.float32, shape=input_shape) y_0 = tf.placeholder(tf.float32, shape=input_shape) z_0 = tf.placeholder(tf.float32, shape=input_shape) w_0 = tf.placeholder(tf.float32, shape=input_shape) loss_list_0, image_list_0, code_list_0, j_list_0 = gan.model( l_x_0, l_y_0, l_z_0, l_w_0, x_0, y_0, z_0, w_0) tensor_name_dirct_0 = gan.tensor_name variables_list_0 = gan.get_variables() G_grad_0 = G_optimizer.compute_gradients( loss_list_0[0], var_list=variables_list_0[0]) D_grad_0 = D_optimizer.compute_gradients( loss_list_0[1], var_list=variables_list_0[1]) G_grad_list.append(G_grad_0) D_grad_list.append(D_grad_0) with tf.device("/gpu:1"): with tf.name_scope("GPU_1"): l_x_1 = tf.placeholder(tf.float32, shape=input_shape) l_y_1 = tf.placeholder(tf.float32, shape=input_shape) l_z_1 = tf.placeholder(tf.float32, shape=input_shape) l_w_1 = tf.placeholder(tf.float32, shape=input_shape) x_1 = tf.placeholder(tf.float32, shape=input_shape) y_1 = tf.placeholder(tf.float32, shape=input_shape) z_1 = tf.placeholder(tf.float32, shape=input_shape) w_1 = tf.placeholder(tf.float32, shape=input_shape) loss_list_1, image_list_1, code_list_1, j_list_1 = gan.model( l_x_1, l_y_1, l_z_1, l_w_1, x_1, y_1, z_1, w_1) variables_list_1 = gan.get_variables() G_grad_1 = G_optimizer.compute_gradients( loss_list_1[0], var_list=variables_list_1[0]) D_grad_1 = D_optimizer.compute_gradients( loss_list_1[1], var_list=variables_list_1[1]) G_grad_list.append(G_grad_1) D_grad_list.append(D_grad_1) with tf.device("/gpu:2"): with tf.name_scope("GPU_2"): l_x_2 = tf.placeholder(tf.float32, shape=input_shape) l_y_2 = tf.placeholder(tf.float32, shape=input_shape) l_z_2 = tf.placeholder(tf.float32, shape=input_shape) l_w_2 = tf.placeholder(tf.float32, shape=input_shape) x_2 = tf.placeholder(tf.float32, shape=input_shape) y_2 = tf.placeholder(tf.float32, shape=input_shape) z_2 = tf.placeholder(tf.float32, shape=input_shape) w_2 = tf.placeholder(tf.float32, shape=input_shape) loss_list_2, image_list_2, code_list_2, j_list_2 = gan.model( l_x_2, l_y_2, l_z_2, l_w_2, x_2, y_2, z_2, w_2) variables_list_2 = gan.get_variables() G_grad_2 = G_optimizer.compute_gradients( loss_list_2[0], var_list=variables_list_2[0]) D_grad_2 = D_optimizer.compute_gradients( loss_list_2[1], var_list=variables_list_2[1]) G_grad_list.append(G_grad_2) D_grad_list.append(D_grad_2) with tf.device("/gpu:3"): with tf.name_scope("GPU_3"): l_x_3 = tf.placeholder(tf.float32, shape=input_shape) l_y_3 = tf.placeholder(tf.float32, shape=input_shape) l_z_3 = tf.placeholder(tf.float32, shape=input_shape) l_w_3 = tf.placeholder(tf.float32, shape=input_shape) x_3 = tf.placeholder(tf.float32, shape=input_shape) y_3 = tf.placeholder(tf.float32, shape=input_shape) z_3 = tf.placeholder(tf.float32, shape=input_shape) w_3 = tf.placeholder(tf.float32, shape=input_shape) loss_list_3, image_list_3, code_list_3, j_list_3 = gan.model( l_x_3, l_y_3, l_z_3, l_w_3, x_3, y_3, z_3, w_3) variables_list_3 = gan.get_variables() G_grad_3 = G_optimizer.compute_gradients( loss_list_3[0], var_list=variables_list_3[0]) D_grad_3 = D_optimizer.compute_gradients( loss_list_3[1], var_list=variables_list_3[1]) G_grad_list.append(G_grad_3) D_grad_list.append(D_grad_3) G_ave_grad = average_gradients(G_grad_list) D_ave_grad = average_gradients(D_grad_list) G_optimizer_op = G_optimizer.apply_gradients(G_ave_grad) D_optimizer_op = D_optimizer.apply_gradients(D_ave_grad) optimizers = [G_optimizer_op, D_optimizer_op] saver = tf.train.Saver() variables_list = gan.get_variables() with tf.Session( graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(tf.global_variables_initializer()) if FLAGS.load_model is not None: logging.info("restore model:" + FLAGS.load_model) if FLAGS.checkpoint is not None: model_checkpoint_path = checkpoints_dir + "/model.ckpt-" + FLAGS.checkpoint latest_checkpoint = model_checkpoint_path else: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) model_checkpoint_path = checkpoint.model_checkpoint_path latest_checkpoint = tf.train.latest_checkpoint( checkpoints_dir) logging.info("model checkpoint path:" + model_checkpoint_path) meta_graph_path = model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, latest_checkpoint) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: step = 0 if FLAGS.load_trans_model is not None: trans_checkpoints_dir = "checkpoints/" + FLAGS.load_trans_model.lstrip( "checkpoints/") trans_latest_checkpoint = tf.train.latest_checkpoint( trans_checkpoints_dir) trans_saver = tf.train.Saver(variables_list[0] + variables_list[1]) trans_saver.restore(sess, trans_latest_checkpoint) if FLAGS.load_seg_model is not None: seg_checkpoints_dir = "checkpoints/" + FLAGS.load_seg_model.lstrip( "checkpoints/") seg_latest_checkpoint = tf.train.latest_checkpoint( seg_checkpoints_dir) seg_saver = tf.train.Saver(variables_list[2]) seg_saver.restore(sess, seg_latest_checkpoint) sess.graph.finalize() logging.info("start step:" + str(step)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: logging.info("tensor_name_dirct:\n" + str(tensor_name_dirct_0)) l_x_train_files = read_filename(FLAGS.L) l_y_train_files = read_filename(FLAGS.L) l_z_train_files = read_filename(FLAGS.L) l_w_train_files = read_filename(FLAGS.L) index = 0 epoch = 0 while not coord.should_stop() and epoch <= FLAGS.epoch: train_true_l_x = [] train_true_l_y = [] train_true_l_z = [] train_true_l_w = [] train_true_x = [] train_true_y = [] train_true_z = [] train_true_w = [] for b in range(FLAGS.batch_size): train_l_x_arr = read_file(FLAGS.L, l_x_train_files, index).reshape( FLAGS.image_size) train_x_arr = read_file(FLAGS.X, l_x_train_files, index).reshape( FLAGS.image_size) train_l_y_arr = read_file(FLAGS.L, l_y_train_files, index).reshape( FLAGS.image_size) train_y_arr = read_file(FLAGS.Y, l_y_train_files, index).reshape( FLAGS.image_size) train_l_z_arr = read_file(FLAGS.L, l_z_train_files, index).reshape( FLAGS.image_size) train_z_arr = read_file(FLAGS.Z, l_z_train_files, index).reshape( FLAGS.image_size) train_l_w_arr = read_file(FLAGS.L, l_w_train_files, index).reshape( FLAGS.image_size) train_w_arr = read_file(FLAGS.W, l_w_train_files, index).reshape( FLAGS.image_size) train_true_l_x.append(train_l_x_arr) train_true_l_y.append(train_l_y_arr) train_true_l_z.append(train_l_z_arr) train_true_l_w.append(train_l_w_arr) train_true_x.append(train_x_arr) train_true_y.append(train_y_arr) train_true_z.append(train_z_arr) train_true_w.append(train_w_arr) epoch = int(index / len(l_x_train_files)) index = index + 1 sess.run( [optimizers], feed_dict={ l_x_0: np.asarray(train_true_l_x)[0:1, :, :, :], l_y_0: np.asarray(train_true_l_y)[0:1, :, :, :], l_z_0: np.asarray(train_true_l_z)[0:1, :, :, :], l_w_0: np.asarray(train_true_l_w)[0:1, :, :, :], x_0: np.asarray(train_true_x)[0:1, :, :, :], y_0: np.asarray(train_true_y)[0:1, :, :, :], z_0: np.asarray(train_true_z)[0:1, :, :, :], w_0: np.asarray(train_true_w)[0:1, :, :, :], l_x_1: np.asarray(train_true_l_x)[1:2, :, :, :], l_y_1: np.asarray(train_true_l_y)[1:2, :, :, :], l_z_1: np.asarray(train_true_l_z)[1:2, :, :, :], l_w_1: np.asarray(train_true_l_w)[1:2, :, :, :], x_1: np.asarray(train_true_x)[1:2, :, :, :], y_1: np.asarray(train_true_y)[1:2, :, :, :], z_1: np.asarray(train_true_z)[1:2, :, :, :], w_1: np.asarray(train_true_w)[1:2, :, :, :], l_x_2: np.asarray(train_true_l_x)[2:3, :, :, :], l_y_2: np.asarray(train_true_l_y)[2:3, :, :, :], l_z_2: np.asarray(train_true_l_z)[2:3, :, :, :], l_w_2: np.asarray(train_true_l_w)[2:3, :, :, :], x_2: np.asarray(train_true_x)[2:3, :, :, :], y_2: np.asarray(train_true_y)[2:3, :, :, :], z_2: np.asarray(train_true_z)[2:3, :, :, :], w_2: np.asarray(train_true_w)[2:3, :, :, :], l_x_3: np.asarray(train_true_l_x)[3:4, :, :, :], l_y_3: np.asarray(train_true_l_y)[3:4, :, :, :], l_z_3: np.asarray(train_true_l_z)[3:4, :, :, :], l_w_3: np.asarray(train_true_l_w)[3:4, :, :, :], x_3: np.asarray(train_true_x)[3:4, :, :, :], y_3: np.asarray(train_true_y)[3:4, :, :, :], z_3: np.asarray(train_true_z)[3:4, :, :, :], w_3: np.asarray(train_true_w)[3:4, :, :, :], }) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: if FLAGS.stage == 'train': save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) else: logging.info("finish test") coord.request_stop() coord.join(threads)
if os.path.exists(args.save_img_dir): shutil.rmtree(args.save_img_dir) if os.path.exists(args.log_dir): shutil.rmtree(args.log_dir) if not os.path.exists(args.save_img_dir): os.mkdir(args.save_img_dir) if not os.path.exists(args.log_dir): os.mkdir(args.log_dir) with tf.Graph().as_default() as graph: initializer = tf.random_uniform_initializer(-args.init_scale, args.init_scale) with tf.variable_scope('model', reuse=None, initializer=initializer) as scope: model = GAN(args) scope.reuse_variables() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.optimizer_options.global_jit_level =\ tf.OptimizerOptions.ON_1 sv = tf.train.Supervisor(logdir=args.log_dir, save_model_secs=args.save_model_secs) saver = sv.saver saver._max_to_keep = 1000 with sv.managed_session(config=config) as sess: save_noise = np.random.uniform(-1., 1., [10, args.noise_dim]) save_feat = to_categorical(np.arange(10), num_classes=10) save_noise_fill = np.concatenate(
import torch from config import Config from model import GAN from torch import optim import torch.nn as nn config = Config() model = GAN(config) model.train()
import preprocessing import matplotlib.pyplot as plt from model import GAN import pickle with open('animelist', 'rb') as f: images = pickle.load(f) gan = GAN() gan.train(images, epochs=15001, batch_size=256, save_images=1000, save_model=15000)
) gan = GAN( discriminator=sagan.discriminator, generator=sagan.generator, real_input_fn=functools.partial( celeba_input_fn, filenames=args.filenames, batch_size=args.batch_size, num_epochs=None, shuffle=True, image_size=[256, 256] ), fake_input_fn=lambda: ( tf.random_normal([args.batch_size, 512]), tf.one_hot(tf.reshape(tf.random.multinomial( logits=tf.log([tf.cast(attr_counts, tf.float32)]), num_samples=args.batch_size ), [args.batch_size]), len(attr_counts)) ), hyper_params=Param( discriminator_learning_rate=4e-4, discriminator_beta1=0.0, discriminator_beta2=0.9, generator_learning_rate=1e-4, generator_beta1=0.0, generator_beta2=0.9 ), name=args.model_dir ) config = tf.ConfigProto(
def train( root_path, batch_size, epochs=1, lr=.001, max_filters=256, min_filters=64, upsample_layers=3, noise_dim=64, blocks=8, device_ids=[0], # pass in list of device ids you want to use, if multiple will use DataParallel image_size=256, batch_shuffle=True, num_workers=0, wass_target=1, mse_weight=10, ttur=4, models_dir='./models/', results_dir='./results/', log_dir='./log/', model_name_prefix='test', save_every=20000, print_every=500, log_every=1000, conv_type='scaled', grad_mean=True, upsample_type='nearest', resume=True, checkpoint=-1, #only used if resume is true, load from which last save? default is latest straight_through_round=True, reset=False, ): #device = "cpu" #if torch.cuda.is_available(): # device = f'cuda:{device_id}' #create dataset/loader datasetname = '' if os.path.isdir(root_path): dataset = FolderDatasetDownsample(root_path, downsample=2**upsample_layers, size=image_size) datasetname='folder' else: dataset = LargeImageDataset(root_path, downsample=2**upsample_layers, size=image_size) datasetname='largeimage' sample_indices = torch.randint(0, len(dataset), (batch_size,)) samples_list = [dataset[i] for i in sample_indices] lores_list, hires_list = zip(*samples_list) lores_samples = torch.stack(lores_list) hires_samples = torch.stack(hires_list) samples = lores_samples, hires_samples dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=batch_shuffle, pin_memory=False) model_name = f'{model_name_prefix}.{datasetname}.{batch_size}.{lr}.{max_filters}.{min_filters}.{upsample_layers}.{blocks}.{noise_dim}.{image_size}.{wass_target}.{ttur}.{mse_weight}' if not os.path.exists(models_dir): os.mkdir(models_dir) if not os.path.exists(models_dir + model_name): os.mkdir(models_dir + model_name) if not os.path.exists(log_dir): os.mkdir(log_dir) log_dir = Path(log_dir) / model_name if not os.path.exists(log_dir): os.mkdir(log_dir) if reset: #erase models, erase logs print('resetting', models_dir+model_name, log_dir) remove_files_in_path(models_dir + model_name, 'pt') remove_files_in_path(log_dir, '0') writer = SummaryWriter(log_dir=log_dir) model = GAN(max_filters=max_filters, min_filters=min_filters, upsample_layers=upsample_layers, noise_dim=noise_dim, blocks=blocks, device_ids=device_ids, models_dir=models_dir, results_dir=results_dir, log_writer=writer, model_name=model_name, save_every=save_every, print_every=print_every, log_every=log_every, conv_type=conv_type, grad_mean=grad_mean, upsample_type=upsample_type, straight_through_round=straight_through_round, samples=samples) if resume: model.load(checkpoint) n_iter = model.train(dataloader, epochs, lr, wass_target, mse_weight, ttur) model.save(n_iter) writer.close()
# create networks g = generator(img_size, n_filters_g) if FLAGS.discriminator=='pixel': d, d_out_shape = discriminator_pixel(img_size, n_filters_d,init_lr) elif FLAGS.discriminator=='patch1': d, d_out_shape = discriminator_patch1(img_size, n_filters_d,init_lr) elif FLAGS.discriminator=='patch2': d, d_out_shape = discriminator_patch2(img_size, n_filters_d,init_lr) elif FLAGS.discriminator=='image': d, d_out_shape = discriminator_image(img_size, n_filters_d,init_lr) else: d, d_out_shape = discriminator_dummy(img_size, n_filters_d,init_lr) utils.make_trainable(d, False) gan=GAN(g,d,img_size, n_filters_g, n_filters_d,alpha_recip, init_lr) generator=pretrain_g(g, img_size, n_filters_g, init_lr) g.summary() d.summary() gan.summary() with open(os.path.join(model_out_dir,"g_{}_{}.json".format(FLAGS.discriminator,FLAGS.ratio_gan2seg)),'w') as f: f.write(g.to_json()) # start training scheduler=utils.Scheduler(n_train_imgs//batch_size, n_train_imgs//batch_size, schedules, init_lr) if alpha_recip>0 else utils.Scheduler(0, n_train_imgs//batch_size, schedules, init_lr) print "training {} images :".format(n_train_imgs) for n_round in range(n_rounds): # train D utils.make_trainable(d, True) for i in range(scheduler.get_dsteps()):
if steps % show_every == 0: gen_samples = sess.run( generator(net.input_z, 3, reuse=True, training=False), feed_dict={net.input_z: sample_z}) samples.append(gen_samples) _ = view_samples(-1, samples, 6, 12, figsize=figsize) plt.show() saver.save(sess, './checkpoints/generator.ckpt') with open('samples.pkl', 'wb') as f: pkl.dump(samples, f) return losses, samples # Create the network net = GAN(real_size, z_size, learning_rate, alpha=alpha, beta1=beta1) dataset = Dataset(trainset, testset) losses, samples = train(net, dataset, epochs, batch_size, figsize=(10,5)) _ = view_samples(0, samples, 4, 4, figsize=(10,5))
dcgan = DCGAN(min_resolution=[4, 4], max_resolution=[32, 32], min_channels=128, max_channels=512) gan = GAN(generator=dcgan.generator, discriminator=dcgan.discriminator, real_input_fn=functools.partial( cifar10_input_fn, filenames=glob.glob(args.filenames), batch_size=args.batch_size, num_epochs=args.num_epochs if args.train else 1, shuffle=True if args.train else False, ), fake_input_fn=lambda: (tf.random_normal([args.batch_size, 100])), hyper_params=Struct( generator_learning_rate=2e-4, generator_beta1=0.5, generator_beta2=0.999, discriminator_learning_rate=2e-4, discriminator_beta1=0.5, discriminator_beta2=0.999, mode_seeking_loss_weight=0.1, )) config = tf.ConfigProto(gpu_options=tf.GPUOptions( visible_device_list=args.gpu, allow_growth=True)) if args.train: gan.train(model_dir=args.model_dir,