class ModuleTrain: def __init__(self, opt, best_loss=0.2): self.opt = opt self.best_loss = best_loss # 正确率这个值,才会保存模型 self.netd = Discriminator(self.opt) self.netg = Generator(self.opt) self.use_gpu = False # 加载模型 if os.path.exists(self.opt.netd_path): self.load_netd(self.opt.netd_path) else: print('[Load model] error: %s not exist !!!' % self.opt.netd_path) if os.path.exists(self.opt.netg_path): self.load_netg(self.opt.netg_path) else: print('[Load model] error: %s not exist !!!' % self.opt.netg_path) # DataLoader初始化 self.transform_train = T.Compose([ T.Resize((self.opt.img_size, self.opt.img_size)), T.ToTensor(), T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]), ]) train_dataset = ImageFolder(root=self.opt.data_path, transform=self.transform_train) self.train_loader = DataLoader(dataset=train_dataset, batch_size=self.opt.batch_size, shuffle=True, num_workers=self.opt.num_workers, drop_last=True) # 优化器和损失函数 # self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.5) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr1, betas=(self.opt.beta1, 0.999)) self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr2, betas=(self.opt.beta1, 0.999)) self.criterion = torch.nn.BCELoss() self.true_labels = Variable(torch.ones(self.opt.batch_size)) self.fake_labels = Variable(torch.zeros(self.opt.batch_size)) self.fix_noises = Variable( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) self.noises = Variable( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) # gpu or cpu if self.opt.use_gpu and torch.cuda.is_available(): self.use_gpu = True else: self.use_gpu = False if self.use_gpu: print('[use gpu] ...') self.netd.cuda() self.netg.cuda() self.criterion.cuda() self.true_labels = self.true_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.fix_noises = self.fix_noises.cuda() self.noises = self.noises.cuda() else: print('[use cpu] ...') pass def train(self, save_best=True): print('[train] epoch: %d' % self.opt.max_epoch) for epoch_i in range(self.opt.max_epoch): loss_netd = 0.0 loss_netg = 0.0 correct = 0 print('================================================') for ii, (img, target) in enumerate(self.train_loader): # 训练 real_img = Variable(img) if self.opt.use_gpu: real_img = real_img.cuda() # 训练判别器 if (ii + 1) % self.opt.d_every == 0: self.optimizer_d.zero_grad() # 尽可能把真图片判别为1 output = self.netd(real_img) error_d_real = self.criterion(output, self.true_labels) error_d_real.backward() # 尽可能把假图片判别为0 self.noises.data.copy_( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) fake_img = self.netg(self.noises).detach() # 根据噪声生成假图 fake_output = self.netd(fake_img) error_d_fake = self.criterion(fake_output, self.fake_labels) error_d_fake.backward() self.optimizer_d.step() loss_netd += (error_d_real.item() + error_d_fake.item()) # 训练生成器 if (ii + 1) % self.opt.g_every == 0: self.optimizer_g.zero_grad() self.noises.data.copy_( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) fake_img = self.netg(self.noises) fake_output = self.netd(fake_img) # 尽可能让判别器把假图片也判别为1 error_g = self.criterion(fake_output, self.true_labels) error_g.backward() self.optimizer_g.step() loss_netg += error_g loss_netd /= (len(self.train_loader) * 2) loss_netg /= len(self.train_loader) print('[Train] Epoch: {} \tNetD Loss: {:.6f} \tNetG Loss: {:.6f}'. format(epoch_i, loss_netd, loss_netg)) if save_best is True: if (loss_netg + loss_netd) / 2 < self.best_loss: self.best_loss = (loss_netg + loss_netd) / 2 self.save(self.netd, self.opt.best_netd_path) # 保存最好的模型 self.save(self.netg, self.opt.best_netg_path) # 保存最好的模型 print('[save best] ...') # self.vis() if (epoch_i + 1) % 5 == 0: self.image_gan() self.save(self.netd, self.opt.netd_path) # 保存最好的模型 self.save(self.netg, self.opt.netg_path) # 保存最好的模型 def vis(self): fix_fake_imgs = self.netg(self.opt.fix_noises) visdom.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake') def image_gan(self): noises = torch.randn(self.opt.gen_search_num, self.opt.nz, 1, 1).normal_(self.opt.gen_mean, self.opt.gen_std) with torch.no_grad(): noises = Variable(noises) if self.use_gpu: noises = noises.cuda() fake_img = self.netg(noises) scores = self.netd(fake_img).data indexs = scores.topk(self.opt.gen_num)[1] result = list() for ii in indexs: result.append(fake_img.data[ii]) torchvision.utils.save_image(torch.stack(result), self.opt.gen_img, normalize=True, range=(-1, 1)) # # print(correct) # # print(len(self.train_loader.dataset)) # train_loss /= len(self.train_loader) # acc = float(correct) / float(len(self.train_loader.dataset)) # print('[Train] Epoch: {} \tLoss: {:.6f}\tAcc: {:.6f}\tlr: {}'.format(epoch_i, train_loss, acc, self.lr)) # # test_acc = self.test() # if save_best is True: # if test_acc > self.best_acc: # self.best_acc = test_acc # str_list = self.model_file.split('.') # best_model_file = "" # for str_index in range(len(str_list)): # best_model_file = best_model_file + str_list[str_index] # if str_index == (len(str_list) - 2): # best_model_file += '_best' # if str_index != (len(str_list) - 1): # best_model_file += '.' # self.save(best_model_file) # 保存最好的模型 # # self.save(self.model_file) def test(self): test_loss = 0.0 correct = 0 time_start = time.time() # 测试集 for data, target in self.test_loader: data, target = Variable(data), Variable(target) if self.use_gpu: data = data.cuda() target = target.cuda() output = self.model(data) # sum up batch loss if self.use_gpu: loss = self.loss(output, target) else: loss = self.loss(output, target) test_loss += loss.item() predict = torch.argmax(output, 1) correct += (predict == target).sum().data time_end = time.time() time_avg = float(time_end - time_start) / float( len(self.test_loader.dataset)) test_loss /= len(self.test_loader) acc = float(correct) / float(len(self.test_loader.dataset)) print('[Test] set: Test loss: {:.6f}\t Acc: {:.6f}\t time: {:.6f} \n'. format(test_loss, acc, time_avg)) return acc def load_netd(self, name): print('[Load model netd] %s ...' % name) self.netd.load_state_dict(torch.load(name)) def load_netg(self, name): print('[Load model netg] %s ...' % name) self.netg.load_state_dict(torch.load(name)) def save(self, model, name): print('[Save model] %s ...' % name) torch.save(model.state_dict(), name)
class Solver(object): def __init__(self, data_loader, config): self.data_loader = data_loader self.noise_n = config.noise_n self.G_last_act = last_act(config.G_last_act) self.D_out_n = config.D_out_n self.D_last_act = last_act(config.D_last_act) self.G_lr = config.G_lr self.D_lr = config.D_lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.epoch = config.epoch self.batch_size = config.batch_size self.D_train_step = config.D_train_step self.save_image_step = config.save_image_step self.log_step = config.log_step self.model_save_step = config.model_save_step self.model_save_path = config.model_save_path self.log_save_path = config.log_save_path self.image_save_path = config.image_save_path self.use_tensorboard = config.use_tensorboard self.pretrained_model = config.pretrained_model self.build_model() if self.use_tensorboard is not None: self.build_tensorboard() if self.pretrained_model is not None: if len(self.pretrained_model) != 2: raise "must have both G and D pretrained parameters, and G is first, D is second" self.load_pretrained_model() def build_model(self): self.G = Generator(self.noise_n, self.G_last_act) self.D = Discriminator(self.D_out_n, self.D_last_act) self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr, [self.beta1, self.beta2]) self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.G.cuda() self.D.cuda() def build_tensorboard(self): from commons.logger import Logger self.logger = Logger(self.log_save_path) def load_pretrained_model(self): self.G.load_state_dict(torch.load(self.pretrained_model[0])) self.D.load_state_dict(torch.load(self.pretrained_model[1])) def reset_grad(self): self.G_optimizer.zero_grad() self.D_optimizer.zero_grad() def to_var(self, x, volatile=False): if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def train(self): bce_loss = nn.BCELoss() print(len(self.data_loader)) for e in range(self.epoch): for i, batch_images in enumerate(self.data_loader): batch_size = batch_images.size(0) real_x = self.to_var(batch_images) noise_x = self.to_var( torch.FloatTensor(noise_vector(batch_size, self.noise_n))) real_label = self.to_var( torch.FloatTensor(batch_size).fill_(1.)) fake_label = self.to_var( torch.FloatTensor(batch_size).fill_(0.)) # train D fake_x = self.G(noise_x) real_out = self.D(real_x) fake_out = self.D(fake_x.detach()) D_real = bce_loss(real_out, real_label) D_fake = bce_loss(fake_out, fake_label) D_loss = D_real + D_fake self.reset_grad() D_loss.backward() self.D_optimizer.step() # Log loss = {} loss['D/loss_real'] = D_real.data[0] loss['D/loss_fake'] = D_fake.data[0] loss['D/loss'] = D_loss.data[0] # Train G if (i + 1) % self.D_train_step == 0: # noise_x = self.to_var(torch.FloatTensor(noise_vector(batch_size, self.noise_n))) fake_out = self.D(self.G(noise_x)) G_loss = bce_loss(fake_out, real_label) self.reset_grad() G_loss.backward() self.G_optimizer.step() loss['G/loss'] = G_loss.data[0] # Print log if (i + 1) % self.log_step == 0: log = "Epoch: {}/{}, Iter: {}/{}".format( e + 1, self.epoch, i + 1, len(self.data_loader)) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * len(self.data_loader) + i + 1) # Save images if (e + 1) % self.save_image_step == 0: noise_x = self.to_var( torch.FloatTensor(noise_vector(32, self.noise_n))) fake_image = self.G(noise_x) save_image( fake_image.data, os.path.join(self.image_save_path, "{}_fake.png".format(e + 1))) if (e + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, "{}_G.pth".format(e + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, "{}_D.pth".format(e + 1)))
def main_worker(args): ################ # Define model # ################ # 4/3 : scale factor in the paper scale_factor = 4 / 3 tmp_scale = args.img_size_max / args.img_size_min args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor))) args.size_list = [ int(args.img_size_min * scale_factor**i) for i in range(args.num_scale + 1) ] discriminator = Discriminator() generator = Generator(args.img_size_min, args.num_scale, scale_factor) ###################### # Loss and Optimizer # ###################### d_opt = mindspore.nn.Adam( discriminator.sub_discriminators[0].get_parameters(), 5e-4, 0.5, 0.999) g_opt = mindspore.nn.Adam(generator.sub_generators[0].get_parameters(), 5e-4, 0.5, 0.999) ############## # Load model # ############## args.stage = 0 if args.load_model is not None: check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r') to_restore = check_load.readlines()[-1].strip() load_file = os.path.join(args.log_dir, to_restore) if os.path.isfile(load_file): print("=> loading checkpoint '{}'".format(load_file)) checkpoint = mindspore.load_checkpoint( load_file) # MPS map_location='cpu'# for _ in range(int(checkpoint['stage'])): generator.progress() discriminator.progress() args.stage = checkpoint['stage'] args.img_to_use = checkpoint['img_to_use'] discriminator.load_state_dict(checkpoint['D_state_dict']) generator.load_state_dict(checkpoint['G_state_dict']) # MPS Adm.load_state_dict是否存在 d_opt.load_state_dict(checkpoint['d_optimizer']) g_opt.load_state_dict(checkpoint['g_optimizer']) print("=> loaded checkpoint '{}' (stage {})".format( load_file, checkpoint['stage'])) else: print("=> no checkpoint found at '{}'".format(args.log_dir)) ########### # Dataset # ########### train_dataset, _ = get_dataset(args.dataset, args) train_sampler = None train_loader = mindspore.DatasetHelper(train_dataset) # MPS 可能需要调参数 ###################### # Validate and Train # ###################### op1 = mindspore.ops.Pad(((5, 5), (5, 5))) op2 = mindspore.ops.Pad(((5, 5), (5, 5))) z_fix_list = [op1(mindspore.ops.StandardNormal(3, args.size_list[0]))] zero_list = [ op2(mindspore.ops.Zeros(3, args.size_list[zeros_idx])) for zeros_idx in range(1, args.num_scale + 1) ] z_fix_list = z_fix_list + zero_list if args.validation: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return elif args.test: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+") record_txt.write('DATASET\t:\t{}\n'.format(args.dataset)) record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype)) record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use)) record_txt.close() networks = [discriminator, generator] for stage in range(args.stage, args.num_scale + 1): trainSinGAN(train_loader, networks, { "d_opt": d_opt, "g_opt": g_opt }, stage, args, {"z_rec": z_fix_list}) validateSinGAN(train_loader, networks, stage, args, {"z_rec": z_fix_list}) discriminator.progress() generator.progress() # Update the networks at finest scale d_opt = mindspore.nn.Adam( discriminator.sub_discriminators[ discriminator.current_scale].parameters(), 5e-4, 0.5, 0.999) g_opt = mindspore.nn.Adam( generator.sub_generators[generator.current_scale].parameters(), 5e-4, 0.5, 0.999) ############## # Save model # ############## if stage == 0: check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") save_checkpoint( { 'stage': stage + 1, 'D_state_dict': discriminator.state_dict(), 'G_state_dict': generator.state_dict(), 'd_optimizer': d_opt.state_dict(), 'g_optimizer': g_opt.state_dict(), 'img_to_use': args.img_to_use }, check_list, args.log_dir, stage + 1) if stage == args.num_scale: check_list.close()
def main_worker(gpu, ngpus_per_node, args): if len(args.gpu) == 1: args.gpu = 0 else: args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:' + args.port, world_size=args.world_size, rank=args.rank) ################ # Define model # ################ # 4/3 : scale factor in the paper scale_factor = 4 / 3 tmp_scale = args.img_size_max / args.img_size_min args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor))) args.size_list = [ int(args.img_size_min * scale_factor**i) for i in range(args.num_scale + 1) ] discriminator = Discriminator() generator = Generator(args.img_size_min, args.num_scale, scale_factor) networks = [discriminator, generator] if args.distributed: if args.gpu is not None: print('Distributed to', args.gpu) torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) networks = [ torch.nn.parallel.DistributedDataParallel( x, device_ids=[args.gpu], output_device=args.gpu) for x in networks ] else: networks = [x.cuda() for x in networks] networks = [ torch.nn.parallel.DistributedDataParallel(x) for x in networks ] elif args.gpu is not None: torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] else: networks = [torch.nn.DataParallel(x).cuda() for x in networks] discriminator, generator, = networks ###################### # Loss and Optimizer # ###################### if args.distributed: d_opt = torch.optim.Adam( discriminator.module.sub_discriminators[0].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam( generator.module.sub_generators[0].parameters(), 5e-4, (0.5, 0.999)) else: d_opt = torch.optim.Adam( discriminator.sub_discriminators[0].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(), 5e-4, (0.5, 0.999)) ############## # Load model # ############## args.stage = 0 if args.load_model is not None: check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r') to_restore = check_load.readlines()[-1].strip() load_file = os.path.join(args.log_dir, to_restore) if os.path.isfile(load_file): print("=> loading checkpoint '{}'".format(load_file)) checkpoint = torch.load(load_file, map_location='cpu') for _ in range(int(checkpoint['stage'])): generator.progress() discriminator.progress() networks = [discriminator, generator] if args.distributed: if args.gpu is not None: print('Distributed to', args.gpu) torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) networks = [ torch.nn.parallel.DistributedDataParallel( x, device_ids=[args.gpu], output_device=args.gpu) for x in networks ] else: networks = [x.cuda() for x in networks] networks = [ torch.nn.parallel.DistributedDataParallel(x) for x in networks ] elif args.gpu is not None: torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] else: networks = [torch.nn.DataParallel(x).cuda() for x in networks] discriminator, generator, = networks args.stage = checkpoint['stage'] args.img_to_use = checkpoint['img_to_use'] discriminator.load_state_dict(checkpoint['D_state_dict']) generator.load_state_dict(checkpoint['G_state_dict']) d_opt.load_state_dict(checkpoint['d_optimizer']) g_opt.load_state_dict(checkpoint['g_optimizer']) print("=> loaded checkpoint '{}' (stage {})".format( load_file, checkpoint['stage'])) else: print("=> no checkpoint found at '{}'".format(args.log_dir)) cudnn.benchmark = True ########### # Dataset # ########### train_dataset, _ = get_dataset(args.dataset, args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) ###################### # Validate and Train # ###################### z_fix_list = [ F.pad(torch.randn(args.batch_size, 3, args.size_list[0], args.size_list[0]), [5, 5, 5, 5], value=0) ] zero_list = [ F.pad(torch.zeros(args.batch_size, 3, args.size_list[zeros_idx], args.size_list[zeros_idx]), [5, 5, 5, 5], value=0) for zeros_idx in range(1, args.num_scale + 1) ] z_fix_list = z_fix_list + zero_list if args.validation: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return elif args.test: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+") record_txt.write('DATASET\t:\t{}\n'.format(args.dataset)) record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype)) record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use)) record_txt.close() for stage in range(args.stage, args.num_scale + 1): if args.distributed: train_sampler.set_epoch(stage) trainSinGAN(train_loader, networks, { "d_opt": d_opt, "g_opt": g_opt }, stage, args, {"z_rec": z_fix_list}) validateSinGAN(train_loader, networks, stage, args, {"z_rec": z_fix_list}) if args.distributed: discriminator.module.progress() generator.module.progress() else: discriminator.progress() generator.progress() networks = [discriminator, generator] if args.distributed: if args.gpu is not None: print('Distributed', args.gpu) torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) networks = [ torch.nn.parallel.DistributedDataParallel( x, device_ids=[args.gpu], output_device=args.gpu) for x in networks ] else: networks = [x.cuda() for x in networks] networks = [ torch.nn.parallel.DistributedDataParallel(x) for x in networks ] elif args.gpu is not None: torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] else: networks = [torch.nn.DataParallel(x).cuda() for x in networks] discriminator, generator, = networks # Update the networks at finest scale if args.distributed: for net_idx in range(generator.module.current_scale): for param in generator.module.sub_generators[ net_idx].parameters(): param.requires_grad = False for param in discriminator.module.sub_discriminators[ net_idx].parameters(): param.requires_grad = False d_opt = torch.optim.Adam( discriminator.module.sub_discriminators[ discriminator.current_scale].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam( generator.module.sub_generators[ generator.current_scale].parameters(), 5e-4, (0.5, 0.999)) else: for net_idx in range(generator.current_scale): for param in generator.sub_generators[net_idx].parameters(): param.requires_grad = False for param in discriminator.sub_discriminators[ net_idx].parameters(): param.requires_grad = False d_opt = torch.optim.Adam( discriminator.sub_discriminators[ discriminator.current_scale].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam( generator.sub_generators[generator.current_scale].parameters(), 5e-4, (0.5, 0.999)) ############## # Save model # ############## if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if stage == 0: check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") save_checkpoint( { 'stage': stage + 1, 'D_state_dict': discriminator.state_dict(), 'G_state_dict': generator.state_dict(), 'd_optimizer': d_opt.state_dict(), 'g_optimizer': g_opt.state_dict(), 'img_to_use': args.img_to_use }, check_list, args.log_dir, stage + 1) if stage == args.num_scale: check_list.close()
def main(): env = DialogEnvironment() experiment_name = args.logdir.split('/')[1] #model name torch.manual_seed(args.seed) #TODO actor = Actor(hidden_size=args.hidden_size,num_layers=args.num_layers,device='cuda',input_size=args.input_size,output_size=args.input_size) critic = Critic(hidden_size=args.hidden_size,num_layers=args.num_layers,input_size=args.input_size,seq_len=args.seq_len) discrim = Discriminator(hidden_size=args.hidden_size,num_layers=args.hidden_size,input_size=args.input_size,seq_len=args.seq_len) actor.to(device), critic.to(device), discrim.to(device) actor_optim = optim.Adam(actor.parameters(), lr=args.learning_rate) critic_optim = optim.Adam(critic.parameters(), lr=args.learning_rate, weight_decay=args.l2_rate) discrim_optim = optim.Adam(discrim.parameters(), lr=args.learning_rate) # load demonstrations writer = SummaryWriter(args.logdir) if args.load_model is not None: #TODO saved_ckpt_path = os.path.join(os.getcwd(), 'save_model', str(args.load_model)) ckpt = torch.load(saved_ckpt_path) actor.load_state_dict(ckpt['actor']) critic.load_state_dict(ckpt['critic']) discrim.load_state_dict(ckpt['discrim']) episodes = 0 train_discrim_flag = True for iter in range(args.max_iter_num): actor.eval(), critic.eval() memory = deque() steps = 0 scores = [] similarity_scores = [] while steps < args.total_sample_size: scores = [] similarity_scores = [] state, expert_action, raw_state, raw_expert_action = env.reset() score = 0 similarity_score = 0 state = state[:args.seq_len,:] expert_action = expert_action[:args.seq_len,:] state = state.to(device) expert_action = expert_action.to(device) for _ in range(10000): steps += 1 mu, std = actor(state.resize(1,args.seq_len,args.input_size)) #TODO: gotta be a better way to resize. action = get_action(mu.cpu(), std.cpu())[0] for i in range(5): emb_sum = expert_action[i,:].sum().cpu().item() if emb_sum == 0: # print(i) action[i:,:] = 0 # manual padding break done= env.step(action) irl_reward = get_reward(discrim, state, action, args) if done: mask = 0 else: mask = 1 memory.append([state, torch.from_numpy(action).to(device), irl_reward, mask,expert_action]) score += irl_reward similarity_score += get_cosine_sim(expert=expert_action,action=action.squeeze(),seq_len=5) #print(get_cosine_sim(s1=expert_action,s2=action.squeeze(),seq_len=5),'sim') if done: break episodes += 1 scores.append(score) similarity_scores.append(similarity_score) score_avg = np.mean(scores) similarity_score_avg = np.mean(similarity_scores) print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg)) print('{}:: {} episode similarity score is {:.2f}'.format(iter, episodes, similarity_score_avg)) actor.train(), critic.train(), discrim.train() if train_discrim_flag: expert_acc, learner_acc = train_discrim(discrim, memory, discrim_optim, args) print("Expert: %.2f%% | Learner: %.2f%%" % (expert_acc * 100, learner_acc * 100)) writer.add_scalar('log/expert_acc', float(expert_acc), iter) #logg writer.add_scalar('log/learner_acc', float(learner_acc), iter) #logg writer.add_scalar('log/avg_acc', float(learner_acc + expert_acc)/2, iter) #logg if args.suspend_accu_exp is not None: #only if not None do we check. if expert_acc > args.suspend_accu_exp and learner_acc > args.suspend_accu_gen: train_discrim_flag = False train_actor_critic(actor, critic, memory, actor_optim, critic_optim, args) writer.add_scalar('log/score', float(score_avg), iter) writer.add_scalar('log/similarity_score', float(similarity_score_avg), iter) writer.add_text('log/raw_state', raw_state[0],iter) raw_action = get_raw_action(action) #TODO writer.add_text('log/raw_action', raw_action,iter) writer.add_text('log/raw_expert_action', raw_expert_action,iter) if iter % 100: score_avg = int(score_avg) # Open a file with access mode 'a' file_object = open(experiment_name+'.txt', 'a') result_str = str(iter) + '|' + raw_state[0] + '|' + raw_action + '|' + raw_expert_action + '\n' # Append at the end of file file_object.write(result_str) # Close the file file_object.close() model_path = os.path.join(os.getcwd(),'save_model') if not os.path.isdir(model_path): os.makedirs(model_path) ckpt_path = os.path.join(model_path, experiment_name + '_ckpt_'+ str(score_avg)+'.pth.tar') save_checkpoint({ 'actor': actor.state_dict(), 'critic': critic.state_dict(), 'discrim': discrim.state_dict(), 'args': args, 'score': score_avg, }, filename=ckpt_path)
class GanTrainer(Trainer): def __init__(self, train_loader, test_loader, valid_loader, general_args, trainer_args): super(GanTrainer, self).__init__(train_loader, test_loader, valid_loader, general_args) # Paths self.loadpath = trainer_args.loadpath self.savepath = trainer_args.savepath # Load the auto-encoder self.use_autoencoder = False if trainer_args.autoencoder_path and os.path.exists( trainer_args.autoencoder_path): self.use_autoencoder = True self.autoencoder = AutoEncoder(general_args=general_args).to( self.device) self.load_pretrained_autoencoder(trainer_args.autoencoder_path) self.autoencoder.eval() # Load the generator self.generator = Generator(general_args=general_args).to(self.device) if trainer_args.generator_path and os.path.exists( trainer_args.generator_path): self.load_pretrained_generator(trainer_args.generator_path) self.discriminator = Discriminator(general_args=general_args).to( self.device) # Optimizers and schedulers self.generator_optimizer = torch.optim.Adam( params=self.generator.parameters(), lr=trainer_args.generator_lr) self.discriminator_optimizer = torch.optim.Adam( params=self.discriminator.parameters(), lr=trainer_args.discriminator_lr) self.generator_scheduler = lr_scheduler.StepLR( optimizer=self.generator_optimizer, step_size=trainer_args.generator_scheduler_step, gamma=trainer_args.generator_scheduler_gamma) self.discriminator_scheduler = lr_scheduler.StepLR( optimizer=self.discriminator_optimizer, step_size=trainer_args.discriminator_scheduler_step, gamma=trainer_args.discriminator_scheduler_gamma) # Load saved states if os.path.exists(self.loadpath): self.load() # Loss function and stored losses self.adversarial_criterion = nn.BCEWithLogitsLoss() self.generator_time_criterion = nn.MSELoss() self.generator_frequency_criterion = nn.MSELoss() self.generator_autoencoder_criterion = nn.MSELoss() # Define labels self.real_label = 1 self.generated_label = 0 # Loss scaling factors self.lambda_adv = trainer_args.lambda_adversarial self.lambda_freq = trainer_args.lambda_freq self.lambda_autoencoder = trainer_args.lambda_autoencoder # Spectrogram converter self.spectrogram = Spectrogram(normalized=True).to(self.device) # Boolean indicating if the model needs to be saved self.need_saving = True # Boolean if the generator receives the feedback from the discriminator self.use_adversarial = trainer_args.use_adversarial def load_pretrained_generator(self, generator_path): """ Loads a pre-trained generator. Can be used to stabilize the training. :param generator_path: location of the pre-trained generator (string). :return: None """ checkpoint = torch.load(generator_path, map_location=self.device) self.generator.load_state_dict(checkpoint['generator_state_dict']) def load_pretrained_autoencoder(self, autoencoder_path): """ Loads a pre-trained auto-encoder. Can be used to infer :param autoencoder_path: location of the pre-trained auto-encoder (string). :return: None """ checkpoint = torch.load(autoencoder_path, map_location=self.device) self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) def train(self, epochs): """ Trains the GAN for a given number of pseudo-epochs. :param epochs: Number of time to iterate over a part of the dataset (int). :return: None """ for epoch in range(epochs): for i in range(self.train_batches_per_epoch): self.generator.train() self.discriminator.train() # Transfer to GPU local_batch = next(self.train_loader_iter) input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) batch_size = input_batch.shape[0] ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # Train the discriminator with real data self.discriminator_optimizer.zero_grad() label = torch.full((batch_size, ), self.real_label, device=self.device) output = self.discriminator(target_batch) # Compute and store the discriminator loss on real data loss_discriminator_real = self.adversarial_criterion( output, torch.unsqueeze(label, dim=1)) self.train_losses['discriminator_adversarial']['real'].append( loss_discriminator_real.item()) loss_discriminator_real.backward() # Train the discriminator with fake data generated_batch = self.generator(input_batch) label.fill_(self.generated_label) output = self.discriminator(generated_batch.detach()) # Compute and store the discriminator loss on fake data loss_discriminator_generated = self.adversarial_criterion( output, torch.unsqueeze(label, dim=1)) self.train_losses['discriminator_adversarial']['fake'].append( loss_discriminator_generated.item()) loss_discriminator_generated.backward() # Update the discriminator weights self.discriminator_optimizer.step() ############################ # Update G network: maximize log(D(G(z))) ########################### self.generator_optimizer.zero_grad() # Get the spectrogram specgram_target_batch = self.spectrogram(target_batch) specgram_fake_batch = self.spectrogram(generated_batch) # Fake labels are real for the generator cost label.fill_(self.real_label) output = self.discriminator(generated_batch) # Compute the generator loss on fake data # Get the adversarial loss loss_generator_adversarial = torch.zeros(size=[1], device=self.device) if self.use_adversarial: loss_generator_adversarial = self.adversarial_criterion( output, torch.unsqueeze(label, dim=1)) self.train_losses['generator_adversarial'].append( loss_generator_adversarial.item()) # Get the L2 loss in time domain loss_generator_time = self.generator_time_criterion( generated_batch, target_batch) self.train_losses['time_l2'].append(loss_generator_time.item()) # Get the L2 loss in frequency domain loss_generator_frequency = self.generator_frequency_criterion( specgram_fake_batch, specgram_target_batch) self.train_losses['freq_l2'].append( loss_generator_frequency.item()) # Get the L2 loss in embedding space loss_generator_autoencoder = torch.zeros(size=[1], device=self.device, requires_grad=True) if self.use_autoencoder: # Get the embeddings _, embedding_target_batch = self.autoencoder(target_batch) _, embedding_generated_batch = self.autoencoder( generated_batch) loss_generator_autoencoder = self.generator_autoencoder_criterion( embedding_generated_batch, embedding_target_batch) self.train_losses['autoencoder_l2'].append( loss_generator_autoencoder.item()) # Combine the different losses loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \ self.lambda_freq * loss_generator_frequency + \ self.lambda_autoencoder * loss_generator_autoencoder # Back-propagate and update the generator weights loss_generator.backward() self.generator_optimizer.step() # Print message if not (i % 10): message = 'Batch {}: \n' \ '\t Generator: \n' \ '\t\t Time: {} \n' \ '\t\t Frequency: {} \n' \ '\t\t Autoencoder {} \n' \ '\t\t Adversarial: {} \n' \ '\t Discriminator: \n' \ '\t\t Real {} \n' \ '\t\t Fake {} \n'.format(i, loss_generator_time.item(), loss_generator_frequency.item(), loss_generator_autoencoder.item(), loss_generator_adversarial.item(), loss_discriminator_real.item(), loss_discriminator_generated.item()) print(message) # Evaluate the model with torch.no_grad(): self.eval() # Save the trainer state self.save() # if self.need_saving: # self.save() # Increment epoch counter self.epoch += 1 self.generator_scheduler.step() self.discriminator_scheduler.step() def eval(self): self.generator.eval() self.discriminator.eval() batch_losses = {'time_l2': [], 'freq_l2': []} for i in range(self.valid_batches_per_epoch): # Transfer to GPU local_batch = next(self.valid_loader_iter) input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) generated_batch = self.generator(input_batch) # Get the spectrogram specgram_target_batch = self.spectrogram(target_batch) specgram_generated_batch = self.spectrogram(generated_batch) loss_generator_time = self.generator_time_criterion( generated_batch, target_batch) batch_losses['time_l2'].append(loss_generator_time.item()) loss_generator_frequency = self.generator_frequency_criterion( specgram_generated_batch, specgram_target_batch) batch_losses['freq_l2'].append(loss_generator_frequency.item()) # Store the validation losses self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2'])) self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2'])) # Display validation losses message = 'Epoch {}: \n' \ '\t Time: {} \n' \ '\t Frequency: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2'])), np.mean(np.mean(batch_losses['freq_l2']))) print(message) # Check if the loss is decreasing self.check_improvement() def save(self): """ Saves the model(s), optimizer(s), scheduler(s) and losses :return: None """ torch.save( { 'epoch': self.epoch, 'generator_state_dict': self.generator.state_dict(), 'discriminator_state_dict': self.discriminator.state_dict(), 'generator_optimizer_state_dict': self.generator_optimizer.state_dict(), 'discriminator_optimizer_state_dict': self.discriminator_optimizer.state_dict(), 'generator_scheduler_state_dict': self.generator_scheduler.state_dict(), 'discriminator_scheduler_state_dict': self.discriminator_scheduler.state_dict(), 'train_losses': self.train_losses, 'test_losses': self.test_losses, 'valid_losses': self.valid_losses }, self.savepath) def load(self): """ Loads the model(s), optimizer(s), scheduler(s) and losses :return: None """ checkpoint = torch.load(self.loadpath, map_location=self.device) self.epoch = checkpoint['epoch'] self.generator.load_state_dict(checkpoint['generator_state_dict']) self.discriminator.load_state_dict( checkpoint['discriminator_state_dict']) self.generator_optimizer.load_state_dict( checkpoint['generator_optimizer_state_dict']) self.discriminator_optimizer.load_state_dict( checkpoint['discriminator_optimizer_state_dict']) self.generator_scheduler.load_state_dict( checkpoint['generator_scheduler_state_dict']) self.discriminator_scheduler.load_state_dict( checkpoint['discriminator_scheduler_state_dict']) self.train_losses = checkpoint['train_losses'] self.test_losses = checkpoint['test_losses'] self.valid_losses = checkpoint['valid_losses'] def evaluate_metrics(self, n_batches): """ Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches :param: n_batches: number of batches to process :return: mean and std for each metric """ with torch.no_grad(): snrs = [] lsds = [] generator = self.generator.eval() for k in range(n_batches): # Transfer to GPU local_batch = next(self.test_loader_iter) # Transfer to GPU input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) # Generates a batch generated_batch = generator(input_batch) # Get the metrics snrs.append( snr(x=generated_batch.squeeze(), x_ref=target_batch.squeeze())) lsds.append( lsd(x=generated_batch.squeeze(), x_ref=target_batch.squeeze())) snrs = torch.cat(snrs).cpu().numpy() lsds = torch.cat(lsds).cpu().numpy() # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm snrs[np.isinf(snrs)] = np.nan lsds[np.isinf(lsds)] = np.nan return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd( lsds)
class Model(object): def __init__(self, opt): super(Model, self).__init__() # Generator self.gen = Generator(opt).cuda(opt.gpu_id) self.gen_params = self.gen.parameters() num_params = 0 for p in self.gen.parameters(): num_params += p.numel() print(self.gen) print(num_params) # Discriminator self.dis = Discriminator(opt).cuda(opt.gpu_id) self.dis_params = self.dis.parameters() num_params = 0 for p in self.dis.parameters(): num_params += p.numel() print(self.dis) print(num_params) # Regressor if opt.mse_weight: self.reg = torch.load('data/utils/classifier.pth').cuda( opt.gpu_id).eval() else: self.reg = None # Losses self.criterion_gan = GANLoss(opt, self.dis) self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight self.loss_mse = Variable(torch.zeros(1).cuda()) self.loss_adv = Variable(torch.zeros(1).cuda()) self.loss = Variable(torch.zeros(1).cuda()) self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/' self.gpu_id = opt.gpu_id self.noise_channels = opt.in_channels - len(opt.input_idx.split(',')) def forward(self, inputs): input, input_orig, target = inputs self.input = Variable(input.cuda(self.gpu_id)) self.input_orig = Variable(input_orig.cuda(self.gpu_id)) self.target = Variable(target.cuda(self.gpu_id)) noise = Variable( torch.randn(self.input.size(0), self.noise_channels).cuda(self.gpu_id)) self.fake = self.gen(torch.cat([self.input, noise], 1)) def backward_G(self): # Regressor loss if self.reg is not None: fake_input = self.reg(self.fake) self.loss_mse = self.criterion_mse(fake_input, self.input_orig) # GAN loss loss_adv, _ = self.criterion_gan(self.fake) loss_G = self.loss_mse + loss_adv loss_G.backward() def backward_D(self): loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake) loss_D = loss_adv loss_D.backward() def train(self): self.gen.train() self.dis.train() def eval(self): self.gen.eval() self.dis.eval() def save_checkpoint(self, epoch): torch.save( { 'epoch': epoch, 'gen_state_dict': self.gen.state_dict(), 'dis_state_dict': self.dis.state_dict() }, self.path + '%d.pkl' % epoch) def load_checkpoint(self, path, pretrained=True): weights = torch.load(path) self.gen.load_state_dict(weights['gen_state_dict']) self.dis.load_state_dict(weights['dis_state_dict'])
class Solver(object): def __init__(self, data_loader, config): self.data_loader = data_loader self.noise_n = config.noise_n self.G_last_act = last_act(config.G_last_act) self.D_out_n = config.D_out_n self.D_last_act = last_act(config.D_last_act) self.G_lr = config.G_lr self.D_lr = config.D_lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.epoch = config.epoch self.batch_size = config.batch_size self.D_train_step = config.D_train_step self.save_image_step = config.save_image_step self.log_step = config.log_step self.model_save_step = config.model_save_step self.clip_value = config.clip_value self.lambda_gp = config.lambda_gp self.model_save_path = config.model_save_path self.log_save_path = config.log_save_path self.image_save_path = config.image_save_path self.use_tensorboard = config.use_tensorboard self.pretrained_model = config.pretrained_model self.build_model() if self.use_tensorboard is not None: self.build_tensorboard() if self.pretrained_model is not None: if len(self.pretrained_model) != 2: raise "must have both G and D pretrained parameters, and G is first, D is second" self.load_pretrained_model() def build_model(self): self.G = Generator(self.noise_n, self.G_last_act) self.D = Discriminator(self.D_out_n, self.D_last_act) self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr, [self.beta1, self.beta2]) self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.G.cuda() self.D.cuda() def build_tensorboard(self): from commons.logger import Logger self.logger = Logger(self.log_save_path) def load_pretrained_model(self): self.G.load_state_dict(torch.load(self.pretrained_model[0])) self.D.load_state_dict(torch.load(self.pretrained_model[1])) def denorm(self, x): out = (x + 1) / 2 return out.clamp_(0, 1) def reset_grad(self): self.G_optimizer.zero_grad() self.D_optimizer.zero_grad() def to_var(self, x, volatile=False): if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def train(self): print(len(self.data_loader)) for e in range(self.epoch): for i, batch_images in enumerate(self.data_loader): batch_size = batch_images.size(0) label = torch.FloatTensor(batch_size) real_x = self.to_var(batch_images) noise_x = self.to_var( torch.FloatTensor(noise_vector(batch_size, self.noise_n))) # train D fake_x = self.G(noise_x) real_out = self.D(real_x) fake_out = self.D(fake_x.detach()) D_real = -torch.mean(real_out) D_fake = torch.mean(fake_out) D_loss = D_real + D_fake self.reset_grad() D_loss.backward() self.D_optimizer.step() # Log loss = {} loss['D/loss_real'] = D_real.data[0] loss['D/loss_fake'] = D_fake.data[0] loss['D/loss'] = D_loss.data[0] # choose one in below two # Clip weights of D # for p in self.D.parameters(): # p.data.clamp_(-self.clip_value, clip_value) # Gradients penalty, WGAP-GP alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) # print(alpha.shape, real_x.shape, fake_x.shape) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) gp_out = self.D(interpolated) grad = torch.autograd.grad(outputs=gp_out, inputs=interpolated, grad_outputs=torch.ones( gp_out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.D_optimizer.step() # Train G if (i + 1) % self.D_train_step == 0: fake_out = self.D(self.G(noise_x)) G_loss = -torch.mean(fake_out) self.reset_grad() G_loss.backward() self.G_optimizer.step() loss['G/loss'] = G_loss.data[0] # Print log if (i + 1) % self.log_step == 0: log = "Epoch: {}/{}, Iter: {}/{}".format( e + 1, self.epoch, i + 1, len(self.data_loader)) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * len(self.data_loader) + i + 1) # Save images if (e + 1) % self.save_image_step == 0: noise_x = self.to_var( torch.FloatTensor(noise_vector(16, self.noise_n))) fake_image = self.G(noise_x) save_image( self.denorm(fake_image.data), os.path.join(self.image_save_path, "{}_fake.png".format(e + 1))) if (e + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, "{}_G.pth".format(e + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, "{}_D.pth".format(e + 1)))
class Trainer(): def __init__(self, config): self.batch_size = config.batchSize self.epochs = config.epochs self.use_cycle_loss = config.cycleLoss self.cycle_multiplier = config.cycleMultiplier self.use_identity_loss = config.identityLoss self.identity_multiplier = config.identityMultiplier self.load_models = config.loadModels self.data_x_loc = config.dataX self.data_y_loc = config.dataY self.device = "cuda" if torch.cuda.is_available() else "cpu" self.init_models() self.init_data_loaders() self.g_optimizer = torch.optim.Adam(list(self.G_X.parameters()) + list(self.G_Y.parameters()), lr=config.lr) self.d_optimizer = torch.optim.Adam(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=config.lr) self.scheduler_g = torch.optim.lr_scheduler.StepLR(self.g_optimizer, step_size=1, gamma=0.95) self.output_path = "./outputs/" self.img_width = 256 self.img_height = 256 # Load/Construct the models def init_models(self): self.G_X = Generator(3, 3, nn.InstanceNorm2d) self.D_X = Discriminator(3) self.G_Y = Generator(3, 3, nn.InstanceNorm2d) self.D_Y = Discriminator(3) if self.load_models: self.G_X.load_state_dict( torch.load(self.output_path + "models/G_X", map_location='cpu')) self.G_Y.load_state_dict( torch.load(self.output_path + "models/G_Y", map_location='cpu')) self.D_X.load_state_dict( torch.load(self.output_path + "models/D_X", map_location='cpu')) self.D_Y.load_state_dict( torch.load(self.output_path + "models/D_Y", map_location='cpu')) else: self.G_X.apply(init_func) self.G_Y.apply(init_func) self.D_X.apply(init_func) self.D_Y.apply(init_func) self.G_X.to(self.device) self.G_Y.to(self.device) self.D_X.to(self.device) self.D_Y.to(self.device) # Initialize data loaders and image transformer def init_data_loaders(self): transform = transforms.Compose([ transforms.Resize((self.img_width, self.img_height)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) X_folder = torchvision.datasets.ImageFolder(self.data_x_loc, transform) self.X_loader = torch.utils.data.DataLoader(X_folder, batch_size=self.batch_size, shuffle=True) Y_folder = torchvision.datasets.ImageFolder(self.data_y_loc, transform) self.Y_loader = torch.utils.data.DataLoader(Y_folder, batch_size=self.batch_size, shuffle=True) def save_models(self): torch.save(self.G_X.state_dict(), self.output_path + "models/G_X") torch.save(self.D_X.state_dict(), self.output_path + "models/D_X") torch.save(self.G_Y.state_dict(), self.output_path + "models/G_Y") torch.save(self.D_Y.state_dict(), self.output_path + "models/D_Y") # Reset gradients for all models, needed for between every training def reset_gradients(self): self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() # Sample image from training data every %x epoch and save them for judging def save_samples(self, epoch): x_iter = iter(self.X_loader) y_iter = iter(self.Y_loader) img_data_x, _ = next(x_iter) img_data_y, _ = next(y_iter) original_x = np.array(img_data_x[0]) generated_y = np.array( self.G_Y(img_data_x[0].view(1, 3, self.img_width, self.img_height).to( self.device)).cpu().detach())[0] original_y = np.array(img_data_y[0]) generated_x = np.array( self.G_X(img_data_y[0].view(1, 3, self.img_width, self.img_height).to( self.device)).cpu().detach())[0] def prepare_image(img): img = img.transpose((1, 2, 0)) return img / 2 + 0.5 original_x = prepare_image(original_x) generated_y = prepare_image(generated_y) original_y = prepare_image(original_y) generated_x = prepare_image(generated_x) plt.imsave('./outputs/samples/original_X_{}.png'.format(epoch), original_x) plt.imsave('./outputs/samples/original_Y_{}.png'.format(epoch), original_y) plt.imsave('./outputs/samples/generated_X_{}.png'.format(epoch), generated_x) plt.imsave('./outputs/samples/generated_Y_{}.png'.format(epoch), generated_y) # Training loop def train(self): D_X_losses = [] D_Y_losses = [] G_X_losses = [] G_Y_losses = [] for epoch in range(self.epochs): print("======") print("Epoch {}!".format(epoch + 1)) # Track progress if epoch % 5 == 0: self.save_samples(epoch) # Paper reduces lr after 100 epochs if epoch > 100: self.scheduler_g.step() for (data_X, _), (data_Y, _) in zip(self.X_loader, self.Y_loader): data_X = data_X.to(self.device) data_Y = data_Y.to(self.device) # ===================================== # Train Discriminators # ===================================== # Train fake X self.reset_gradients() fake_X = self.G_X(data_Y) out_fake_X = self.D_X(fake_X) d_x_f_loss = torch.mean(out_fake_X**2) d_x_f_loss.backward() self.d_optimizer.step() # Train fake Y self.reset_gradients() fake_Y = self.G_Y(data_X) out_fake_Y = self.D_Y(fake_Y) d_y_f_loss = torch.mean(out_fake_Y**2) d_y_f_loss.backward() self.d_optimizer.step() # Train true X self.reset_gradients() out_true_X = self.D_X(data_X) d_x_t_loss = torch.mean((out_true_X - 1)**2) d_x_t_loss.backward() self.d_optimizer.step() # Train true Y self.reset_gradients() out_true_Y = self.D_Y(data_Y) d_y_t_loss = torch.mean((out_true_Y - 1)**2) d_y_t_loss.backward() self.d_optimizer.step() D_X_losses.append([ d_x_t_loss.cpu().detach().numpy(), d_x_f_loss.cpu().detach().numpy() ]) D_Y_losses.append([ d_y_t_loss.cpu().detach().numpy(), d_y_f_loss.cpu().detach().numpy() ]) # ===================================== # Train GENERATORS # ===================================== # Cycle X -> Y -> X self.reset_gradients() fake_Y = self.G_Y(data_X) out_fake_Y = self.D_Y(fake_Y) g_loss1 = torch.mean((out_fake_Y - 1)**2) if self.use_cycle_loss: reconst_X = self.G_X(fake_Y) g_loss2 = self.cycle_multiplier * torch.mean( (data_X - reconst_X)**2) G_Y_losses.append([ g_loss1.cpu().detach().numpy(), g_loss2.cpu().detach().numpy() ]) g_loss = g_loss1 + g_loss2 g_loss.backward() self.g_optimizer.step() # Cycle Y -> X -> Y self.reset_gradients() fake_X = self.G_X(data_Y) out_fake_X = self.D_X(fake_X) g_loss1 = torch.mean((out_fake_X - 1)**2) if self.use_cycle_loss: reconst_Y = self.G_Y(fake_X) g_loss2 = self.cycle_multiplier * torch.mean( (data_Y - reconst_Y)**2) G_X_losses.append([ g_loss1.cpu().detach().numpy(), g_loss2.cpu().detach().numpy() ]) g_loss = g_loss1 + g_loss2 g_loss.backward() self.g_optimizer.step() # ===================================== # Train image IDENTITY # ===================================== if self.use_identity_loss: self.reset_gradients() # X should be same after G(X) same_X = self.G_X(data_X) g_loss = self.identity_multiplier * torch.mean( (data_X - same_X)**2) g_loss.backward() self.g_optimizer.step() # Y should be same after G(Y) same_Y = self.G_X(data_Y) g_loss = self.identity_multiplier * torch.mean( (data_Y - same_Y)**2) g_loss.backward() self.g_optimizer.step() # Epoch done, save models self.save_models() # Save losses for analysis np.save(self.output_path + 'losses/G_X_losses.npy', np.array(G_X_losses)) np.save(self.output_path + 'losses/G_Y_losses.npy', np.array(G_Y_losses)) np.save(self.output_path + 'losses/D_X_losses.npy', np.array(D_X_losses)) np.save(self.output_path + 'losses/D_Y_losses.npy', np.array(D_Y_losses))
class Seq2SeqCycleGAN: def __init__(self, model_config, train_config, vocab, max_len, mode='train'): self.mode = mode self.model_config = model_config self.train_config = train_config self.vocab = vocab self.vocab_size = self.vocab.num_words self.max_len = max_len # self.embedding_layer = nn.Embedding(vocab_size, model_config['embedding_size'], padding_idx=PAD_token) self.embedding_layer = nn.Sequential( nn.Linear(self.vocab_size, self.model_config['embedding_size']), nn.Sigmoid()) self.G_AtoB = Generator(self.embedding_layer, self.model_config, self.train_config, self.vocab_size, self.max_len, mode=self.mode).cuda() self.G_BtoA = Generator(self.embedding_layer, self.model_config, self.train_config, self.vocab_size, self.max_len, mode=self.mode).cuda() if self.mode == 'train': self.D_B = Discriminator(self.embedding_layer, self.model_config, self.train_config).cuda() self.D_A = Discriminator(self.embedding_layer, self.model_config, self.train_config).cuda() if self.train_config['continue_train']: self.embedding_layer.load_state_dict( torch.load(self.train_config['which_epoch'] + '_embedding_layer.pth')) self.G_AtoB.load_state_dict( torch.load(self.train_config['which_epoch'] + '_G_AtoB.pth')) self.G_BtoA.load_state_dict( torch.load(self.train_config['which_epoch'] + '_G_BtoA.pth')) self.D_B.load_state_dict( torch.load(self.train_config['which_epoch'] + '_D_B.pth')) self.D_A.load_state_dict( torch.load(self.train_config['which_epoch'] + '_D_A.pth')) self.embedding_layer.train() self.G_AtoB.train() self.G_BtoA.train() self.D_B.train() self.D_A.train() self.criterionBCE = nn.BCELoss().cuda() self.criterionCE = nn.CrossEntropyLoss().cuda() self.optimizer_G = torch.optim.Adam(itertools.chain( self.embedding_layer.parameters(), self.G_AtoB.parameters(), self.G_BtoA.parameters()), lr=train_config['base_lr'], betas=(0.9, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.embedding_layer.parameters(), self.D_A.parameters(), self.D_B.parameters()), lr=train_config['base_lr'], betas=(0.9, 0.999)) self.real_label = torch.ones( (train_config['batch_size'], 1)).cuda() self.fake_label = torch.zeros( (train_config['batch_size'], 1)).cuda() else: self.embedding_layer.load_state_dict( torch.load(self.train_config['which_epoch'] + '_embedding_layer.pth')) self.G_AtoB.load_state_dict( torch.load(self.train_config['which_epoch'] + '_G_AtoB.pth')) self.G_BtoA.load_state_dict( torch.load(self.train_config['which_epoch'] + '_G_BtoA.pth')) self.embedding_layer.eval() self.G_AtoB.eval() self.G_BtoA.eval() def backward_D_basic(self, netD, real, real_addn_feats, fake, fake_addn_feats): netD.hidden = netD.init_hidden() pred_real = netD(real, real_addn_feats) loss_D_real = self.criterionBCE(pred_real, self.real_label) netD.hidden = netD.init_hidden() pred_fake = netD(fake.detach(), fake_addn_feats) loss_D_fake = self.criterionBCE(pred_fake, self.fake_label) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() self.clip_gradient(self.embedding_layer) self.clip_gradient(netD) return loss_D def backward_D_A(self): self.loss_D_A = self.backward_D_basic( self.D_A, self.real_A, self.real_A_addn_feats, self.fake_A, self.fake_A_addn_feats) * 10 def backward_D_B(self): self.loss_D_B = self.backward_D_basic( self.D_B, self.real_B, self.real_B_addn_feats, self.fake_B, self.fake_B_addn_feats) * 10 def backward_G(self): self.D_B.hidden = self.D_B.init_hidden() self.fake_B_addn_feats = get_addn_feats(self.fake_B, self.vocab).cuda() self.loss_G_AtoB = self.criterionBCE( self.D_B(self.fake_B, self.fake_B_addn_feats), self.real_label) * 10 self.D_A.hidden = self.D_A.init_hidden() self.fake_A_addn_feats = get_addn_feats(self.fake_A, self.vocab).cuda() self.loss_G_BtoA = self.criterionBCE( self.D_A(self.fake_A, self.fake_A_addn_feats), self.real_label) * 10 if self.rec_A.size(0) != self.real_A_label.size(0): self.real_A, self.rec_A, self.real_A_label = self.update_label_sizes( self.real_A, self.rec_A, self.real_A_label) self.loss_cycle_A = self.criterionCE(self.rec_A, self.real_A_label) #* lambda_A if self.rec_B.size(0) != self.real_B_label.size(0): self.real_B, self.rec_B, self.real_B_label = self.update_label_sizes( self.real_B, self.rec_B, self.real_B_label) self.loss_cycle_B = self.criterionCE(self.rec_B, self.real_B_label) #* lambda_B self.idt_B = self.G_AtoB(self.real_B) if self.idt_B.size(0) != self.real_B_label.size(0): self.real_B, self.idt_B, self.real_B_label = self.update_label_sizes( self.real_B, self.idt_B, self.real_B_label) self.loss_idt_B = self.criterionCE( self.idt_B, self.real_B_label) #* lambda_B * lambda_idt self.idt_A = self.G_BtoA(self.real_A) if self.idt_A.size(0) != self.real_A_label.size(0): self.real_A, self.idt_A, self.real_A_label = self.update_label_sizes( self.real_A, self.idt_A, self.real_A_label) self.loss_idt_A = self.criterionCE( self.idt_A, self.real_A_label) #* lambda_A * lambda_idt self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() self.clip_gradient(self.embedding_layer) self.clip_gradient(self.G_AtoB) self.clip_gradient(self.G_BtoA) def forward(self, real_A, real_A_addn_feats, real_B, real_B_addn_feats): self.real_A = real_A self.real_A_addn_feats = real_A_addn_feats self.real_A_label = self.real_A.max(dim=1)[1] self.real_B = real_B self.real_B_addn_feats = real_B_addn_feats self.real_B_label = self.real_B.max(dim=1)[1] self.fake_B = F.softmax(self.G_AtoB.forward(self.real_A), dim=1) self.fake_A = F.softmax(self.G_BtoA.forward(self.real_B), dim=1) if self.mode == 'train': self.rec_A = self.G_BtoA.forward(self.fake_B) self.rec_B = self.G_AtoB.forward(self.fake_A) else: real_A_list = self.real_A.max(dim=1)[1].tolist() real_B_list = self.real_B.max(dim=1)[1].tolist() fake_B_list = self.fake_B.max(dim=1)[1].tolist() fake_A_list = self.fake_A.max(dim=1)[1].tolist() print('Input (Shakespeare):', idx_to_sent(real_A_list, self.vocab)) print('Output (Modern):', idx_to_sent(fake_B_list, self.vocab)) print('\n') print('Input (Modern):', idx_to_sent(real_B_list, self.vocab)) print('Output (Shakespeare):', idx_to_sent(fake_A_list, self.vocab)) print('\n') def optimize_parameters(self): self.set_requires_grad([self.D_A, self.D_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad([self.D_A, self.D_B], True) self.optimizer_D.zero_grad() self.backward_D_B() self.backward_D_A() self.optimizer_D.step() def update_label_sizes(self, real, rec, real_label): if rec.size(0) > real.size(0): real_label = torch.cat( (real_label, torch.zeros((rec.size(0) - real.size(0))).type( torch.LongTensor).cuda()), 0) elif rec.size(0) < real.size(0): diff = real.size(0) - rec.size(0) to_concat = torch.zeros((diff, self.vocab_size)).cuda() to_concat[:, 0] = 1 rec = torch.cat((rec, to_concat), 0) return real, rec, real_label def indices_to_one_hot(self, idx_tensor): one_hot_tensor = torch.empty((idx_tensor.size(0), self.vocab_size)) for idx in range(idx_tensor.size(0)): zeros = torch.zeros((self.vocab_size)) zeros[idx_tensor[idx].item()] = 1.0 one_hot_tensor[idx] = zeros return one_hot_tensor def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def clip_gradient(self, model): nn.utils.clip_grad_norm_(model.parameters(), 0.25)
def train(config): gpu_manage(config) train_dataset = Dataset(config.train_dir) val_dataset = Dataset(config.val_dir) training_data_loader = DataLoader(dataset=train_dataset, num_workers=config.threads, batch_size=config.batchsize, shuffle=True) val_data_loader = DataLoader(dataset=val_dataset, num_workers=config.threads, batch_size=config.test_batchsize, shuffle=False) gen = UNet(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids) if config.gen_init is not None: param = torch.load(config.gen_init) gen.load_state_dict(param) print('load {} as pretrained model'.format(config.gen_init)) dis = Discriminator(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids) if config.dis_init is not None: param = torch.load(config.dis_init) dis.load_state_dict(param) print('load {} as pretrained model'.format(config.dis_init)) opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001) opt_dis = optim.Adam(dis.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001) real_a = torch.FloatTensor(config.batchsize, config.in_ch, 256, 256) real_b = torch.FloatTensor(config.batchsize, config.out_ch, 256, 256) criterionL1 = nn.L1Loss() criterionMSE = nn.MSELoss() criterionSoftplus = nn.Softplus() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if config.cuda: gen = gen.cuda(0) dis = dis.cuda(0) criterionL1 = criterionL1.cuda(0) criterionMSE = criterionMSE.cuda(0) criterionSoftplus = criterionSoftplus.cuda(0) real_a = real_a.cuda(0) real_b = real_b.cuda(0) real_a = Variable(real_a) real_b = Variable(real_b) logreport = LogReport(log_dir=config.out_dir) testreport = TestReport(log_dir=config.out_dir) for epoch in range(1, config.epoch + 1): print('Epoch', epoch, datetime.now()) for iteration, batch in enumerate(tqdm(training_data_loader)): real_a, real_b = batch[0], batch[1] real_a = F.interpolate(real_a, size=256).to(device) real_b = F.interpolate(real_b, size=256).to(device) fake_b = gen.forward(real_a) # Update D opt_dis.zero_grad() fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = dis.forward(fake_ab.detach()) batchsize, _, w, h = pred_fake.size() real_ab = torch.cat((real_a, real_b), 1) pred_real = dis.forward(real_ab) loss_d_fake = torch.sum(criterionSoftplus(pred_fake)) / batchsize / w / h loss_d_real = torch.sum(criterionSoftplus(-pred_real)) / batchsize / w / h loss_d = loss_d_fake + loss_d_real loss_d.backward() if epoch % config.minimax == 0: opt_dis.step() # Update G opt_gen.zero_grad() fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = dis.forward(fake_ab) loss_g_gan = torch.sum(criterionSoftplus(-pred_fake)) / batchsize / w / h loss_g = loss_g_gan + criterionL1(fake_b, real_b) * config.lamb loss_g.backward() opt_gen.step() if iteration % 100 == 0: logreport({ 'epoch': epoch, 'iteration': len(training_data_loader) * (epoch - 1) + iteration, 'gen/loss': loss_g.item(), 'dis/loss': loss_d.item(), }) with torch.no_grad(): log_test = test(config, val_data_loader, gen, criterionMSE, epoch) testreport(log_test) if epoch % config.snapshot_interval == 0: checkpoint(config, epoch, gen, dis) logreport.save_lossgraph() testreport.save_lossgraph() print('Done', datetime.now())