def __init__(self, device='cuda:0', log_dir='logs', gpu_ids=0, lr=0.0002, beta1=0.5, lambda_idt=5, lambda_A=10.0, lambda_B=10.0): self.lr = lr self.beta1 = beta1 self.device = device self.netG_A = Generator().to(self.device) self.netG_B = Generator().to(self.device) self.netD_A = Discriminator().to(self.device) self.netD_B = Discriminator().to(self.device) print(torch.cuda.is_available()) # multi-GPUs self.netG_A = torch.nn.DataParallel(self.netG_A, gpu_ids) self.netG_B = torch.nn.DataParallel(self.netG_B, gpu_ids) self.netD_A = torch.nn.DataParallel(self.netD_A, gpu_ids) self.netD_B = torch.nn.DataParallel(self.netD_B, gpu_ids) print('will use gpus: {}'.format(gpu_ids)) self.fake_A_pool = ImagePool(50) self.fake_B_pool = ImagePool(50) # set losses self.criterionGAN = GANLoss(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # weights of loss function self.lambda_idt = lambda_idt self.lambda_A = lambda_A self.lambda_B = lambda_B # optimization self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=self.lr, betas=(self.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) self.log_dir = log_dir if not os.path.exists(self.log_dir): os.makedirs(self.log_dir)
def __init__(self, config): self.generator = Generator() self.discriminator = Discriminator() print(self.generator) print(self.discriminator) self.bce_loss_fn = nn.BCELoss() self.mse_loss_fn = nn.MSELoss() self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=config.lr, betas=(config.beta1, config.beta2)) self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=config.lr, betas=(config.beta1, config.beta2)) if config.dataset == 'grid': self.dataset = VaganDataset(config.dataset_dir, train=config.is_train) elif config.dataset == 'lrw': self.dataset = LRWdataset(config.dataset_dir, train=config.is_train) self.data_loader = DataLoader(self.dataset, batch_size=config.batch_size, num_workers=config.num_thread, shuffle=True, drop_last=True) data_iter = iter(self.data_loader) data_iter.next() self.ones = Variable(torch.ones(config.batch_size), requires_grad=False) self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False) if config.cuda: device_ids = [int(i) for i in config.device_ids.split(',')] self.generator = nn.DataParallel(self.generator.cuda(), device_ids=device_ids) self.discriminator = nn.DataParallel(self.discriminator.cuda(), device_ids=device_ids) self.bce_loss_fn = self.bce_loss_fn.cuda() self.mse_loss_fn = self.mse_loss_fn.cuda() self.ones = self.ones.cuda() self.zeros = self.zeros.cuda() self.config = config self.start_epoch = 0 if config.load_model: self.start_epoch = config.start_epoch self.load(config.pretrained_dir, config.pretrained_epoch)
class CycleGAN(object): def __init__(self, device='cuda:0', log_dir='logs', gpu_ids=0, lr=0.0002, beta1=0.5, lambda_idt=5, lambda_A=10.0, lambda_B=10.0, lambda_mask=10.0): self.lr = lr self.beta1 = beta1 self.device = device self.netG_A = Generator().to(self.device) self.netG_B = Generator().to(self.device) self.netD_A = Discriminator().to(self.device) self.netD_B = Discriminator().to(self.device) print(torch.cuda.is_available()) # multi-GPUs self.netG_A = torch.nn.DataParallel(self.netG_A, gpu_ids) self.netG_B = torch.nn.DataParallel(self.netG_B, gpu_ids) self.netD_A = torch.nn.DataParallel(self.netD_A, gpu_ids) self.netD_B = torch.nn.DataParallel(self.netD_B, gpu_ids) print('will use gpus: {}'.format(gpu_ids)) self.fake_A_pool = ImagePool(50) self.fake_B_pool = ImagePool(50) # set losses self.criterionGAN = GANLoss(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionMask = MASKLoss(self.device) # weights of loss function self.lambda_idt = lambda_idt self.lambda_A = lambda_A self.lambda_B = lambda_B self.lambda_mask = lambda_mask # optimization self.optimizer_G = torch.optim.Adam( itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=self.lr, betas=(self.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) self.log_dir = log_dir if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) def set_input(self, input): self.real_A = input['A'].to(self.device) self.real_B = input['B'].to(self.device) self.real_A_mask = input['A_mask'].to(self.device) def backward_G(self, real_A, real_B, real_A_mask): idt_A = self.netG_A(real_B) loss_idt_A = self.criterionIdt(idt_A, real_B) * self.lambda_idt idt_B = self.netG_B(real_A) loss_idt_B = self.criterionIdt(idt_B, real_A) * self.lambda_idt # GAN loss D_A(G_A(A)) # G_A tries to fool D_A as real fake_B = self.netG_A(real_A) pred_fake_B = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake_B, True) # GAN loss D_B(G_B(B)) # G_B tries to fool D_B as real fake_A = self.netG_B(real_B) pred_fake_A = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake_A, True) # forward cycle loss # real_A => fake_B => rec_A is close to real_A rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, real_A) * self.lambda_A # backward cycle loss # real_B => fake_A => rec_B is close to real_B rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, real_B) * self.lambda_B # mse for mase as a new loss function if self.lambda_mask == 0: loss_mask = torch.tensor(0).to(self.device) else: loss_mask = self.criterionMask(real_A, fake_B, real_A_mask) * self.lambda_mask # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B + loss_mask loss_G.backward() return loss_G_A.data, loss_G_B.data, loss_cycle_A.data, loss_cycle_B.data, \ loss_idt_A.data, loss_idt_B.data, loss_mask.data, fake_A.data, fake_B.data def backward_D_A(self, real_B, fake_B): # work on fake_B from domain A # use image pool fake_B = self.fake_B_pool.query(fake_B) # real image is real pred_real = self.netD_A(real_B) loss_D_real = self.criterionGAN(pred_real, True) # fake image is fake # detach() pred_fake = self.netD_A(fake_B.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # combined loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() return loss_D_A.data def backward_D_B(self, real_A, fake_A): # work on fake_A from domain B fake_A = self.fake_A_pool.query(fake_A) # real image is real pred_real = self.netD_B(real_A) loss_D_real = self.criterionGAN(pred_real, True) # fake image is fake # detach() pred_fake = self.netD_B(fake_A.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # combined loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() return loss_D_B.data def optimize(self): # update Generator (G_A and G_B) self.optimizer_G.zero_grad() loss_G_A, loss_G_B, loss_cycle_A, loss_cycle_B, loss_idt_A, loss_idt_B, loss_mask, fake_A, fake_B \ = self.backward_G(self.real_A, self.real_B, self.real_A_mask) self.optimizer_G.step() # update D_A self.optimizer_D_A.zero_grad() loss_D_A = self.backward_D_A(self.real_B, fake_B) self.optimizer_D_A.step() # update D_B self.optimizer_D_B.zero_grad() loss_D_B = self.backward_D_B(self.real_A, fake_A) self.optimizer_D_B.step() ret_loss = [ loss_G_A, loss_D_A, loss_G_B, loss_D_B, loss_cycle_A, loss_cycle_B, loss_idt_A, loss_idt_B, loss_mask ] return np.array(ret_loss) def train(self, data_loader): running_loss = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) time_list = [] for batch_idx, data in enumerate(data_loader): t1 = time.perf_counter() self.set_input(data) losses = self.optimize() losses = losses.astype(np.float32) running_loss += losses t2 = time.perf_counter() get_processing_time = t2 - t1 time_list.append(get_processing_time) if batch_idx % 50 == 0: print('batch: {} / {}, elapsed_time: {} sec'.format(batch_idx, len(data_loader), sum(time_list))) time_list = [] running_loss /= len(data_loader) return running_loss def save_network(self, network, network_label, epoch_label): save_filename = '{}_net_{}.pth'.format(epoch_label, network_label) save_path = os.path.join(self.log_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) network.to(self.device) def load_network(self, network, network_label, epoch_label): load_filename = '{}_net_{}.pth'.format(epoch_label, network_label) load_path = os.path.join(self.log_dir, load_filename) network.load_state_dict(torch.load(load_path)) def save(self, label): self.save_network(self.netG_A, 'G_A', label) self.save_network(self.netD_A, 'D_A', label) self.save_network(self.netG_B, 'G_B', label) self.save_network(self.netD_B, 'D_B', label) def load(self, label): self.load_network(self.netG_A, 'G_A', label) self.load_network(self.netD_A, 'D_A', label) self.load_network(self.netG_B, 'G_B', label) self.load_network(self.netD_B, 'D_B', label) def save_imgs(self, imgs, name_imgs, batch_size, epoch_label): img_table_name = '{}_'.format(epoch_label) + name_imgs + '.png' save_path = os.path.join(self.log_dir, img_table_name) if batch_size <= 16: utils.save_image( imgs, save_path, nrow=int(batch_size ** 0.5), normalize=True, range=(-1, 1) ) else: utils.save_image( imgs, save_path, nrow=int(16 ** 0.5), normalize=True, range=(-1, 1) ) def generate_imgs(self, epoch_label, batch_size): real_A = self.real_A real_B = self.real_B fake_B = self.netG_A(real_A) fake_A = self.netG_B(real_B) self.save_imgs(real_A, 'real_A', batch_size, epoch_label) self.save_imgs(real_B, 'real_B', batch_size, epoch_label) self.save_imgs(fake_B, 'fake_B', batch_size, epoch_label) self.save_imgs(fake_A, 'fake_A', batch_size, epoch_label)
def __init__(self, config): self.generator = Generator() self.discriminator = Discriminator() self.encoder = Encoder() self.encoder.load_state_dict( torch.load( '/mnt/disk1/dat/lchen63/grid/model/model_embedding/encoder_6.pth' )) for param in self.encoder.parameters(): param.requires_grad = False print(self.generator) print(self.discriminator) self.l1_loss_fn = nn.L1Loss() self.bce_loss_fn = nn.BCELoss() self.mse_loss_fn = nn.MSELoss() self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=config.lr, betas=(config.beta1, config.beta2)) self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=config.lr, betas=(config.beta1, config.beta2)) if config.dataset == 'grid': self.dataset = VaganDataset(config.dataset_dir, train=config.is_train) elif config.dataset == 'lrw': self.dataset = LRWdataset(config.dataset_dir, train=config.is_train) self.data_loader = DataLoader(self.dataset, batch_size=config.batch_size, num_workers=4, shuffle=True, drop_last=True) data_iter = iter(self.data_loader) data_iter.next() self.ones = Variable(torch.ones(config.batch_size), requires_grad=False) self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False) if config.cuda: device_ids = [int(i) for i in config.device_ids.split(',')] self.generator = nn.DataParallel(self.generator.cuda(), device_ids=device_ids) self.discriminator = nn.DataParallel(self.discriminator.cuda(), device_ids=device_ids) self.encoder = nn.DataParallel(self.encoder.cuda(), device_ids=device_ids) self.bce_loss_fn = self.bce_loss_fn.cuda() self.mse_loss_fn = self.mse_loss_fn.cuda() self.l1_loss_fn = self.l1_loss_fn.cuda() self.ones = self.ones.cuda() self.zeros = self.zeros.cuda() self.config = config self.start_epoch = 0
class Trainer(): def __init__(self, config): self.generator = Generator() self.discriminator = Discriminator() self.encoder = Encoder() self.encoder.load_state_dict( torch.load( '/mnt/disk1/dat/lchen63/grid/model/model_embedding/encoder_6.pth' )) for param in self.encoder.parameters(): param.requires_grad = False print(self.generator) print(self.discriminator) self.l1_loss_fn = nn.L1Loss() self.bce_loss_fn = nn.BCELoss() self.mse_loss_fn = nn.MSELoss() self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=config.lr, betas=(config.beta1, config.beta2)) self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=config.lr, betas=(config.beta1, config.beta2)) if config.dataset == 'grid': self.dataset = VaganDataset(config.dataset_dir, train=config.is_train) elif config.dataset == 'lrw': self.dataset = LRWdataset(config.dataset_dir, train=config.is_train) self.data_loader = DataLoader(self.dataset, batch_size=config.batch_size, num_workers=4, shuffle=True, drop_last=True) data_iter = iter(self.data_loader) data_iter.next() self.ones = Variable(torch.ones(config.batch_size), requires_grad=False) self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False) if config.cuda: device_ids = [int(i) for i in config.device_ids.split(',')] self.generator = nn.DataParallel(self.generator.cuda(), device_ids=device_ids) self.discriminator = nn.DataParallel(self.discriminator.cuda(), device_ids=device_ids) self.encoder = nn.DataParallel(self.encoder.cuda(), device_ids=device_ids) self.bce_loss_fn = self.bce_loss_fn.cuda() self.mse_loss_fn = self.mse_loss_fn.cuda() self.l1_loss_fn = self.l1_loss_fn.cuda() self.ones = self.ones.cuda() self.zeros = self.zeros.cuda() self.config = config self.start_epoch = 0 # self.load(config.model_dir) def fit(self): config = self.config configure("{}/".format(config.log_dir), flush_secs=5) num_steps_per_epoch = len(self.data_loader) cc = 0 for epoch in range(self.start_epoch, config.max_epochs): for step, (example, real_im, landmarks, right_audio, wrong_audio) in enumerate(self.data_loader): t1 = time.time() if config.cuda: example = Variable(example).cuda() landmarks = Variable(landmarks).cuda() real_im = Variable(real_im).cuda() right_audio = Variable(right_audio).cuda() wrong_audio = Variable(wrong_audio).cuda() else: example = Variable(example) landmarks = Variable(landmarks) real_im = Variable(real_im) right_audio = Variable(right_audio) wrong_audio = Variable(wrong_audio) fake_im = self.generator(example, right_audio) # Train the discriminator D_real = self.discriminator(real_im, right_audio) D_wrong = self.discriminator(real_im, wrong_audio) D_fake = self.discriminator(fake_im.detach(), right_audio) loss_real = self.bce_loss_fn(D_real, self.ones) loss_wrong = self.bce_loss_fn(D_wrong, self.zeros) loss_fake = self.bce_loss_fn(D_fake, self.zeros) loss_disc = loss_real + 0.5 * (loss_fake + loss_wrong) loss_disc.backward() self.opt_d.step() self._reset_gradients() # Train the generator # noise = Variable(torch.randn(config.batch_size, config.noise_size)) # noise = noise.cuda() if config.cuda else noise fake_im = self.generator(example, right_audio) fea_r = self.encoder(real_im)[1] fea_f = self.encoder(fake_im)[1] D_fake = self.discriminator(fake_im, right_audio) ############gan loss################### loss_gen1 = self.bce_loss_fn(D_fake, self.ones) #######gradient loss############## # f_gra_x = torch.abs(fake_im[:,:,:,:-1,:] - fake_im[:,:,:,1:,:]) # f_gra_y = torch.abs(fake_im[:,:,:,:,:-1] - fake_im[:,:,:,:,1:]) # r_gra_x = torch.abs(real_im[:,:,:,:-1,:] - real_im[:,:,:,1:,:]) # r_gra_y = torch.abs(real_im[:,:,:,:,:-1] - real_im[:,:,:,:,1:]) # loss_grad_x = self.l1_loss_fn(f_gra_x,r_gra_x) # loss_grad_y = self.l1_loss_fn(f_gra_y, r_gra_y) ######perceptual loss ############## loss_perceptual = self.mse_loss_fn(fea_f, fea_r) loss_gen = loss_gen1 + loss_perceptual loss_gen.backward() self.opt_g.step() self._reset_gradients() t2 = time.time() if (step + 1) % 1 == 0 or (step + 1) == num_steps_per_epoch: steps_remain = num_steps_per_epoch-step+1 + \ (config.max_epochs-epoch+1)*num_steps_per_epoch eta = int((t2 - t1) * steps_remain) print( "[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f}, loss_perceptual: {: .4f}, ETA: {} second" .format(epoch + 1, config.max_epochs, step + 1, num_steps_per_epoch, loss_disc.data[0], loss_gen1.data[0], loss_perceptual.data[0], eta)) log_value('discriminator_loss', loss_disc.data[0], step + num_steps_per_epoch * epoch) log_value('generator_loss', loss_gen1.data[0], step + num_steps_per_epoch * epoch) log_value('perceptual_loss', 0.5 * loss_perceptual.data[0], step + num_steps_per_epoch * epoch) if (step) % (num_steps_per_epoch / 3) == 0: fake_store = fake_im.data.permute( 0, 2, 1, 3, 4).contiguous().view(config.batch_size * 16, 3, 64, 64) torchvision.utils.save_image(fake_store, "{}fake_{}.png".format( config.sample_dir, cc), nrow=16, normalize=True) real_store = real_im.data.permute( 0, 2, 1, 3, 4).contiguous().view(config.batch_size * 16, 3, 64, 64) torchvision.utils.save_image(real_store, "{}real_{}.png".format( config.sample_dir, cc), nrow=16, normalize=True) cc += 1 if epoch % 1 == 0: torch.save( self.generator.state_dict(), "{}/generator_{}.pth".format(config.model_dir, epoch)) torch.save( self.discriminator.state_dict(), "{}/discriminator_{}.pth".format(config.model_dir, epoch)) def load(self, directory): paths = glob.glob(os.path.join(directory, "*.pth")) gen_path = [path for path in paths if "generator" in path][0] disc_path = [path for path in paths if "discriminator" in path][0] # gen_state_dict = torch.load(gen_path) # new_gen_state_dict = OrderedDict() # for k, v in gen_state_dict.items(): # name = 'model.' + k # new_gen_state_dict[name] = v # # load params # self.generator.load_state_dict(new_gen_state_dict) # disc_state_dict = torch.load(disc_path) # new_disc_state_dict = OrderedDict() # for k, v in disc_state_dict.items(): # name = 'model.' + k # new_disc_state_dict[name] = v # # load params # self.discriminator.load_state_dict(new_disc_state_dict) self.generator.load_state_dict(torch.load(gen_path)) self.discriminator.load_state_dict(torch.load(disc_path)) self.start_epoch = int(gen_path.split(".")[0].split("_")[-1]) print("Load pretrained [{}, {}]".format(gen_path, disc_path)) def _reset_gradients(self): self.generator.zero_grad() self.discriminator.zero_grad() self.encoder.zero_grad()
class Trainer(): def __init__(self, config): self.generator = Generator() self.discriminator = Discriminator() print(self.generator) print(self.discriminator) self.bce_loss_fn = nn.BCELoss() self.mse_loss_fn = nn.MSELoss() self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=config.lr, betas=(config.beta1, config.beta2)) self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=config.lr, betas=(config.beta1, config.beta2)) if config.dataset == 'grid': self.dataset = VaganDataset(config.dataset_dir, train=config.is_train) elif config.dataset == 'lrw': self.dataset = LRWdataset(config.dataset_dir, train=config.is_train) self.data_loader = DataLoader(self.dataset, batch_size=config.batch_size, num_workers=config.num_thread, shuffle=True, drop_last=True) data_iter = iter(self.data_loader) data_iter.next() self.ones = Variable(torch.ones(config.batch_size), requires_grad=False) self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False) if config.cuda: device_ids = [int(i) for i in config.device_ids.split(',')] self.generator = nn.DataParallel(self.generator.cuda(), device_ids=device_ids) self.discriminator = nn.DataParallel(self.discriminator.cuda(), device_ids=device_ids) self.bce_loss_fn = self.bce_loss_fn.cuda() self.mse_loss_fn = self.mse_loss_fn.cuda() self.ones = self.ones.cuda() self.zeros = self.zeros.cuda() self.config = config self.start_epoch = 0 if config.load_model: self.start_epoch = config.start_epoch self.load(config.pretrained_dir, config.pretrained_epoch) def fit(self): config = self.config configure("{}".format(config.log_dir), flush_secs=5) num_steps_per_epoch = len(self.data_loader) cc = 0 for epoch in range(self.start_epoch, config.max_epochs): for step, (example, real_im, landmarks, right_audio, wrong_audio) in enumerate(self.data_loader): t1 = time.time() if config.cuda: example = Variable(example).cuda() real_im = Variable(real_im).cuda() right_audio = Variable(right_audio).cuda() wrong_audio = Variable(wrong_audio).cuda() else: example = Variable(example) real_im = Variable(real_im) right_audio = Variable(right_audio) wrong_audio = Variable(wrong_audio) fake_im = self.generator(example, right_audio) # Train the discriminator D_real = self.discriminator(real_im, right_audio) D_wrong = self.discriminator(real_im, wrong_audio) D_fake = self.discriminator(fake_im.detach(), right_audio) loss_real = self.bce_loss_fn(D_real, self.ones) loss_wrong = self.bce_loss_fn(D_wrong, self.zeros) loss_fake = self.bce_loss_fn(D_fake, self.zeros) loss_disc = loss_real + 0.5 * (loss_fake + loss_wrong) loss_disc.backward() self.opt_d.step() self._reset_gradients() # Train the generator fake_im = self.generator(example, right_audio) D_fake = self.discriminator(fake_im, right_audio) loss_gen = self.bce_loss_fn(D_fake, self.ones) loss_gen.backward() self.opt_g.step() self._reset_gradients() t2 = time.time() if (step + 1) % 1 == 0 or (step + 1) == num_steps_per_epoch: steps_remain = num_steps_per_epoch-step+1 + \ (config.max_epochs-epoch+1)*num_steps_per_epoch eta = int((t2 - t1) * steps_remain) print( "[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f}, ETA: {} second" .format(epoch + 1, config.max_epochs, step + 1, num_steps_per_epoch, loss_disc.data[0], loss_gen.data[0], eta)) log_value('discriminator_loss', loss_disc.data[0], step + num_steps_per_epoch * epoch) log_value('generator_loss', loss_gen.data[0], step + num_steps_per_epoch * epoch) if (step) % (num_steps_per_epoch / 10) == 0: fake_store = fake_im.data.permute( 0, 2, 1, 3, 4).contiguous().view(config.batch_size * 16, 3, 64, 64) torchvision.utils.save_image(fake_store, "{}fake_{}.png".format( config.sample_dir, cc), nrow=16, normalize=True) real_store = real_im.data.permute( 0, 2, 1, 3, 4).contiguous().view(config.batch_size * 16, 3, 64, 64) torchvision.utils.save_image(real_store, "{}real_{}.png".format( config.sample_dir, cc), nrow=16, normalize=True) cc += 1 if epoch % 1 == 0: torch.save( self.generator.state_dict(), "{}/generator_{}.pth".format(config.model_dir, epoch)) torch.save( self.discriminator.state_dict(), "{}/discriminator_{}.pth".format(config.model_dir, epoch)) def load(self, directory, epoch): gen_path = os.path.join(directory, 'generator_{}.pth'.format(epoch)) disc_path = os.path.join(directory, 'discriminator_{}.pth'.format(epoch)) self.generator.load_state_dict(torch.load(gen_path)) self.discriminator.load_state_dict(torch.load(disc_path)) print("Load pretrained [{}, {}]".format(gen_path, disc_path)) def _reset_gradients(self): self.generator.zero_grad() self.discriminator.zero_grad()