def main(): args = parser.parse_args() print("================================") print(args) print("================================") os.makedirs('out', exist_ok=True) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': args.num_workers, 'pin_memory': True} if use_cuda else {} torch.manual_seed(args.seed) transform = transforms.Compose([transforms.ToTensor()]) train_set = CelebA('./data/celebA', transform=transform) # Inversion attack on TRAIN data of facescrub classifier test1_set = FaceScrub('./data/facescrub', transform=transform, train=True) # Inversion attack on TEST data of facescrub classifier test2_set = FaceScrub('./data/facescrub', transform=transform, train=False) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) test1_loader = torch.utils.data.DataLoader(test1_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) test2_loader = torch.utils.data.DataLoader(test2_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) classifier = nn.DataParallel(Classifier(nc=args.nc, ndf=args.ndf, nz=args.nz)).to(device) inversion = nn.DataParallel(Inversion(nc=args.nc, ngf=args.ngf, nz=args.nz, truncation=args.truncation, c=args.c)).to(device) optimizer = optim.Adam(inversion.parameters(), lr=0.0002, betas=(0.5, 0.999), amsgrad=True) # Load classifier path = 'out/classifier.pth' try: checkpoint = torch.load(path) classifier.load_state_dict(checkpoint['model']) epoch = checkpoint['epoch'] best_cl_acc = checkpoint['best_cl_acc'] print("=> loaded classifier checkpoint '{}' (epoch {}, acc {:.4f})".format(path, epoch, best_cl_acc)) except: print("=> load classifier checkpoint '{}' failed".format(path)) return # Train inversion model best_recon_loss = 999 for epoch in range(1, args.epochs + 1): train(classifier, inversion, args.log_interval, device, train_loader, optimizer, epoch) recon_loss = test(classifier, inversion, device, test1_loader, epoch, 'test1') test(classifier, inversion, device, test2_loader, epoch, 'test2') if recon_loss < best_recon_loss: best_recon_loss = recon_loss state = { 'epoch': epoch, 'model': inversion.state_dict(), 'optimizer': optimizer.state_dict(), 'best_recon_loss': best_recon_loss } torch.save(state, 'out/inversion.pth') shutil.copyfile('out/recon_test1_{}.png'.format(epoch), 'out/best_test1.png') shutil.copyfile('out/recon_test2_{}.png'.format(epoch), 'out/best_test2.png')
def get_loader_labeled(image_dir, attr_path, selected_attrs, batch_size, train, new_size=None, height=256, width=256, crop=True): """Build and return a data loader.""" transform_list = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list transform_list = [transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)] + transform_list if train else transform_list transform = transforms.Compose(transform_list) dataset = CelebA(image_dir, attr_path, selected_attrs, transform, train) data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train) return data_loader
def main(): logdir = osp.join(FLAGS.logdir, FLAGS.exp) logger = TensorBoardOutputFormat(logdir) config = tf.ConfigProto() sess = tf.Session(config=config) LABEL = None print("Loading data...") if FLAGS.dataset == 'cubes': dataset = Cubes(cond_idx=FLAGS.cond_idx) test_dataset = dataset if FLAGS.cond_idx == 0: label_size = 2 elif FLAGS.cond_idx == 1: label_size = 1 elif FLAGS.cond_idx == 2: label_size = 3 elif FLAGS.cond_idx == 3: label_size = 20 LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32) elif FLAGS.dataset == 'color': dataset = CubesColor() test_dataset = dataset LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32) label_size = 301 elif FLAGS.dataset == 'pos': dataset = CubesPos() test_dataset = dataset LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 elif FLAGS.dataset == "pairs": dataset = Pairs(cond_idx=0) test_dataset = dataset LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32) label_size = 6 elif FLAGS.dataset == "continual": dataset = CubesContinual() test_dataset = dataset if FLAGS.prelearn_model_shape: LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32) label_size = 20 else: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 elif FLAGS.dataset == "cross": dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline) test_dataset = dataset if FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) label_size = 1 elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 if FLAGS.joint_baseline: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) label_size = 3 elif FLAGS.dataset == 'celeba': dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx) test_dataset = dataset channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) model = ResNet128( num_channels=channel_num, num_filters=64, classes=2) if FLAGS.joint_baseline: # Other stuff for joint model optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999) X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32) NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32) HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) channel_num = 3 model = CubesNetGen(num_channels=channel_num, label_size=label_size) weights = model.construct_weights('context_0') output = model.forward(NOISE, weights, reuse=False, label=LABEL) print(output.get_shape()) mse_loss = tf.reduce_mean(tf.square(output - X)) gvs = optimizer.compute_gradients(mse_loss) train_op = optimizer.apply_gradients(gvs) gvs = [(k, v) for (k, v) in gvs if k is not None] target_vars = {} target_vars['train_op'] = train_op target_vars['X'] = X target_vars['X_NOISE'] = X_NOISE target_vars['ATTENTION_MASK'] = ATTENTION_MASK target_vars['eps_begin'] = tf.zeros(1) target_vars['gvs'] = gvs target_vars['energy_pos'] = tf.zeros(1) target_vars['energy_neg'] = tf.zeros(1) target_vars['loss_energy'] = tf.zeros(1) target_vars['loss_ml'] = tf.zeros(1) target_vars['total_loss'] = mse_loss target_vars['attention_mask'] = tf.zeros(1) target_vars['attention_grad'] = tf.zeros(1) target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X)) target_vars['x_mod'] = tf.zeros(1) target_vars['x_grad'] = tf.zeros(1) target_vars['NOISE'] = NOISE target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['HIER_LABEL'] = HIER_LABEL data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) else: print("label size here ", label_size) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32) if FLAGS.dataset != "celeba": model = CubesNet(num_channels=channel_num, label_size=label_size) heir_model = HeirNet(num_channels=FLAGS.cond_func) models_pretrain = [] if FLAGS.prelearn_model: model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label) weights = model_prelearn.construct_weights('context_1') LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32) models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN)) cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp) if (FLAGS.prelearn_iter != -1 or not FLAGS.train): model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) saver.restore(sess, model_file) if FLAGS.prelearn_model_shape: model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape) weights = model_prelearn.construct_weights('context_2') LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32) models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN)) cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape) if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train): model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) saver.restore(sess, model_file) print("Done loading...") data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = model.construct_weights('context_0') if FLAGS.heir_mask: weights = heir_model.construct_weights('heir_0', weights=weights) Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) attention_mask = ATTENTION_MASK tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99) for j in range(FLAGS.num_gpus): x_mod = X_SPLIT[j] if FLAGS.comb_mask: steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_attention_step(counter, attention_mask): attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01) energy_noise = energy_start = model.forward( x_mod, weights, attention_mask, label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_noise = energy_noise + energy_heir attention_grad = tf.gradients( FLAGS.temperature * energy_noise, [attention_mask])[0] energy_noise_old = energy_noise # Clip gradient norm for now attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME') attention_mask = tf.stop_gradient(attention_mask) counter = counter + 1 return counter, attention_mask steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask)) # attention_mask = tf.Print(attention_mask, [attention_mask]) energy_pos = model.forward( X_SPLIT[j], weights, tf.stop_gradient(attention_mask), label=LABEL_POS_SPLIT[j], stop_at_grad=False) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_pos = energy_heir + energy_pos else: energy_pos = model.forward( X_SPLIT[j], weights, attention_mask, label=LABEL_POS_SPLIT[j], stop_at_grad=False) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_pos = energy_heir + energy_pos print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] loss_energys = [] eps_begin = tf.zeros(1) steps = tf.constant(0) c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod, attention_mask): lr = FLAGS.step_lr x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale) attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01) energy_noise = model.forward( x_mod, weights, attention_mask, label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.prelearn_model: for m_i, w_i, l_i in models_pretrain: energy_noise = energy_noise + m_i.forward( x_mod, w_i, attention_mask, label=l_i, reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_noise = energy_heir + energy_noise x_grad, attention_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, attention_mask]) if not FLAGS.comb_mask: attention_grad = tf.zeros(1) energy_noise_old = energy_noise if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now x_last = x_mod - (lr) * x_grad if FLAGS.comb_mask: attention_mask = attention_mask - FLAGS.attention_lr * attention_grad attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME') attention_mask = tf.stop_gradient(attention_mask) x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod, attention_mask steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask)) attention_mask = tf.stop_gradient(attention_mask) # attention_mask = tf.Print(attention_mask, [attention_mask]) energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask]) x_grads.append(x_grad) energy_neg = model.forward( tf.stop_gradient(x_mod), weights, tf.stop_gradient(attention_mask), label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_neg = energy_heir + energy_neg temp = FLAGS.temperature x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights, attention_mask, reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.antialias: antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3])) inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME') test_x_mod = x_mod if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['HIER_LABEL'] = HEIR_LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['attention_grad'] = attention_grad if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin target_vars['ATTENTION_MASK'] = ATTENTION_MASK target_vars['models_pretrain'] = models_pretrain if FLAGS.comb_mask: target_vars['attention_mask'] = tf.nn.softmax(attention_mask) else: target_vars['attention_mask'] = tf.zeros(1) if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op # sess = tf.Session(config=config) saver = loader = tf.train.Saver( max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train): model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader)
args.custom_attr = args_.custom_attr args.n_attrs = len(args.attrs) args.betas = (args.beta1, args.beta2) print(args) if args.custom_img: output_path = join('output', args.experiment_name, 'custom_testing') from data import Custom test_dataset = Custom(args.custom_data, args.custom_attr, args.img_size, 'test', args.attrs) else: output_path = join('output', args.experiment_name, 'sample_testing') if args.data == 'CelebA': from data import CelebA test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'test', args.attrs) if args.data == 'CelebA-HQ': from data import CelebA_HQ test_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'test', args.attrs) os.makedirs(output_path, exist_ok=True) test_dataloader = data.DataLoader( test_dataset, batch_size=1, num_workers=args.num_workers, shuffle=False, drop_last=False ) if args.num_test is None: print('Testing images:', len(test_dataset)) else: print('Testing images:', min(len(test_dataset), args.num_test)) attgan = AttGAN(args)
# Paths checkpoint_path = join('results', args.experiment_name, 'checkpoint') sample_path = join('results', args.experiment_name, 'sample') summary_path = join('results', args.experiment_name, 'summary') os.makedirs(checkpoint_path, exist_ok=True) os.makedirs(sample_path, exist_ok=True) os.makedirs(summary_path, exist_ok=True) with open(join('results', args.experiment_name, 'setting.json'), 'w', encoding='utf-8') as f: json.dump(vars(args), f, indent=2, sort_keys=True) writer = SummaryWriter(summary_path) # Data selected_attrs = [args.target_attr] train_dset = CelebA(args.data_path, args.attr_path, args.image_size, 'train', selected_attrs) train_data = data.DataLoader(train_dset, args.batch_size, shuffle=True, drop_last=True) train_data = loop(train_data) test_dset = CelebA(args.data_path, args.attr_path, args.image_size, 'test', selected_attrs) test_data = data.DataLoader(test_dset, args.num_samples) for fixed_reals, fixed_labels in test_data: # Get the first batch of images from the testing set fixed_reals, fixed_labels = fixed_reals.to( device), fixed_labels.type_as(fixed_reals).to(device) fixed_target_labels = 1 - fixed_labels break del test_dset
# TODO: support conditional inputs args = parser.parse_args() opts = {k: v for k, v in args._get_kwargs()} latent_size = 512 sigmoid_at_end = args.gan in ['lsgan', 'gan'] if hasattr(args, 'no_tanh'): tanh_at_end = False else: tanh_at_end = True G = Generator(num_channels=3, latent_size=latent_size, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=tanh_at_end) D = Discriminator(num_channels=3, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=sigmoid_at_end) print(G) print(D) data = CelebA() noise = RandomNoiseGenerator(latent_size, 'gaussian') pggan = PGGAN(G, D, data, noise, opts) pggan.train()
print(args) args.lr_base = args.lr args.n_attrs = len(args.attrs) args.betas = (args.beta1, args.beta2) os.makedirs(join('output', args.experiment_name), exist_ok=True) os.makedirs(join('output', args.experiment_name, 'checkpoint'), exist_ok=True) os.makedirs(join('output', args.experiment_name, 'sample_training'), exist_ok=True) with open(join('output', args.experiment_name, 'setting.txt'), 'w') as f: f.write(json.dumps(vars(args), indent=4, separators=(',', ':'))) if args.data == 'CelebA': from data import CelebA train_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'train', args.attrs) valid_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'valid', args.attrs) if args.data == 'CelebA-HQ': from data import CelebA_HQ train_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'train', args.attrs) valid_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'valid', args.attrs) train_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, drop_last=True)
test_annos[i, index] = value tf = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) test_dataset = datasets.ImageFolder(root='test_imgs', transform=tf) test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size) if dataset == 'celeba': from data import CelebA, PairedData train_dset = CelebA( data_path, image_size, selected_attr=selected_attributes, mode='train', test_num=2000 ) train_data = PairedData(train_dset, batch_size) valid_dset = CelebA( data_path, image_size, selected_attr=selected_attributes, mode='val', test_num=2000 ) valid_data = PairedData(valid_dset, batch_size) if dataset == 'celeba-hq': from data import CelebAHQ, PairedData train_dset = CelebAHQ( data_path, image_size, selected_attr=selected_attributes, mode='train', test_num=2000 ) train_data = PairedData(train_dset, batch_size) valid_dset = CelebAHQ( data_path, image_size, selected_attr=selected_attributes, mode='val', test_num=2000 )
attrs_default = [ 'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young' ] attrs_list = attrs_default.copy() batch_size = 32 n_samples = 16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") PLL = PerceptualLoss() model_e_d = LSA_VAE(8, len(attrs_list)) from data import CelebA train_dataset = CelebA(data_path, attr_path, img_size, 'train', attrs_list) valid_dataset = CelebA(data_path, attr_path, img_size, 'valid', attrs_list) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=True) valid_dataloader = DataLoader(valid_dataset, batch_size=n_samples, num_workers=num_workers, shuffle=False, drop_last=False) # SGD_optimizer = optim.SGD(model_e_d.parameters(), lr=0.0005, momentum=0.5, weight_decay=1e-4) ADAM_optimizer = optim.Adam(model_e_d.parameters(), lr=0.0005,
# Logger base_dir = 'results/gamma={:.1f}_run={:d}'.format(args.gamma, args.run) writer = tb.FileWriter(os.path.join(base_dir, 'log.csv'), args=args, overwrite=args.run >= 999) writer.add_var('d_real', '{:8.4f}', d_real_loss) writer.add_var('d_fake', '{:8.4f}', d_fake_loss) writer.add_var('k', '{:8.4f}', k * 1) writer.add_var('M', '{:8.4f}', m_global) writer.add_var('lr', '{:8.6f}', lr * 1) writer.add_var('iter', '{:>8d}') writer.initialize() sess = tf.Session() load_model(sess) f_gen = tb.function(sess, [z], x_fake) f_rec = tb.function(sess, [x_real], d_real) celeba = CelebA(args.data) # Alternatively try grouping d_train/g_train together all_tensors = [d_train, g_train, d_real_loss, d_fake_loss] # d_tensors = [d_train, d_real_loss] # g_tensors = [g_train, d_fake_loss] for i in xrange(args.max_iter): x = celeba.next_batch(args.bs) z = np.random.uniform(-1, 1, (args.bs, args.e_size)) feed_dict = {'x:0': x, 'z:0': z, 'k:0': args.k, 'lr:0': args.lr, 'g:0': args.gamma} _, _, d_real_loss, d_fake_loss = sess.run(all_tensors, feed_dict) # _, d_real_loss = sess.run(d_tensors, feed_dict) # _, d_fake_loss = sess.run(g_tensors, feed_dict) args.k = np.clip(args.k + args.lambd * (args.gamma * d_real_loss - d_fake_loss), 0., 1.)
from networks import create_Generator from data import CelebA from model import InpaintingModel import json, time, os from config import config G = create_Generator(config) celeba = CelebA(os.path.expanduser('~/datasets/celeba-hq-1024x1024.h5'), os.path.expanduser('~/datasets/holes_hq.hdf5')) config['model_dir'] = config['model_dir'].replace( '<time>', time.strftime("%Y-%b-%d %H_%M_%S")) os.makedirs(config['model_dir']) with open(os.path.join(config['model_dir'], 'config.json'), 'w') as f: json.dump(config, f) model = InpaintingModel(G, celeba, config) model.run()