def get_GAN_AB_model(folder_model, model_name, device): n_residual_blocks = 9 # this should be the same values used in training the G_AB model G_AB = GeneratorResNet(input_shape=(3,0), num_residual_blocks = n_residual_blocks) G_AB.load_state_dict(torch.load(folder_model + model_name, map_location=device ), ) if cuda: G_AB = G_AB.to(device) return G_AB
def get_generator_model(): generator = GeneratorResNet(img_shape=img_shape, res_blocks=residual_blocks, c_dim=c_dim) generator.load_state_dict( torch.load(PATH_G, map_location=torch.device('cpu'))) generator.eval() return generator
def _prepare_generating(self): """Prepare generating Make tensorflow's graph. """ self.z_size = self.style_z_size + self.char_embedding_n if self.arch == 'DCGAN': generator = GeneratorDCGAN(img_size=(self.img_width, self.img_height), img_dim=self.img_dim, z_size=self.z_size, layer_n=4, k_size=3, smallest_hidden_unit_n=64, is_bn=False) elif self.arch == 'ResNet': generator = GeneratorResNet(k_size=3, smallest_unit_n=64) if FLAGS.generate_walk: style_embedding_np = np.random.uniform( -1, 1, (FLAGS.char_img_n // self.walk_step, self.style_z_size)).astype(np.float32) else: style_embedding_np = np.random.uniform( -1, 1, (self.style_ids_n, self.style_z_size)).astype(np.float32) with tf.variable_scope('embeddings'): style_embedding = tf.Variable(style_embedding_np, name='style_embedding') self.style_ids_x = tf.placeholder(tf.int32, (self.batch_size, ), name='style_ids_x') self.style_ids_y = tf.placeholder(tf.int32, (self.batch_size, ), name='style_ids_y') self.style_ids_alpha = tf.placeholder(tf.float32, (self.batch_size, ), name='style_ids_alpha') self.char_ids_x = tf.placeholder(tf.int32, (self.batch_size, ), name='char_ids_x') self.char_ids_y = tf.placeholder(tf.int32, (self.batch_size, ), name='char_ids_y') self.char_ids_alpha = tf.placeholder(tf.float32, (self.batch_size, ), name='char_ids_alpha') # If sum of (style/char)_ids is less than -1, z is generated from uniform distribution style_z_x = tf.cond( tf.less(tf.reduce_sum(self.style_ids_x), 0), lambda: tf.random_uniform( (self.batch_size, self.style_z_size), -1, 1), lambda: tf.nn.embedding_lookup(style_embedding, self.style_ids_x)) style_z_y = tf.cond( tf.less(tf.reduce_sum(self.style_ids_y), 0), lambda: tf.random_uniform( (self.batch_size, self.style_z_size), -1, 1), lambda: tf.nn.embedding_lookup(style_embedding, self.style_ids_y)) style_z = style_z_x * tf.expand_dims(1. - self.style_ids_alpha, 1) \ + style_z_y * tf.expand_dims(self.style_ids_alpha, 1) char_z_x = tf.one_hot(self.char_ids_x, self.char_embedding_n) char_z_y = tf.one_hot(self.char_ids_y, self.char_embedding_n) char_z = char_z_x * tf.expand_dims(1. - self.char_ids_alpha, 1) \ + char_z_y * tf.expand_dims(self.char_ids_alpha, 1) z = tf.concat([style_z, char_z], axis=1) self.generated_imgs = generator(z, is_train=False) if FLAGS.gpu_ids == "": sess_config = tf.ConfigProto(device_count={"GPU": 0}, log_device_placement=True) else: sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions( visible_device_list=FLAGS.gpu_ids)) self.sess = tf.Session(config=sess_config) self.sess.run(tf.global_variables_initializer()) if FLAGS.generate_walk: var_list = [ var for var in tf.global_variables() if 'embedding' not in var.name ] else: var_list = [var for var in tf.global_variables()] pretrained_saver = tf.train.Saver(var_list=var_list) checkpoint = tf.train.get_checkpoint_state(self.src_log) assert checkpoint, 'cannot get checkpoint: {}'.format(self.src_log) pretrained_saver.restore(self.sess, checkpoint.model_checkpoint_path)
def _prepare_training(self): """Prepare Training Make tensorflow's graph. To support Multi-GPU, divide mini-batch. And this program has resume function. If there is checkpoint file in FLAGS.gan_dir/log, load checkpoint file and restart training. """ assert FLAGS.batch_size >= FLAGS.style_ids_n, 'batch_size must be greater equal than style_ids_n' self.gpu_n = len(FLAGS.gpu_ids.split(',')) self.embedding_chars = set_chars_type(FLAGS.chars_type) assert self.embedding_chars != [], 'embedding_chars is empty' self.char_embedding_n = len(self.embedding_chars) self.z_size = FLAGS.style_z_size + self.char_embedding_n with tf.device('/cpu:0'): # Set embeddings from uniform distribution style_embedding_np = np.random.uniform( -1, 1, (FLAGS.style_ids_n, FLAGS.style_z_size)).astype(np.float32) with tf.variable_scope('embeddings'): self.style_embedding = tf.Variable(style_embedding_np, name='style_embedding') self.style_ids = tf.placeholder(tf.int32, (FLAGS.batch_size, ), name='style_ids') self.char_ids = tf.placeholder(tf.int32, (FLAGS.batch_size, ), name='char_ids') self.is_train = tf.placeholder(tf.bool, name='is_train') self.real_imgs = tf.placeholder(tf.float32, (FLAGS.batch_size, FLAGS.img_width, FLAGS.img_height, FLAGS.img_dim), name='real_imgs') self.labels = tf.placeholder( tf.float32, (FLAGS.batch_size, self.char_embedding_n), name='labels') d_opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0., beta2=0.9) g_opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0., beta2=0.9) # Initialize lists for multi gpu fake_imgs = [0] * self.gpu_n d_loss = [0] * self.gpu_n g_loss = [0] * self.gpu_n d_grads = [0] * self.gpu_n g_grads = [0] * self.gpu_n divided_batch_size = FLAGS.batch_size // self.gpu_n is_not_first = False # Build graph for i in range(self.gpu_n): batch_start = i * divided_batch_size batch_end = (i + 1) * divided_batch_size with tf.device('/gpu:{}'.format(i)): if FLAGS.arch == 'DCGAN': generator = GeneratorDCGAN(img_size=(FLAGS.img_width, FLAGS.img_height), img_dim=FLAGS.img_dim, z_size=self.z_size, layer_n=4, k_size=3, smallest_hidden_unit_n=64, is_bn=False) discriminator = DiscriminatorDCGAN( img_size=(FLAGS.img_width, FLAGS.img_height), img_dim=FLAGS.img_dim, layer_n=4, k_size=3, smallest_hidden_unit_n=64, is_bn=False) elif FLAGS.arch == 'ResNet': generator = GeneratorResNet(k_size=3, smallest_unit_n=64) discriminator = DiscriminatorResNet(k_size=3, smallest_unit_n=64) # If sum of (style/char)_ids is less than -1, z is generated from uniform distribution style_z = tf.cond( tf.less( tf.reduce_sum(self.style_ids[batch_start:batch_end]), 0), lambda: tf.random_uniform( (divided_batch_size, FLAGS.style_z_size), -1, 1), lambda: tf.nn.embedding_lookup( self.style_embedding, self.style_ids[batch_start: batch_end])) char_z = tf.one_hot(self.char_ids[batch_start:batch_end], self.char_embedding_n) z = tf.concat([style_z, char_z], axis=1) # Generate fake images fake_imgs[i] = generator(z, is_reuse=is_not_first, is_train=self.is_train) # Calculate loss d_real = discriminator(self.real_imgs[batch_start:batch_end], is_reuse=is_not_first, is_train=self.is_train) d_fake = discriminator(fake_imgs[i], is_reuse=True, is_train=self.is_train) d_loss[i] = -(tf.reduce_mean(d_real) - tf.reduce_mean(d_fake)) g_loss[i] = -tf.reduce_mean(d_fake) # Calculate gradient Penalty epsilon = tf.random_uniform((divided_batch_size, 1, 1, 1), minval=0., maxval=1.) interp = self.real_imgs[batch_start:batch_end] + epsilon * ( fake_imgs[i] - self.real_imgs[batch_start:batch_end]) d_interp = discriminator(interp, is_reuse=True, is_train=self.is_train) grads = tf.gradients(d_interp, [interp])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(grads), reduction_indices=[-1])) grad_penalty = tf.reduce_mean((slopes - 1.)**2) d_loss[i] += 10 * grad_penalty # Get trainable variables d_vars = [ var for var in tf.trainable_variables() if 'discriminator' in var.name ] g_vars = [ var for var in tf.trainable_variables() if 'generator' in var.name ] d_grads[i] = d_opt.compute_gradients(d_loss[i], var_list=d_vars) g_grads[i] = g_opt.compute_gradients(g_loss[i], var_list=g_vars) is_not_first = True with tf.device('/cpu:0'): self.fake_imgs = tf.concat(fake_imgs, axis=0) avg_d_grads = average_gradients(d_grads) avg_g_grads = average_gradients(g_grads) self.d_train = d_opt.apply_gradients(avg_d_grads) self.g_train = g_opt.apply_gradients(avg_g_grads) # Calculate summary for tensorboard tf.summary.scalar('d_loss', -(sum(d_loss) / len(d_loss))) tf.summary.scalar('g_loss', -(sum(g_loss) / len(g_loss))) self.summary = tf.summary.merge_all() # Setup session sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions( visible_device_list=FLAGS.gpu_ids)) self.sess = tf.Session(config=sess_config) self.saver = tf.train.Saver(max_to_keep=5) # If checkpoint is found, restart training checkpoint = tf.train.get_checkpoint_state(self.dst_log) if checkpoint: saver_resume = tf.train.Saver() saver_resume.restore(self.sess, checkpoint.model_checkpoint_path) self.epoch_start = int( checkpoint.model_checkpoint_path.split('-')[-1]) print('restore ckpt') else: self.sess.run(tf.global_variables_initializer()) self.epoch_start = 0 # Setup writer for tensorboard self.writer = tf.summary.FileWriter(self.dst_log)
parser.add_argument( "--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") opt = parser.parse_args() SCALE_FACTOR = opt.scale_factor MODEL_NAME = opt.model_name hr_shape = (opt.hr_height, opt.hr_width) results = {'Test': {'psnr': [], 'ssim': []}} device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") generator = GeneratorResNet() generator = nn.DataParallel(generator, device_ids=[0, 1, 2]) generator.to(device) # generator.load_state_dict(torch.load("saved_models/generator_%d_%d.pth" % (4,99))) generator.load_state_dict(torch.load("saved_models/" + MODEL_NAME)) generator.eval() test_dataloader = DataLoader( TestImageDataset("../My_dataset/single_channel_100000/%s" % opt.test_dataset_name, hr_shape=hr_shape, scale_factor=opt.scale_factor), # change batch_size=1, shuffle=False, num_workers=opt.n_cpu,
def print_network(model, name): """ Print out the network information https://github.com/yunjey/stargan/blob/master/solver.py """ num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) # Initialize generator and discriminator generator = GeneratorResNet(input_shape=img_shape, residual_blocks=opt.residual_blocks, c_dim=c_dim) discriminator = Discriminator(input_shape=img_shape, c_dim=c_dim) if cuda: generator = generator.cuda() discriminator = discriminator.cuda() criterion_cycle.cuda() if opt.is_print: print_network(generator, 'Generator') print_network(discriminator, 'Discriminator') # Tensor type Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor if opt.epoch != 0:
def __init__(self, opt): self.config = opt # Make output dirs os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True) os.makedirs('images/%s' % (opt.model_name), exist_ok=True) self.cuda = opt.gpu_id > -1 # Gs and Ds self.G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks) self.G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks) if opt.large_patch: self.D_A = LargePatchDiscriminator() self.D_B = LargePatchDiscriminator() else: self.D_A = Discriminator() self.D_B = Discriminator() # Patch if opt.large_patch: self.patch = (1, 64, 64) else: self.patch = (1, 16, 16) # Weight init self.G_AB.apply(weights_init_normal) self.G_BA.apply(weights_init_normal) self.D_A.apply(weights_init_normal) self.D_B.apply(weights_init_normal) # Loss self.criterion_GAN = torch.nn.MSELoss() self.criterion_cycle = torch.nn.L1Loss() self.criterion_identity = torch.nn.L1Loss() if self.cuda: self.G_AB = self.G_AB.cuda() self.G_BA = self.G_BA.cuda() self.D_A = self.D_A.cuda() self.D_B = self.D_B.cuda() self.criterion_GAN = self.criterion_GAN.cuda() self.criterion_cycle = self.criterion_cycle.cuda() self.criterion_identity = self.criterion_identity.cuda() # Optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.G_AB.parameters(), self.G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) self.optimizer_D = torch.optim.Adam(self.D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step) self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step) self.Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor # Loss weights self.lambda_cyc = 10 self.lambda_id = opt.lambda_id * self.lambda_cyc # Buffers of previously generated samples self.fake_A_buffer = ReplayBuffer() self.fake_B_buffer = ReplayBuffer() # Image transformations A_transforms_ = [ transforms.CenterCrop((178, 178)), transforms.Resize((300, 300)), transforms.RandomCrop((256, 256)), transforms.RandomHorizontalFlip(), transforms.RandomAffine(degrees=opt.rotate_degree, fillcolor=(255, 255, 255)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] B_transforms_ = [ transforms.Resize((360, 360)), transforms.RandomCrop((256, 256)), transforms.RandomHorizontalFlip(), transforms.RandomAffine(degrees=opt.rotate_degree, fillcolor=(255, 255, 255)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] # Training data loader self.train_dataloader = DataLoader( ImageDataset("./data/", A_transforms_=A_transforms_, B_transforms_=B_transforms_), batch_size=1, shuffle=True, ) # Test data loader self.val_dataloader = DataLoader(ImageDataset( "./data/", A_transforms_=A_transforms_, B_transforms_=B_transforms_, mode='test'), batch_size=1)
class Trainer(): def __init__(self, opt): self.config = opt # Make output dirs os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True) os.makedirs('images/%s' % (opt.model_name), exist_ok=True) self.cuda = opt.gpu_id > -1 # Gs and Ds self.G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks) self.G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks) if opt.large_patch: self.D_A = LargePatchDiscriminator() self.D_B = LargePatchDiscriminator() else: self.D_A = Discriminator() self.D_B = Discriminator() # Patch if opt.large_patch: self.patch = (1, 64, 64) else: self.patch = (1, 16, 16) # Weight init self.G_AB.apply(weights_init_normal) self.G_BA.apply(weights_init_normal) self.D_A.apply(weights_init_normal) self.D_B.apply(weights_init_normal) # Loss self.criterion_GAN = torch.nn.MSELoss() self.criterion_cycle = torch.nn.L1Loss() self.criterion_identity = torch.nn.L1Loss() if self.cuda: self.G_AB = self.G_AB.cuda() self.G_BA = self.G_BA.cuda() self.D_A = self.D_A.cuda() self.D_B = self.D_B.cuda() self.criterion_GAN = self.criterion_GAN.cuda() self.criterion_cycle = self.criterion_cycle.cuda() self.criterion_identity = self.criterion_identity.cuda() # Optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.G_AB.parameters(), self.G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) self.optimizer_D = torch.optim.Adam(self.D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step) self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step) self.Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor # Loss weights self.lambda_cyc = 10 self.lambda_id = opt.lambda_id * self.lambda_cyc # Buffers of previously generated samples self.fake_A_buffer = ReplayBuffer() self.fake_B_buffer = ReplayBuffer() # Image transformations A_transforms_ = [ transforms.CenterCrop((178, 178)), transforms.Resize((300, 300)), transforms.RandomCrop((256, 256)), transforms.RandomHorizontalFlip(), transforms.RandomAffine(degrees=opt.rotate_degree, fillcolor=(255, 255, 255)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] B_transforms_ = [ transforms.Resize((360, 360)), transforms.RandomCrop((256, 256)), transforms.RandomHorizontalFlip(), transforms.RandomAffine(degrees=opt.rotate_degree, fillcolor=(255, 255, 255)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] # Training data loader self.train_dataloader = DataLoader( ImageDataset("./data/", A_transforms_=A_transforms_, B_transforms_=B_transforms_), batch_size=1, shuffle=True, ) # Test data loader self.val_dataloader = DataLoader(ImageDataset( "./data/", A_transforms_=A_transforms_, B_transforms_=B_transforms_, mode='test'), batch_size=1) def train_epoch(self, epoch): prev_time = time.time() for i, batch in enumerate(self.train_dataloader): # Model input real_A = Variable(batch['A'].type(self.Tensor)) real_B = Variable(batch['B'].type(self.Tensor)) # Adversarial ground truths valid = Variable(self.Tensor(np.ones( (real_A.size(0), *self.patch))), requires_grad=False) fake = Variable(self.Tensor(np.zeros( (real_A.size(0), *self.patch))), requires_grad=False) # Train Generators self.optimizer_G.zero_grad() # GAN loss fake_B = self.G_AB(real_A) loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), valid) fake_A = self.G_BA(real_B) loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle loss recov_A = self.G_BA(fake_B) loss_cycle_A = self.criterion_cycle(recov_A, real_A) recov_B = self.G_AB(fake_A) loss_cycle_B = self.criterion_cycle(recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Identity loss loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A) loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # Total loss loss_G = loss_GAN + self.lambda_cyc * loss_cycle + self.lambda_id * loss_identity loss_G.backward() self.optimizer_G.step() # Train Discriminator self.optimizer_D.zero_grad() # Real loss loss_real = self.criterion_GAN(self.D_A(real_A), valid) # Fake loss (on batch of previously generated samples) fake_A_ = self.fake_A_buffer.push_and_pop(fake_A) loss_fake = self.criterion_GAN(self.D_A(fake_A_.detach()), fake) # Total loss loss_D_A = (loss_real + loss_fake) / 2 loss_D_A.backward() self.optimizer_D.step() self.optimizer_D.zero_grad() loss_real = self.criterion_GAN(self.D_B(real_B), valid) fake_B_ = self.fake_B_buffer.push_and_pop(fake_B) loss_fake = self.criterion_GAN(self.D_B(fake_B_.detach()), fake) loss_D_B = (loss_real + loss_fake) / 2 loss_D_B.backward() self.optimizer_D.step() loss_D = (loss_D_A + loss_D_B) / 2 # Determine approximate time left batches_done = epoch * len(self.train_dataloader) + i batches_left = self.config.n_epochs * len( self.train_dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" % (epoch, self.config.n_epochs, i, len(self.train_dataloader), loss_D.item(), loss_G.item(), loss_GAN.item(), loss_cycle.item(), loss_identity.item(), time_left)) if batches_done % self.config.sample_interval == 0: # Sample a picture imgs = next(iter(self.val_dataloader)) real_A = Variable(imgs['A'].type(self.Tensor)) fake_B = self.G_AB(real_A) real_B = Variable(imgs['B'].type(self.Tensor)) fake_A = self.G_BA(real_B) img_sample = torch.cat( (real_A.data, fake_B.data, real_B.data, fake_A.data), 0) save_image(img_sample, 'images/%s/%s.png' % (self.config.model_name, batches_done), nrow=4, normalize=True) self.lr_scheduler_G.step() self.lr_scheduler_D.step() if self.config.checkpoint_interval != -1 and epoch % self.config.checkpoint_interval == 0: torch.save( self.G_AB.state_dict(), 'saved_models/%s/G_AB_%d.pth' % (self.config.model_name, epoch))
import torch from models import GeneratorResNet parser = argparse.ArgumentParser() parser.add_argument('--check_point', type=str, default='saved_models/G_AB_10.pth', help='check point from which load trained model') parser.add_argument('--batch_size', type=int, default=1, help='size of the batches') parser.add_argument('--A_file', type=str, default='test.png', help='path of the data') parser.add_argument('--img_height', type=int, default=256, help='size of image height') parser.add_argument('--img_width', type=int, default=256, help='size of image width') parser.add_argument('--gpu_id', type=int, default=-1, help='GPU id') opt = parser.parse_args() cuda = opt.gpu_id > -1 # # Load pretrained model G_AB G_AB = GeneratorResNet() if cuda: G_AB = G_AB.cuda() G_AB.load_state_dict(torch.load(opt.check_point)) Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor # Image transformations transforms_ = [transforms.Resize((opt.img_height, opt.img_width)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] img_transformer = transforms.Compose(transforms_) # Test data img = img_transformer(Image.open(opt.A_file).convert("RGB"))
def main(): cuda = torch.cuda.is_available() input_shape = (opt.channels, opt.img_height, opt.img_width) # Initialize generator and discriminator G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks) G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks) D_A = Discriminator(input_shape) D_B = Discriminator(input_shape) if cuda: G_AB = G_AB.cuda() G_BA = G_BA.cuda() D_A = D_A.cuda() D_B = D_B.cuda() criterion_GAN.cuda() criterion_cycle.cuda() criterion_identity.cuda() if opt.epoch != 0: # Load pretrained models G_AB.load_state_dict( torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch))) G_BA.load_state_dict( torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch))) D_A.load_state_dict( torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch))) D_B.load_state_dict( torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch))) else: # Initialize weights G_AB.apply(weights_init_normal) G_BA.apply(weights_init_normal) D_A.apply(weights_init_normal) D_B.apply(weights_init_normal) # Optimizers optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor # Buffers of previously generated samples fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Image transformations transforms_ = [ transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC), transforms.RandomCrop((opt.img_height, opt.img_width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] # Training data loader dataloader = DataLoader( ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu, ) # Test data loader val_dataloader = DataLoader( ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"), batch_size=5, shuffle=True, num_workers=1, ) def sample_images(batches_done): """Saves a generated sample from the test set""" imgs = next(iter(val_dataloader)) G_AB.eval() G_BA.eval() real_A = Variable(imgs["A"].type(Tensor)) fake_B = G_AB(real_A) real_B = Variable(imgs["B"].type(Tensor)) fake_A = G_BA(real_B) # Arange images along x-axis real_A = make_grid(real_A, nrow=5, normalize=True) real_B = make_grid(real_B, nrow=5, normalize=True) fake_A = make_grid(fake_A, nrow=5, normalize=True) fake_B = make_grid(fake_B, nrow=5, normalize=True) # Arange images along y-axis image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False) # ---------- # Training # ---------- prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(batch["A"].type(Tensor)) real_B = Variable(batch["B"].type(Tensor)) # Adversarial ground truths valid = Variable(Tensor( np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False) fake = Variable(Tensor( np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False) # ------------------ # Train Generators # ------------------ G_AB.train() G_BA.train() optimizer_G.zero_grad() # Identity loss loss_id_A = criterion_identity(G_BA(real_A), real_A) loss_id_B = criterion_identity(G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss fake_B = G_AB(real_A) loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) fake_A = G_BA(real_B) loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle loss recov_A = G_BA(fake_B) loss_cycle_A = criterion_cycle(recov_A, real_A) recov_B = G_AB(fake_A) loss_cycle_B = criterion_cycle(recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity loss_G.backward() optimizer_G.step() # ----------------------- # Train Discriminator A # ----------------------- optimizer_D_A.zero_grad() # Real loss loss_real = criterion_GAN(D_A(real_A), valid) # Fake loss (on batch of previously generated samples) fake_A_ = fake_A_buffer.push_and_pop(fake_A) loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) # Total loss loss_D_A = (loss_real + loss_fake) / 2 loss_D_A.backward() optimizer_D_A.step() # ----------------------- # Train Discriminator B # ----------------------- optimizer_D_B.zero_grad() # Real loss loss_real = criterion_GAN(D_B(real_B), valid) # Fake loss (on batch of previously generated samples) fake_B_ = fake_B_buffer.push_and_pop(fake_B) loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) # Total loss loss_D_B = (loss_real + loss_fake) / 2 loss_D_B.backward() optimizer_D_B.step() loss_D = (loss_D_A + loss_D_B) / 2 # -------------- # Log Progress # -------------- # Determine approximate time left batches_done = epoch * len(dataloader) + i batches_left = opt.n_epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" % ( epoch, opt.n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), loss_GAN.item(), loss_cycle.item(), loss_identity.item(), time_left, )) # If at sample interval save image if batches_done % opt.sample_interval == 0: sample_images(batches_done) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: # Save model checkpoints torch.save( G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch)) torch.save( G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch)) torch.save( D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch)) torch.save( D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
help='rotate degree') parser.add_argument('--lambda_id', type=float, default=0.5, help='lambda_id') parser.add_argument('--large_patch', type=bool, default=False, help='whether use large patch') opt = parser.parse_args() # make output dirs os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True) os.makedirs('images/%s' % (opt.model_name), exist_ok=True) cuda = opt.gpu_id > -1 # Gs and Ds G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks) G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks) if opt.large_patch: D_A = LargePatchDiscriminator() D_B = LargePatchDiscriminator() else: D_A = Discriminator() D_B = Discriminator() criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() if cuda: G_AB = G_AB.cuda() G_BA = G_BA.cuda()
def _define_generator(self, path_to_model): gen = GeneratorResNet() gen.load_state_dict(torch.load(path_to_model)) return gen
def _prepare_generating(self): """Prepare generating Make tensorflow's graph. """ self.z_size = self.style_z_size + self.char_embedding_n if self.arch == 'DCGAN': generator = GeneratorDCGAN(img_size=(self.img_width, self.img_height), img_dim=self.img_dim, z_size=self.z_size, layer_n=4, k_size=3, smallest_hidden_unit_n=64, is_bn=False) elif self.arch == 'ResNet': generator = GeneratorResNet(k_size=3, smallest_unit_n=64) if FLAGS.generate_walk: style_embedding_np = np.random.uniform( -1, 1, (FLAGS.char_img_n // self.walk_step, self.style_z_size)).astype(np.float32) else: style_embedding_np = np.random.uniform( -1, 1, (self.style_ids_n, self.style_z_size)).astype(np.float32) with tf.variable_scope('embeddings'): style_embedding = tf.Variable(style_embedding_np, name='style_embedding') self.char_ids = tf.placeholder(tf.int32, (self.batch_size, ), name='char_ids') style_z = self.latent char_z = tf.one_hot(self.char_ids, self.char_embedding_n) z = tf.concat([style_z, char_z], axis=1) self.generated_imgs = generator(z, is_train=False) if FLAGS.gpu_ids == "": sess_config = tf.ConfigProto(device_count={"GPU": 0}, log_device_placement=True) else: sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions( visible_device_list=FLAGS.gpu_ids)) self.sess = tf.Session(config=sess_config) self.sess.run(tf.global_variables_initializer()) if FLAGS.generate_walk: var_list = [ var for var in tf.global_variables() if 'embedding' not in var.name ] else: var_list = [var for var in tf.global_variables()] pretrained_saver = tf.train.Saver(var_list=var_list) ckpt_filename = "result.ckpt-10000" ckpt_filepath = os.path.join(self.src_log, ckpt_filename) pretrained_saver.restore(self.sess, ckpt_filepath)