def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) print("======printing args========") print(args) print("=================================") # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): print("Loading bin dataset") dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) else: print(f"Loading raw text dataset {args.data}") dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # try to load generator model g_model_path = 'checkpoints/generator/best_gmodel.pt' if not os.path.exists(g_model_path): print("Start training generator!") train_g(args, dataset) assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load(g_model_path) #print(f"First dict: {pretrained_dict}") # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } #print(f"Second dict: {pretrained_dict}") # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) #print(f"model dict: {model_dict}") # 3. load the new state dict generator.load_state_dict(model_dict) print("Generator has successfully loaded!") # try to load discriminator model d_model_path = 'checkpoints/discriminator/best_dmodel.pt' if not os.path.exists(d_model_path): print("Start training discriminator!") train_d(args, dataset) assert os.path.exists(d_model_path) discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = discriminator.state_dict() pretrained_dict = torch.load(d_model_path) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(model_dict) print("Discriminator has successfully loaded!") #return print("starting main training loop") torch.autograd.set_detect_anomaly(True) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/joint'): os.makedirs('checkpoints/joint') checkpoints_path = 'checkpoints/joint/' # define loss function g_criterion = torch.nn.NLLLoss(size_average=False, ignore_index=dataset.dst_dict.pad(), reduce=True) d_criterion = torch.nn.BCEWithLogitsLoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) # seed = args.seed + epoch_i # torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when rand > 50% rand = random.random() if rand >= 0.5: # policy gradient training generator.decoder.is_testing = True sys_out_batch, prediction, _ = generator(sample) generator.decoder.is_testing = False with torch.no_grad(): n_i = sample['net_input']['src_tokens'] #print(f"net input:\n{n_i}, pred: \n{prediction}") reward = discriminator( sample['net_input']['src_tokens'], prediction) # dataset.dst_dict.pad()) train_trg_batch = sample['target'] #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}") pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() # oracle valid _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) else: # MLE training #print(f"printing sample: \n{sample}") _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence) #, dataset.dst_dict.pad()) #print(f"disc out: {disc_out.shape}, labels: {labels.shape}") #print(f"labels: {labels}") d_loss = d_criterion(disc_out, labels.long()) acc = torch.sum(torch.Sigmoid() (disc_out).round() == labels).float() / len(labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg, # d_logging_meters['train_acc'].avg, # i)) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: sample['id'] = sample['id'].cuda() sample['net_input']['src_tokens'] = sample['net_input'][ 'src_tokens'].cuda() sample['net_input']['src_lengths'] = sample['net_input'][ 'src_lengths'].cuda() sample['net_input']['prev_output_tokens'] = sample[ 'net_input']['prev_output_tokens'].cuda() sample['target'] = sample['target'].cuda() # generator validation _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug("G dev loss at batch {0}: {1:.3f}".format( i, g_logging_meters['valid_loss'].avg)) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence, dataset.dst_dict.pad()) d_loss = d_criterion(disc_out, labels) acc = torch.sum(torch.Sigmoid()(disc_out).round() == labels).float() / len(labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg, # d_logging_meters['valid_acc'].avg, i)) torch.save(generator, open( checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format( g_logging_meters['valid_loss'].avg, epoch_i), 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
def build_model(self): if cfg.train.loss_type == cfg.VANILLA: self.loss = nn.BCELoss() elif cfg.train.loss_type == cfg.WGAN: self.loss = lambda logits, labels: torch.mean(logits) self.D_global = Discriminator(cfg.dataset.dataset_name) self.G_global = Generator(cfg.dataset.dataset_name) # Enable cuda if available if torch.cuda.is_available(): self.D_global.cuda() self.G_global.cuda() # Optimizers self.D_global_optimizer = Adam(self.D_global.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.G_global_optimizer = Adam(self.G_global.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.D_pairs = [] self.G_pairs = [] self.D_pairs_optimizers = [] self.G_pairs_optimizers = [] self.D_msg_pairs = [] self.D_msg_pairs_optimizers = [] for id in range(1, cfg.train.N_pairs + 1): discriminator = Discriminator(cfg.dataset.dataset_name) generator = Generator(cfg.dataset.dataset_name) # Enable cuda if available if torch.cuda.is_available(): generator.cuda() discriminator.cuda() self.D_pairs.append(discriminator) self.G_pairs.append(generator) # Optimizers D_optimizer = Adam(discriminator.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) G_optimizer = Adam(generator.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.D_pairs_optimizers.append(D_optimizer) self.G_pairs_optimizers.append(G_optimizer) # create msg Discriminator pair for G_global discriminator = Discriminator(cfg.dataset.dataset_name) # Enable cuda if available if torch.cuda.is_available(): generator.cuda() discriminator.cuda() self.D_msg_pairs.append(discriminator) # Optimizers D_optimizer = Adam(discriminator.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.D_msg_pairs_optimizers.append(D_optimizer) self.logger = Logger(model_name='DCGAN', data_name='MNIST', logdir=cfg.validation.validation_dir) return
discriminator_A = Discriminator() discriminator_B = Discriminator() generator_A = generator_A.to(device) generator_B = generator_B.to(device) discriminator_A = discriminator_A.to(device) discriminator_B = discriminator_B.to(device) if device == 'cuda': generator_A = torch.nn.DataParallel(generator_A) generator_B = torch.nn.DataParallel(generator_B) discriminator_A = torch.nn.DataParallel(discriminator_A) discriminator_B = torch.nn.DataParallel(discriminator_B) chained_gen_params = chain(generator_A.parameters(), generator_B.parameters()) chained_dis_params = chain(discriminator_A.parameters(), discriminator_B.parameters()) optim_gen = torch.optim.Adam(chained_gen_params, lr=LEARNING_RATE, betas=(0.5, 0.999), weight_decay=0.00001) optim_dis = torch.optim.Adam(chained_dis_params, lr=LEARNING_RATE, betas=(0.5, 0.999), weight_decay=0.00001) data_size = min(len(data_A), len(data_B)) n_batches = (data_size // BATCH_SIZE) recon_criterion = nn.MSELoss()
class GAN_CLS(object): def __init__(self, args, data_loader, SUPERVISED=True): """ args : Arguments data_loader = An instance of class DataLoader for loading our dataset in batches """ self.data_loader = data_loader self.num_epochs = args.num_epochs self.batch_size = args.batch_size self.log_step = args.log_step self.sample_step = args.sample_step self.log_dir = args.log_dir self.checkpoint_dir = args.checkpoint_dir self.sample_dir = args.sample_dir self.final_model = args.final_model self.model_save_step = args.model_save_step #self.dataset = args.dataset #self.model_name = args.model_name self.img_size = args.img_size self.z_dim = args.z_dim self.text_embed_dim = args.text_embed_dim self.text_reduced_dim = args.text_reduced_dim self.learning_rate = args.learning_rate self.beta1 = args.beta1 self.beta2 = args.beta2 self.l1_coeff = args.l1_coeff self.resume_epoch = args.resume_epoch self.resume_idx = args.resume_idx self.SUPERVISED = SUPERVISED # Logger setting log_name = datetime.datetime.now().strftime('%Y-%m-%d') + '.log' self.logger = logging.getLogger('__name__') self.logger.setLevel(logging.INFO) self.formatter = logging.Formatter( '%(asctime)s:%(levelname)s:%(message)s') self.file_handler = logging.FileHandler( os.path.join(self.log_dir, log_name)) self.file_handler.setFormatter(self.formatter) self.logger.addHandler(self.file_handler) self.build_model() def smooth_label(self, tensor, offset): return tensor + offset def dump_imgs(images_Array, name): with open('{}.pickle'.format(name), 'wb') as file: dump(images_Array, file) def build_model(self): """ A function of defining following instances : ----- Generator ----- Discriminator ----- Optimizer for Generator ----- Optimizer for Discriminator ----- Defining Loss functions """ # ---------------------------------------------------------------------# # 1. Network Initialization # # ---------------------------------------------------------------------# self.gen = Generator(batch_size=self.batch_size, img_size=self.img_size, z_dim=self.z_dim, text_embed_dim=self.text_embed_dim, text_reduced_dim=self.text_reduced_dim) self.disc = Discriminator(batch_size=self.batch_size, img_size=self.img_size, text_embed_dim=self.text_embed_dim, text_reduced_dim=self.text_reduced_dim) self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.learning_rate, betas=(self.beta1, self.beta2)) self.disc_optim = optim.Adam(self.disc.parameters(), lr=self.learning_rate, betas=(self.beta1, self.beta2)) self.cls_gan_optim = optim.Adam(itertools.chain( self.gen.parameters(), self.disc.parameters()), lr=self.learning_rate, betas=(self.beta1, self.beta2)) print('------------- Generator Model Info ---------------') self.print_network(self.gen, 'G') print('------------------------------------------------') print('------------- Discriminator Model Info ---------------') self.print_network(self.disc, 'D') print('------------------------------------------------') self.criterion = nn.BCELoss().cuda() # self.CE_loss = nn.CrossEntropyLoss().cuda() # self.MSE_loss = nn.MSELoss().cuda() self.gen.train() self.disc.train() def print_network(self, model, name): """ A function for printing total number of model parameters """ num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("Total number of parameters: {}".format(num_params)) def load_checkpoints(self, resume_epoch, idx): """Restore the trained generator and discriminator.""" print('Loading the trained models from epoch {} and iteration {}...'. format(resume_epoch, idx)) G_path = os.path.join(self.checkpoint_dir, '{}-{}-G.ckpt'.format(resume_epoch, idx)) D_path = os.path.join(self.checkpoint_dir, '{}-{}-D.ckpt'.format(resume_epoch, idx)) self.gen.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.disc.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) def train_model(self): data_loader = self.data_loader start_epoch = 0 if self.resume_epoch >= 0: start_epoch = self.resume_epoch self.load_checkpoints(self.resume_epoch, self.resume_idx) print('--------------- Model Training Started ---------------') start_time = time.time() for epoch in range(start_epoch, self.num_epochs): print("Epoch: {}".format(epoch + 1)) for idx, batch in enumerate(data_loader): print("Index: {}".format(idx + 1), end="\t") true_imgs = batch['true_imgs'] true_embed = batch['true_embds'] false_imgs = batch['false_imgs'] real_labels = torch.ones(true_imgs.size(0)) fake_labels = torch.zeros(true_imgs.size(0)) smooth_real_labels = torch.FloatTensor( self.smooth_label(real_labels.numpy(), -0.1)) true_imgs = Variable(true_imgs.float()).cuda() true_embed = Variable(true_embed.float()).cuda() false_imgs = Variable(false_imgs.float()).cuda() real_labels = Variable(real_labels).cuda() smooth_real_labels = Variable(smooth_real_labels).cuda() fake_labels = Variable(fake_labels).cuda() # ---------------------------------------------------------------# # 2. Training the generator # # ---------------------------------------------------------------# self.gen.zero_grad() z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda() fake_imgs = self.gen.forward(true_embed, z) fake_out, fake_logit = self.disc.forward(fake_imgs, true_embed) fake_out = Variable(fake_out.data, requires_grad=True).cuda() true_out, true_logit = self.disc.forward(true_imgs, true_embed) true_out = Variable(true_out.data, requires_grad=True).cuda() g_sf = self.criterion(fake_out, real_labels) #g_img = self.l1_coeff * nn.L1Loss()(fake_imgs, true_imgs) gen_loss = g_sf gen_loss.backward() self.gen_optim.step() # ---------------------------------------------------------------# # 3. Training the discriminator # # ---------------------------------------------------------------# self.disc.zero_grad() false_out, false_logit = self.disc.forward( false_imgs, true_embed) false_out = Variable(false_out.data, requires_grad=True) sr = self.criterion(true_out, smooth_real_labels) sw = self.criterion(true_out, fake_labels) sf = self.criterion(false_out, smooth_real_labels) disc_loss = torch.log(sr) + (torch.log(1 - sw) + torch.log(1 - sf)) / 2 disc_loss.backward() self.disc_optim.step() self.cls_gan_optim.step() # Logging loss = {} loss['G_loss'] = gen_loss.item() loss['D_loss'] = disc_loss.item() # ---------------------------------------------------------------# # 4. Logging INFO into log_dir # # ---------------------------------------------------------------# log = "" if (idx + 1) % self.log_step == 0: end_time = time.time() - start_time end_time = datetime.timedelta(seconds=end_time) log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format( end_time, epoch + 1, self.num_epochs, idx) for net, loss_value in loss.items(): log += "{}: {:.4f}".format(net, loss_value) self.logger.info(log) print(log) """ # ---------------------------------------------------------------# # 5. Saving generated images # # ---------------------------------------------------------------# if (idx + 1) % self.sample_step == 0: concat_imgs = torch.cat((true_imgs, fake_imgs), 0) # ?????????? concat_imgs = (concat_imgs + 1) / 2 # out.clamp_(0, 1) save_path = os.path.join(self.sample_dir, '{}-{}-images.jpg'.format(epoch, idx + 1)) # concat_imgs.cpu().detach().numpy() self.dump_imgs(concat_imgs.cpu().numpy(), save_path) #save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0) print ('Saved real and fake images into {}...'.format(self.sample_dir)) """ # ---------------------------------------------------------------# # 6. Saving the checkpoints & final model # # ---------------------------------------------------------------# if (idx + 1) % self.model_save_step == 0: G_path = os.path.join( self.checkpoint_dir, '{}-{}-G.ckpt'.format(epoch, idx + 1)) D_path = os.path.join( self.checkpoint_dir, '{}-{}-D.ckpt'.format(epoch, idx + 1)) torch.save(self.gen.state_dict(), G_path) torch.save(self.disc.state_dict(), D_path) print('Saved model checkpoints into {}...\n'.format( self.checkpoint_dir)) print('--------------- Model Training Completed ---------------') # Saving final model into final_model directory G_path = os.path.join(self.final_model, '{}-G.pth'.format('final')) D_path = os.path.join(self.final_model, '{}-D.pth'.format('final')) torch.save(self.gen.state_dict(), G_path) torch.save(self.disc.state_dict(), D_path) print('Saved final model into {}...'.format(self.final_model))
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset( args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset( args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator loaded successfully!") if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path save_dir = args.save_dir if not os.path.exists(save_dir): os.makedirs(save_dir) checkpoints_path = save_dir # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize Fader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(trainloader): generator.train() discriminator.train() if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when random.random() > 50% if random.random() >= 0.5: print("Policy Gradient Training") sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch = sample['target'] # 64 x 50 pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) sample_size = sample['target'].size(0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss = pg_loss / math.log(2) g_logging_meters['train_loss'].update(logging_loss.item(), sample_size) logging.debug(f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}") g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: # MLE training print("MLE Training") sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1)) # (64 X 50) X 6632 train_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size(0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug(f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}") g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input']['src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = Variable(torch.ones(sample['target'].size(0)).float()) # 64 length vector with torch.no_grad(): sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable(torch.zeros(sample['target'].size(0)).float()) # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 if use_cuda: fake_labels = fake_labels.cuda() disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 d_loss = d_criterion(disc_out.squeeze(1), fake_labels) acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug(f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}") d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() if num_update % 5000 == 0: # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, dev_trg_batch) sample_size = sample['target'].size(0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug(f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}") # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable(torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _,prediction = out_batch.topk(1) prediction = prediction.squeeze(1) #64 * 50 = 6632 fake_labels = Variable(torch.zeros(sample['target'].size(0)).float()) fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 if use_cuda: fake_labels = fake_labels.cuda() disc_out = discriminator(src_sentence, fake_sentence) d_loss = d_criterion(disc_out.squeeze(1), fake_labels) acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) logging.debug(f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}") torch.save(generator, open(checkpoints_path + "/"+f"num_update{num_update}.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path +"/"+ "best_gmodel.pt", 'wb'), pickle_module=dill)
def train(opt: Options): real_label = 1 fake_label = 0 netG = Generator(opt) netD = Discriminator(opt) print(netG) print(netD) netG.apply(weights_init_g) netD.apply(weights_init_d) # summary(netD, (opt.c_dim, opt.x_dim, opt.y_dim)) dataloader = load_data(opt.data_root, opt.x_dim, opt.y_dim, opt.batch_size, opt.workers) x, y, r = get_coordinates(x_dim=opt.x_dim, y_dim=opt.y_dim, scale=opt.scale, batch_size=opt.batch_size) optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) criterion = nn.BCELoss() # criterion = nn.L1Loss() noise = torch.FloatTensor(opt.batch_size, opt.z_dim) ones = torch.ones(opt.batch_size, opt.x_dim * opt.y_dim, 1) input_ = torch.FloatTensor(opt.batch_size, opt.c_dim, opt.x_dim, opt.y_dim) label = torch.FloatTensor(opt.batch_size, 1) input_ = Variable(input_) label = Variable(label) noise = Variable(noise) if opt.use_cuda: netG = netG.cuda() netD = netD.cuda() x = x.cuda() y = y.cuda() r = r.cuda() ones = ones.cuda() criterion = criterion.cuda() input_ = input_.cuda() label = label.cuda() noise = noise.cuda() noise.data.normal_() fixed_seed = torch.bmm(ones, noise.unsqueeze(1)) def _update_discriminator(data): # for p in netD.parameters(): # p.requires_grad = True # to avoid computation netD.zero_grad() real_cpu, _ = data input_.data.copy_(real_cpu) label.data.fill_(real_label-0.1) # use smooth label for discriminator output = netD(input_) errD_real = criterion(output, label) errD_real.backward() D_x = output.data.mean() # train with fake noise.data.normal_() seed = torch.bmm(ones, noise.unsqueeze(1)) fake = netG(x, y, r, seed) label.data.fill_(fake_label) output = netD(fake.detach()) # add ".detach()" to avoid backprop through G errD_fake = criterion(output, label) errD_fake.backward() # gradients for fake/real will be accumulated D_G_z1 = output.data.mean() errD = errD_real + errD_fake optimizerD.step() # .step() can be called once the gradients are computed return fake, D_G_z1, errD, D_x def _update_generator(fake): # for p in netD.parameters(): # p.requires_grad = False # to avoid computation netG.zero_grad() label.data.fill_(real_label) # fake labels are real for generator cost output = netD(fake) errG = criterion(output, label) errG.backward() # True if backward through the graph for the second time D_G_z2 = output.data.mean() optimizerG.step() return D_G_z2, errG def _save_model(epoch): os.makedirs(opt.models_root, exist_ok=True) if epoch % 1 == 0: torch.save(netG.state_dict(), os.path.join(opt.models_root, "G-cppn-wgan-anime_{}.pth".format(epoch))) torch.save(netD.state_dict(), os.path.join(opt.models_root, "D-cppn-wgan-anime_{}.pth".format(epoch))) def _log(i, epoch, errD, errG, D_x, D_G_z1, D_G_z2, delta_time): print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f Elapsed %.2f s' % (epoch, opt.iterations, i, len(dataloader), errD.data.item(), errG.data.item(), D_x, D_G_z1, D_G_z2, delta_time)) def _save_images(i, epoch): os.makedirs(opt.images_root, exist_ok=True) if i % 100 == 0: fake = netG(x, y, r, fixed_seed) fname = os.path.join(opt.images_root, "fake_samples_{:02}-{:04}.png".format(epoch, i)) vutils.save_image(fake.data[0:64, :, :, :], fname, nrow=8) def _start(): print("Start training") for epoch in range(opt.iterations): for i, data in enumerate(dataloader, 0): start_iter = time.time() fake, D_G_z1, errD, D_x = _update_discriminator(data) D_G_z2, errG = _update_generator(fake) end_iter = time.time() _log(i, epoch, errD, errG, D_x, D_G_z1, D_G_z2, end_iter - start_iter) _save_images(i, epoch) _save_model(epoch) _start()
####discriminator_model_conf = modelpath_d['discriminator_model_conf'] #print(generator_model) #print(generator_model) #print(discriminator_model) #modelpath_d = torch.load('train-model-16-medical-adversarial/modeladversarial_40_500.pt') #discriminator_model_pose = modelpath_d['discriminator_model_conf'] #print(generator_model) #print(generator_model) #print(discriminator_model) for params in generator_model.parameters(): params.requires_grad = False for params in discriminator_model_pose.parameters(): params.requires_grad = True #for params in discriminator_model_conf.parameters(): # params.requires_grad = True # Use dataparallel generator_model = nn.DataParallel(generator_model) #discriminator_model_conf = nn.DataParallel(discriminator_model_conf) discriminator_model_pose = nn.DataParallel(discriminator_model_pose) # Datasets if args.dataset == 'lsp': lsp_train_dataset = LSP(args) args.mode = 'val' lsp_val_dataset = LSP(args) # medical
class BigGAN(): """Big GAN""" def __init__(self, device, dataloader, num_classes, configs): self.device = device self.dataloader = dataloader self.num_classes = num_classes # model settings & hyperparams # self.total_steps = configs.total_steps self.epochs = configs.epochs self.d_iters = configs.d_iters self.g_iters = configs.g_iters self.batch_size = configs.batch_size self.imsize = configs.imsize self.nz = configs.nz self.ngf = configs.ngf self.ndf = configs.ndf self.g_lr = configs.g_lr self.d_lr = configs.d_lr self.beta1 = configs.beta1 self.beta2 = configs.beta2 # instance noise self.inst_noise_sigma = configs.inst_noise_sigma self.inst_noise_sigma_iters = configs.inst_noise_sigma_iters # model logging and saving self.log_step = configs.log_step self.save_epoch = configs.save_epoch self.model_path = configs.model_path self.sample_path = configs.sample_path # pretrained self.pretrained_model = configs.pretrained_model # building self.build_model() # archive of all losses self.ave_d_losses = [] self.ave_d_losses_real = [] self.ave_d_losses_fake = [] self.ave_g_losses = [] if self.pretrained_model: self.load_pretrained() def build_model(self): """Initiate Generator and Discriminator""" self.G = Generator(self.nz, self.ngf, self.num_classes).to(self.device) self.D = Discriminator(self.ndf, self.num_classes).to(self.device) self.g_optimizer = optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) print("Generator Parameters: ", parameters(self.G)) print(self.G) print("Discriminator Parameters: ", parameters(self.D)) print(self.D) print("Number of classes: ", self.num_classes) def load_pretrained(self): """Loading pretrained model""" checkpoint = torch.load( os.path.join(self.model_path, "{}_biggan.pth".format(self.pretrained_model))) # load models self.G.load_state_dict(checkpoint["g_state_dict"]) self.D.load_state_dict(checkpoint["d_state_dict"]) # load optimizers self.g_optimizer.load_state_dict(checkpoint["g_optimizer"]) self.d_optimizer.load_state_dict(checkpoint["d_optimizer"]) # load losses self.ave_d_losses = checkpoint["ave_d_losses"] self.ave_d_losses_real = checkpoint["ave_d_losses_real"] self.ave_d_losses_fake = checkpoint["ave_d_losses_fake"] self.ave_g_losses = checkpoint["ave_g_losses"] print("Loading pretrained models (epoch: {})..!".format( self.pretrained_model)) def reset_grad(self): """Reset gradients""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def train(self): """Train model""" step_per_epoch = len(self.dataloader) epochs = self.epochs total_steps = epochs * step_per_epoch # fixed z and labels for sampling generator images fixed_z = tensor2var(torch.randn(self.batch_size, self.nz), device=self.device) fixed_labels = tensor2var(torch.from_numpy( np.tile(np.arange(self.num_classes), self.batch_size)).long(), device=self.device) print("Initiating Training") print("Epochs: {}, Total Steps: {}, Steps/Epoch: {}".format( epochs, total_steps, step_per_epoch)) if self.pretrained_model: start_epoch = self.pretrained_model else: start_epoch = 0 self.D.train() self.G.train() # Instance noise - make random noise mean (0) and std for injecting inst_noise_mean = torch.full( (self.batch_size, 3, self.imsize, self.imsize), 0).to(self.device) inst_noise_std = torch.full( (self.batch_size, 3, self.imsize, self.imsize), self.inst_noise_sigma).to(self.device) # total time start_time = time.time() for epoch in range(start_epoch, epochs): # local losses d_losses = [] d_losses_real = [] d_losses_fake = [] g_losses = [] data_iter = iter(self.dataloader) for step in range(step_per_epoch): # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters inst_noise_sigma_curr = 0 if step > self.inst_noise_sigma_iters else ( 1 - step / self.inst_noise_sigma_iters) * self.inst_noise_sigma inst_noise_std.fill_(inst_noise_sigma_curr) # get real images real_images, real_labels = next(data_iter) real_images = real_images.to(self.device) real_labels = real_labels.to(self.device) # ================== TRAIN DISCRIMINATOR ================== # for _ in range(self.d_iters): self.reset_grad() # TRAIN REAL # creating instance noise inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to( self.device) # adding noise to real images d_real = self.D(real_images + inst_noise, real_labels) d_loss_real = loss_hinge_dis_real(d_real) d_loss_real.backward() # delete loss if (step + 1) % self.log_step != 0: del d_real, d_loss_real # TRAIN FAKE # create fake images using latent vector z = tensor2var(torch.randn(real_images.size(0), self.nz), device=self.device) fake_images = self.G(z, real_labels) # creating instance noise inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to( self.device) # adding noise to fake images # detach fake_images tensor from graph d_fake = self.D(fake_images.detach() + inst_noise, real_labels) d_loss_fake = loss_hinge_dis_fake(d_fake) d_loss_fake.backward() # delete loss, output del fake_images if (step + 1) % self.log_step != 0: del d_fake, d_loss_fake # optimize D self.d_optimizer.step() # ================== TRAIN GENERATOR ================== # for _ in range(self.g_iters): self.reset_grad() # create new latent vector z = tensor2var(torch.randn(real_images.size(0), self.nz), device=self.device) # generate fake images inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to( self.device) fake_images = self.G(z, real_labels) g_fake = self.D(fake_images + inst_noise, real_labels) # compute hinge loss for G g_loss = loss_hinge_gen(g_fake) g_loss.backward() del fake_images if (step + 1) % self.log_step != 0: del g_fake, g_loss # optimize G self.g_optimizer.step() # logging step progression if (step + 1) % self.log_step == 0: d_loss = d_loss_real + d_loss_fake # logging losses d_losses.append(d_loss.item()) d_losses_real.append(d_loss_real.item()) d_losses_fake.append(d_loss_fake.item()) g_losses.append(g_loss.item()) # print out elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], Epoch: [{}/{}], Step [{}/{}], g_loss: {:.4f}, d_loss: {:.4f}," " d_loss_real: {:.4f}, d_loss_fake: {:.4f}".format( elapsed, (epoch + 1), epochs, (step + 1), step_per_epoch, g_loss, d_loss, d_loss_real, d_loss_fake)) del d_real, d_loss_real, d_fake, d_loss_fake, g_fake, g_loss # logging average losses over epoch self.ave_d_losses.append(mean(d_losses)) self.ave_d_losses_real.append(mean(d_losses_real)) self.ave_d_losses_fake.append(mean(d_losses_fake)) self.ave_g_losses.append(mean(g_losses)) # epoch update print( "Elapsed [{}], Epoch: [{}/{}], ave_g_loss: {:.4f}, ave_d_loss: {:.4f}," " ave_d_loss_real: {:.4f}, ave_d_loss_fake: {:.4f},".format( elapsed, epoch + 1, epochs, self.ave_g_losses[epoch], self.ave_d_losses[epoch], self.ave_d_losses_real[epoch], self.ave_d_losses_fake[epoch])) # sample images every epoch fake_images = self.G(fixed_z, fixed_labels) fake_images = denorm(fake_images.data) save_image( fake_images, os.path.join(self.sample_path, "Epoch {}.png".format(epoch + 1))) # save model if (epoch + 1) % self.save_epoch == 0: torch.save( { "g_state_dict": self.G.state_dict(), "d_state_dict": self.D.state_dict(), "g_optimizer": self.g_optimizer.state_dict(), "d_optimizer": self.d_optimizer.state_dict(), "ave_d_losses": self.ave_d_losses, "ave_d_losses_real": self.ave_d_losses_real, "ave_d_losses_fake": self.ave_d_losses_fake, "ave_g_losses": self.ave_g_losses }, os.path.join(self.model_path, "{}_biggan.pth".format(epoch + 1))) print("Saving models (epoch {})..!".format(epoch + 1)) def plot(self): plt.plot(self.ave_d_losses) plt.plot(self.ave_d_losses_real) plt.plot(self.ave_d_losses_fake) plt.plot(self.ave_g_losses) plt.legend(["d loss", "d real", "d fake", "g loss"], loc="upper left") plt.show()
#if epoch_ % 50 == 0 and epoch_ != 0: #save_state(save_dir, epoch_, G, D) if __name__ == "__main__": args = parse_args() real_label = 1 fake_label = 0 if args.device == "gpu": device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") data_loader = load_data(args.img_dir, args.img_size, args.batch_size) G = Generator(args.img_size, args.norm, args.up_type, device).to(device) D = Discriminator(args.img_size, args.norm, args.spectral, args.noise).to(device) G.apply(weights_init) D.apply(weights_init) optimizerD = optim.Adam(D.parameters(), lr=args.D_lr, betas=(0.5, 0.999)) optimizerG = optim.Adam(G.parameters(), lr=args.G_lr, betas=(0.5, 0.999)) loss = Loss(args.loss) main(args.epoch, device, args.batch_size, loss, G, D, optimizerG, optimizerD, data_loader)
train=True, download=True, transform=transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] # [] means channel, 0.5,0.5 means mean & std # => img = (img - mean) / 0.5 per channel ), ), batch_size=opt.batch_size, shuffle=True, ) # Optimizers optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr) optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- batches_done=0 for epoch in range(opt.n_epochs): for i, (imgs, _) in enumerate(dataloader): # batch id, (image, target) # Configure input real_imgs = imgs.type(Tensor) # -----------------
'''record basic information''' record('basic info', (TMUX + '<br>' + 'train' + '<br>' + 'wgan-gp: ' + str(if_use_wgan_gp)), 'text') if torch.cuda.is_available(): use_cuda = True discriminator.cuda() generator.cuda() one = one.cuda() mone = mone.cuda() loss_function = nn.BCELoss() if if_use_wgan_gp: d_optim = torch.optim.RMSprop(discriminator.parameters(), lr=1e-4, eps=1e-5, alpha=0.99) g_optim = torch.optim.RMSprop(generator.parameters(), lr=1e-4, eps=1e-5, alpha=0.99) else: d_optim = torch.optim.Adagrad(discriminator.parameters(), lr=lr) g_optim = torch.optim.Adagrad(generator.parameters(), lr=lr) num_epoch = 120 dataloader = DataLoader(batch_size) num_batch = int(dataloader.num_batches) # length of data / batch_size print(num_batch)
D = Discriminator(conv_size) G = Generator(z_size, conv_size) cuda = False if torch.cuda.is_available(): cuda = True D = D.cuda() G = G.cuda() lr = 0.0002 beta1 = 0.5 beta2 = 0.99 d_optim = optim.Adam(D.parameters(), lr, [beta1, beta2]) g_optim = optim.Adam(G.parameters(), lr, [beta1, beta2]) def train_discriminator(real_images, optimizer, batch_size, z_size): optimizer.zero_grad() if cuda: real_images = real_images.cuda() # Loss for real image d_real_loss = real_loss(D(real_images), cuda, smooth=True) # Loss for fake image fake_images = G(generate_z_vector(batch_size, z_size, cuda)) d_fake_loss = fake_loss(D(fake_images), cuda)
# # loss = train_generator(generators[i], label_data_iterators[i], gen_criterions[i], gen_optimizers[i]) # bleu_s = 0#bleu_4(TEXT, corpus, generators[i], g_sequence_len, count=100) # print('Epoch [{}], Generator: {}, loss: {}, Perplexity: {}'.format(epoch, generators[i].name, loss, math.exp(loss))) # print('-'*25) exit(0) d_num_class = len(label_names) + 1 discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim, d_filter_sizes, d_num_filters, d_dropout) discriminator.embedding.weight.data = TEXT.vocab.vectors if opt.cuda: discriminator = discriminator.cuda() # Pretrain Discriminator dis_criterion = nn.NLLLoss(size_average=False) dis_optimizer = optim.Adam(discriminator.parameters()) if opt.cuda: dis_criterion = dis_criterion.cuda() print('Pretrain Discriminator ...') for epoch in range(PRE_EPOCH_NUM): loss, acc = train_discriminator(discriminator, generators, real_data_iterator, dis_criterion, dis_optimizer) print('Epoch [{}], loss: {}, accuracy: {}'.format(epoch, loss, acc)) # # Adversarial Training rollouts = [Rollout(generator, 0.8) for generator in generators] print('#####################################################') print('Start Adversarial Training...') gen_gan_losses = [GANLoss() for _ in generators] gen_gan_optm = [optim.Adam(generator.parameters()) for generator in generators]
def main(args): # log hyperparameter print(args) # select device args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda: 0" if args.cuda else "cpu") # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) # data loader transform = transforms.Compose([ utils.Normalize(), utils.ToTensor() ]) train_dataset = TVDataset( root=args.root, sub_size=args.block_size, volume_list=args.volume_train_list, max_k=args.training_step, train=True, transform=transform ) test_dataset = TVDataset( root=args.root, sub_size=args.block_size, volume_list=args.volume_test_list, max_k=args.training_step, train=False, transform=transform ) kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {} train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) # model def generator_weights_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) def discriminator_weights_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.zeros_(m.bias) g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual) g_model.apply(generator_weights_init) if args.data_parallel and torch.cuda.device_count() > 1: g_model = nn.DataParallel(g_model) g_model.to(device) if args.gan_loss != "none": d_model = Discriminator(args.dis_sn) d_model.apply(discriminator_weights_init) # if args.dis_sn: # d_model = add_sn(d_model) if args.data_parallel and torch.cuda.device_count() > 1: d_model = nn.DataParallel(d_model) d_model.to(device) mse_loss = nn.MSELoss() adversarial_loss = nn.MSELoss() train_losses, test_losses = [], [] d_losses, g_losses = [], [] # optimizer g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) if args.gan_loss != "none": d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr, betas=(args.beta1, args.beta2)) Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor # load checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint {}".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] g_model.load_state_dict(checkpoint["g_model_state_dict"]) # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"]) if args.gan_loss != "none": d_model.load_state_dict(checkpoint["d_model_state_dict"]) # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"]) d_losses = checkpoint["d_losses"] g_losses = checkpoint["g_losses"] train_losses = checkpoint["train_losses"] test_losses = checkpoint["test_losses"] print("=> load chekcpoint {} (epoch {})" .format(args.resume, checkpoint["epoch"])) # main loop for epoch in tqdm(range(args.start_epoch, args.epochs)): # training.. g_model.train() if args.gan_loss != "none": d_model.train() train_loss = 0. volume_loss_part = np.zeros(args.training_step) for i, sample in enumerate(train_loader): params = list(g_model.named_parameters()) # pdb.set_trace() # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g))) # adversarial ground truths real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False) fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False) v_f = sample["v_f"].to(device) v_b = sample["v_b"].to(device) v_i = sample["v_i"].to(device) g_optimizer.zero_grad() fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm) # adversarial loss # update discriminator if args.gan_loss != "none": avg_d_loss = 0. avg_d_loss_real = 0. avg_d_loss_fake = 0. for k in range(args.n_d): d_optimizer.zero_grad() decisions = d_model(v_i) d_loss_real = adversarial_loss(decisions, real_label) fake_decisions = d_model(fake_volumes.detach()) d_loss_fake = adversarial_loss(fake_decisions, fake_label) d_loss = d_loss_real + d_loss_fake d_loss.backward() avg_d_loss += d_loss.item() / args.n_d avg_d_loss_real += d_loss_real / args.n_d avg_d_loss_fake += d_loss_fake / args.n_d d_optimizer.step() # update generator if args.gan_loss != "none": avg_g_loss = 0. avg_loss = 0. for k in range(args.n_g): loss = 0. g_optimizer.zero_grad() # adversarial loss if args.gan_loss != "none": fake_decisions = d_model(fake_volumes) g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label) loss += g_loss avg_g_loss += g_loss.item() / args.n_g # volume loss if args.volume_loss: volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes) for j in range(v_i.shape[1]): volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every loss += volume_loss # feature loss if args.feature_loss: feat_real = d_model.extract_features(v_i) feat_fake = d_model.extract_features(fake_volumes) for m in range(len(feat_real)): loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m]) avg_loss += loss / args.n_g loss.backward() g_optimizer.step() train_loss += avg_loss # log training status subEpoch = (i + 1) // args.log_every if (i+1) % args.log_every == 0: print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader), avg_loss )) print("Volume Loss: ") for j in range(volume_loss_part.shape[0]): print("\tintermediate {}: {:.6f}".format( j+1, volume_loss_part[j] )) if args.gan_loss != "none": print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format( avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss )) d_losses.append(avg_d_loss) g_losses.append(avg_g_loss) # train_losses.append(avg_loss) train_losses.append(train_loss.item() / args.log_every) print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format( subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time())) )) train_loss = 0. volume_loss_part = np.zeros(args.training_step) # testing... if (i + 1) % args.test_every == 0: g_model.eval() if args.gan_loss != "none": d_model.eval() test_loss = 0. with torch.no_grad(): for i, sample in enumerate(test_loader): v_f = sample["v_f"].to(device) v_b = sample["v_b"].to(device) v_i = sample["v_i"].to(device) fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm) test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item() test_losses.append(test_loss * args.batch_size / len(test_loader.dataset)) print("====> SubEpoch: {} Test set loss {:4f} Time {}".format( subEpoch, test_losses[-1], time.asctime(time.localtime(time.time())) )) # saving... if (i+1) % args.check_every == 0: print("=> saving checkpoint at epoch {}".format(epoch)) if args.gan_loss != "none": torch.save({"epoch": epoch + 1, "g_model_state_dict": g_model.state_dict(), "g_optimizer_state_dict": g_optimizer.state_dict(), "d_model_state_dict": d_model.state_dict(), "d_optimizer_state_dict": d_optimizer.state_dict(), "d_losses": d_losses, "g_losses": g_losses, "train_losses": train_losses, "test_losses": test_losses}, os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar") ) else: torch.save({"epoch": epoch + 1, "g_model_state_dict": g_model.state_dict(), "g_optimizer_state_dict": g_optimizer.state_dict(), "train_losses": train_losses, "test_losses": test_losses}, os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar") ) torch.save(g_model.state_dict(), os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth")) num_subEpoch = len(train_loader) // args.log_every print("====> Epoch: {} Average loss: {:.6f} Time {}".format( epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time())) ))
parser.add_argument( '--retrain', action='store_true', help='Whether or not to start training from a previous state.') args = parser.parse_args() print("Initializing generator model and optimizer.") g_net = Generator().cuda() g_opt = optim.RMSprop(g_net.parameters(), args.learning_rate_d, weight_decay=args.rmsprop_decay) g_losses = np.empty(0) print("Initializing discriminator model and optimizer.") d_net = Discriminator().cuda() d_opt = optim.RMSprop(d_net.parameters(), args.learning_rate_d, weight_decay=args.rmsprop_decay) d_losses = np.empty(0) if args.retrain: g_net.load_state_dict(torch.load('../data/generator_state')) d_net.load_state_dict(torch.load('../data/discriminator_state')) print("Beginning training..") loader = ETL(args.batch_size, args.image_size, args.path) for iteration in range(args.iterations): # Train discriminator for _ in range(args.k_discriminator):
class AdvGAN_Pretrain: def __init__(self, device, model, model_num_labels, box_min, box_max): self.device = device self.model_num_labels = model_num_labels self.model = model self.box_min = box_min self.box_max = box_max self.netG = Generator().to(device) self.netDisc = Discriminator().to(device) # initialize all weights self.netG.apply(weights_init) self.netDisc.apply(weights_init) # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-3) self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(), lr=1e-3) if not os.path.exists(models_path): os.makedirs(models_path) def train_batch(self, x, labels): # optimize D for i in range(1): # Mask mask = torch.ones(x.size(0), 1, x.size(2), x.size(3)) mask = mask.type(torch.FloatTensor).to(self.device) x_with_mask = torch.cat((x, mask), 1).to(self.device) perturbation = self.netG(x_with_mask) # add a clipping trick adv_images = torch.clamp(perturbation, -0.3, 0.3) * mask + x # adv_images = torch.clamp(perturbation, -0.3, 0.3) + x adv_images = torch.clamp(adv_images, self.box_min, self.box_max) self.optimizer_D.zero_grad() pred_real = self.netDisc(x) loss_D_real = F.mse_loss( pred_real, torch.ones_like(pred_real, device=self.device)) loss_D_real.backward() pred_fake = self.netDisc(adv_images.detach()) loss_D_fake = F.mse_loss( pred_fake, torch.zeros_like(pred_fake, device=self.device)) loss_D_fake.backward() loss_D_GAN = loss_D_fake + loss_D_real self.optimizer_D.step() # optimize G for i in range(1): self.optimizer_G.zero_grad() # cal G's loss in GAN pred_fake = self.netDisc(adv_images) loss_G_fake = F.mse_loss( pred_fake, torch.ones_like(pred_fake, device=self.device)) loss_G_fake.backward(retain_graph=True) # calculate perturbation norm loss_perturb = torch.mean( torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1)) # loss_perturb = torch.max(loss_perturb - C, torch.zeros(1, device=self.device)) # cal adv loss logits_model = self.model(adv_images) probs_model = F.softmax(logits_model, dim=1) onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels] # C&W loss function real = torch.sum(onehot_labels * probs_model, dim=1) other, _ = torch.max( (1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1) zeros = torch.zeros_like(other) loss_adv = torch.max(real - other, zeros) loss_adv = torch.sum(loss_adv) # maximize cross_entropy loss # loss_adv = -F.mse_loss(logits_model, onehot_labels) # loss_adv = - F.cross_entropy(logits_model, labels) adv_lambda = 10 pert_lambda = 1 loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb loss_G.backward() self.optimizer_G.step() return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item( ), loss_adv.item() def train(self, train_dataloader, epochs): writer = SummaryWriter(log_dir="visualization/pre_advgan/", comment='Pretrained AdvGAN stats') for epoch in range(1, epochs + 1): if epoch == 50: self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4) self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(), lr=1e-4) if epoch == 80: self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-5) self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(), lr=1e-5) loss_D_sum = 0 loss_G_fake_sum = 0 loss_perturb_sum = 0 loss_adv_sum = 0 for i, data in enumerate(train_dataloader, start=0): images, labels = data images, labels = images.to(self.device), labels.to(self.device) loss_D_batch, loss_G_fake_batch, loss_perturb_batch, loss_adv_batch = \ self.train_batch(images, labels) loss_D_sum += loss_D_batch loss_G_fake_sum += loss_G_fake_batch loss_perturb_sum += loss_perturb_batch loss_adv_sum += loss_adv_batch # print statistics num_batch = len(train_dataloader) writer.add_scalar('discriminator_loss', loss_D_sum / num_batch, epoch) writer.add_scalar('generator_loss', loss_G_fake_sum / num_batch, epoch) writer.add_scalar('perturbation_loss', loss_perturb_sum / num_batch, epoch) writer.add_scalar('adversarial_loss', loss_adv_sum / num_batch, epoch) print("epoch %d:\nloss_D: %.5f, loss_G_fake: %.5f,\ \nloss_perturb: %.5f, loss_adv: %.5f\n" % (epoch, loss_D_sum / num_batch, loss_G_fake_sum / num_batch, loss_perturb_sum / num_batch, loss_adv_sum / num_batch)) # save generator if epoch % 20 == 0: netG_file_name = models_path + 'netG_pretrained_epoch_' + str( epoch) + '.pth' torch.save(self.netG.state_dict(), netG_file_name) netDisc_file_name = models_path + 'netDisc_pretrained_epoch_' + str( epoch) + '.pth' torch.save(self.netDisc.state_dict(), netDisc_file_name) writer.close()
fixed_labels = torch.zeros(SAMPLE_SIZE, NUM_LABELS) for i in range(NUM_LABELS): for j in range(SAMPLE_SIZE // NUM_LABELS): fixed_labels[i*(SAMPLE_SIZE // NUM_LABELS) + j, i] = 1.0 label = torch.FloatTensor(args.batch_size) one_hot_labels = torch.FloatTensor(args.batch_size, 10) if args.cuda: model_d.cuda() model_g.cuda() input, label = input.cuda(), label.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() one_hot_labels = one_hot_labels.cuda() fixed_labels = fixed_labels.cuda() optim_d = optim.SGD(model_d.parameters(), lr=args.lr) optim_g = optim.SGD(model_g.parameters(), lr=args.lr) fixed_noise = Variable(fixed_noise) fixed_labels = Variable(fixed_labels) real_label = 1 fake_label = 0 for epoch_idx in range(args.epochs): model_d.train() model_g.train() d_loss = 0.0 g_loss = 0.0 for batch_idx, (train_x, train_y) in enumerate(train_loader):
def main(): # # -------------------- Data -------------------- num_workers = 8 # number of subprocesses to use for data loading batch_size = 64 # how many samples per batch to load transform = transforms.ToTensor() # convert data to torch.FloatTensor train_data = datasets.MNIST(root='../data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers) # # Obtain one batch of training images # dataiter = iter(train_loader) # images, labels = dataiter.next() # images = images.numpy() # # Get one image from the batch for visualization # img = np.squeeze(images[0]) # fig = plt.figure(figsize=(3, 3)) # ax = fig.add_subplot(111) # ax.imshow(img, cmap='gray') # plt.show() # # -------------------- Discriminator and Generator -------------------- # Discriminator hyperparams input_size = 784 # Size of input image to discriminator (28*28) d_output_size = 1 # Size of discriminator output (real or fake) d_hidden_size = 32 # Size of last hidden layer in the discriminator # Generator hyperparams z_size = 100 # Size of latent vector to give to generator g_output_size = 784 # Size of discriminator output (generated image) g_hidden_size = 32 # Size of first hidden layer in the generator # Instantiate discriminator and generator D = Discriminator(input_size, d_hidden_size, d_output_size) G = Generator(z_size, g_hidden_size, g_output_size) # # -------------------- Optimizers and Criterion -------------------- # Training hyperparams num_epochs = 100 print_every = 400 lr = 0.002 # Create optimizers for the discriminator and generator, respectively d_optimizer = optim.Adam(D.parameters(), lr) g_optimizer = optim.Adam(G.parameters(), lr) losses = [] # keep track of generated "fake" samples criterion = nn.BCEWithLogitsLoss() # -------------------- Training -------------------- D.train() G.train() # Get some fixed data for sampling. These are images that are held # constant throughout training, and allow us to inspect the model's performance sample_size = 16 fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size)) fixed_z = torch.from_numpy(fixed_z).float() samples = [] # keep track of loss for epoch in range(num_epochs): for batch_i, (real_images, _) in enumerate(train_loader): batch_size = real_images.size(0) # Important rescaling step real_images = real_images * 2 - 1 # rescale input images from [0,1) to [-1, 1) # Generate fake images, used for both discriminator and generator z = np.random.uniform(-1, 1, size=(batch_size, z_size)) z = torch.from_numpy(z).float() fake_images = G(z) real_labels = torch.ones(batch_size) fake_labels = torch.zeros(batch_size) # ============================================ # TRAIN THE DISCRIMINATOR # ============================================ d_optimizer.zero_grad() # 1. Train with real images # Compute the discriminator losses on real images D_real = D(real_images) d_real_loss = real_loss(criterion, D_real, real_labels, smooth=True) # 2. Train with fake images # Compute the discriminator losses on fake images # ------------------------------------------------------- # ATTENTION: # *.detach(), thus, generator is fixed when we optimize # the discriminator # ------------------------------------------------------- D_fake = D(fake_images.detach()) d_fake_loss = fake_loss(criterion, D_fake, fake_labels) # 3. Add up loss and perform backprop d_loss = (d_real_loss + d_fake_loss) * 0.5 d_loss.backward() d_optimizer.step() # ========================================= # TRAIN THE GENERATOR # ========================================= g_optimizer.zero_grad() # Make the discriminator fixed when optimizing the generator set_model_gradient(D, False) # 1. Train with fake images and flipped labels # Compute the discriminator losses on fake images using flipped labels! G_D_fake = D(fake_images) g_loss = real_loss(criterion, G_D_fake, real_labels) # use real loss to flip labels # 2. Perform backprop g_loss.backward() g_optimizer.step() # Make the discriminator require_grad=True after optimizing the generator set_model_gradient(D, True) # ========================================= # Print some loss stats # ========================================= if batch_i % print_every == 0: print( 'Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'. format(epoch + 1, num_epochs, d_loss.item(), g_loss.item())) # AFTER EACH EPOCH losses.append((d_loss.item(), g_loss.item())) # generate and save sample, fake images G.eval() # eval mode for generating samples samples_z = G(fixed_z) samples.append(samples_z) view_samples(-1, samples, "last_sample.png") G.train() # back to train mode # Save models and training generator samples torch.save(G.state_dict(), "G.pth") torch.save(D.state_dict(), "D.pth") with open('train_samples.pkl', 'wb') as f: pkl.dump(samples, f) # Plot the loss curve fig, ax = plt.subplots() losses = np.array(losses) plt.plot(losses.T[0], label='Discriminator') plt.plot(losses.T[1], label='Generator') plt.title("Training Losses") plt.legend() plt.savefig("loss.png") plt.show()
class trainer(object): def __init__(self, cfg): self.cfg = cfg self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS, out_ch=cfg.DATASET.N_CLASS, side='out') self.Image_generator = U_Net(in_ch=3, out_ch=cfg.DATASET.N_CLASS, side='in') self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3, cfg.DATASET.IMGSIZE, patch=True) self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0], cfg.LOSS.LOSS_WEIGHT[1], cfg.LOSS.LOSS_WEIGHT[2], ignore_index=cfg.LOSS.IGNORE_INDEX) self.criterion_D = DiscriminatorLoss() train_dataset = BaseDataset(cfg, split='train') valid_dataset = BaseDataset(cfg, split='val') self.train_dataloader = data.DataLoader( train_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.valid_dataloader = data.DataLoader( valid_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints') if not os.path.isdir(self.ckpt_outdir): os.mkdir(self.ckpt_outdir) self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val') if not os.path.isdir(self.val_outdir): os.mkdir(self.val_outdir) self.start_epoch = cfg.TRAIN.RESUME self.n_epoch = cfg.TRAIN.N_EPOCH self.optimizer_G = torch.optim.Adam( [{ 'params': self.OldLabel_generator.parameters() }, { 'params': self.Image_generator.parameters() }], lr=cfg.OPTIMIZER.G_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) self.optimizer_D = torch.optim.Adam( [{ 'params': self.discriminator.parameters(), 'initial_lr': cfg.OPTIMIZER.D_LR }], lr=cfg.OPTIMIZER.D_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE lambda_poly = lambda iters: pow( (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9) self.scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.logger = logger(cfg.TRAIN.OUTDIR, name='train') self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS) if self.start_epoch >= 0: self.OldLabel_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_N']) self.Image_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_I']) self.discriminator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_D']) self.optimizer_G.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_G']) self.optimizer_D.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_D']) log = "Using the {}th checkpoint".format(self.start_epoch) self.logger.info(log) self.Image_generator = self.Image_generator.cuda() self.OldLabel_generator = self.OldLabel_generator.cuda() self.discriminator = self.discriminator.cuda() self.criterion_G = self.criterion_G.cuda() self.criterion_D = self.criterion_D.cuda() def train(self): all_train_iter_total_loss = [] all_train_iter_corr_loss = [] all_train_iter_recover_loss = [] all_train_iter_change_loss = [] all_train_iter_gan_loss_gen = [] all_train_iter_gan_loss_dis = [] all_val_epo_iou = [] all_val_epo_acc = [] iter_num = [0] epoch_num = [] num_batches = len(self.train_dataloader) for epoch_i in range(self.start_epoch + 1, self.n_epoch): iter_total_loss = AverageTracker() iter_corr_loss = AverageTracker() iter_recover_loss = AverageTracker() iter_change_loss = AverageTracker() iter_gan_loss_gen = AverageTracker() iter_gan_loss_dis = AverageTracker() batch_time = AverageTracker() tic = time.time() # train self.OldLabel_generator.train() self.Image_generator.train() self.discriminator.train() for i, meta in enumerate(self.train_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) # ------------------- # Train Discriminator # ------------------- self.discriminator.set_requires_grad(True) self.optimizer_D.zero_grad() fake_sample = torch.cat((image, corr_pred), 1).detach() real_sample = torch.cat( (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1) score_fake_d = self.discriminator(fake_sample) score_real = self.discriminator(real_sample) gan_loss_dis = self.criterion_D(pred_score=score_fake_d, real_score=score_real) gan_loss_dis.backward() self.optimizer_D.step() self.scheduler_D.step() # --------------- # Train Generator # --------------- self.discriminator.set_requires_grad(False) self.optimizer_G.zero_grad() score_fake = self.discriminator( torch.cat((image, corr_pred), 1)) total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G( corr_pred, recover_pred, score_fake, old_label, new_label) total_loss.backward() self.optimizer_G.step() self.scheduler_G.step() iter_total_loss.update(total_loss.item()) iter_corr_loss.update(corr_loss.item()) iter_recover_loss.update(recover_loss.item()) iter_change_loss.update(change_loss.item()) iter_gan_loss_gen.update(gan_loss_gen.item()) iter_gan_loss_dis.update(gan_loss_dis.item()) batch_time.update(time.time() - tic) tic = time.time() log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \ 'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format( datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item()) print(log) if (i + 1) % 10 == 0: all_train_iter_total_loss.append(iter_total_loss.avg) all_train_iter_corr_loss.append(iter_corr_loss.avg) all_train_iter_recover_loss.append(iter_recover_loss.avg) all_train_iter_change_loss.append(iter_change_loss.avg) all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg) all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg) iter_total_loss.reset() iter_corr_loss.reset() iter_recover_loss.reset() iter_change_loss.reset() iter_gan_loss_gen.reset() iter_gan_loss_dis.reset() vis.line(X=np.column_stack( np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)), Y=np.column_stack((all_train_iter_total_loss, all_train_iter_corr_loss, all_train_iter_recover_loss, all_train_iter_change_loss, all_train_iter_gan_loss_gen, all_train_iter_gan_loss_dis)), opts={ 'legend': [ 'total_loss', 'corr_loss', 'recover_loss', 'change_loss', 'gan_loss_gen', 'gan_loss_dis' ], 'linecolor': np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255], [255, 0, 255]]), 'title': 'Train loss of generator and discriminator' }, win='Train loss of generator and discriminator') iter_num.append(iter_num[-1] + 1) # eval self.OldLabel_generator.eval() self.Image_generator.eval() self.discriminator.eval() with torch.no_grad(): for j, meta in enumerate(self.valid_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) preds = np.argmax(corr_pred.cpu().detach().numpy().copy(), axis=1) target = new_label.cpu().detach().numpy().copy() self.running_metrics.update(target, preds) if j == 0: color_map1 = gen_color_map(preds[0, :]).astype( np.uint8) color_map2 = gen_color_map(preds[1, :]).astype( np.uint8) color_map = cv2.hconcat([color_map1, color_map2]) cv2.imwrite( os.path.join( self.val_outdir, '{}epoch*{}*{}.png'.format( epoch_i, meta[3][0], meta[3][1])), color_map) score = self.running_metrics.get_scores() oa = score['Overall Acc: \t'] precision = score['Precision: \t'][1] recall = score['Recall: \t'][1] iou = score['Class IoU: \t'][1] miou = score['Mean IoU: \t'] self.running_metrics.reset() epoch_num.append(epoch_i) all_val_epo_acc.append(oa) all_val_epo_iou.append(miou) vis.line(X=np.column_stack( np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)), Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)), opts={ 'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'], 'linecolor': np.array([[255, 0, 0], [0, 255, 0]]), 'title': 'Validate Accuracy and IoU' }, win='validate Accuracy and IoU') log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \ .format(datetime.now(), epoch_i, oa, recall, miou) self.logger.info(log) state = { 'epoch': epoch_i, "acc": oa, "recall": recall, "iou": miou, 'model_G_N': self.OldLabel_generator.state_dict(), 'model_G_I': self.Image_generator.state_dict(), 'model_D': self.discriminator.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict() } save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i)) torch.save(state, save_path)
make_dir_if_needed(args.data_output) generator = UNet(71, 3, False).to(device) discriminator = Discriminator(3, args.size).to(device) if args.dataset == 'youtube': gl_data_sampler = YoutubeFaces(args.data_dir, device=device, size=args.size) disc_data_sampler = YoutubeFaces(args.data_dir, device=device, len=3, size=args.size) elif args.dataset == 'my': gl_data_sampler = MyDatasetSampler(args.data_dir, device=device, size=args.size) disc_data_sampler = MyDatasetSampler(args.data_dir, device=device, length=3, size=args.size) compute_perceptual = PerceptualLoss().to(device) gen_optim = ranger(generator.parameters()) disc_optim = ranger(discriminator.parameters()) losses = [] for e in range(epochs): print('EPOCH {}'.format(e)) for first, second, third in tqdm(DataLoader(gl_data_sampler, batch_size=1)): generator.train(False) discriminator.train(True) for d_first, d_second, _ in DataLoader(disc_data_sampler, batch_size=1): disc_optim.zero_grad() gen_in = torch.cat([d_first[0], d_second[1]], 1)
class SEQGANs(nn.Module): def __init__(self): super().__init__() self.l2_reg_lambda = 0.2 self.batch_size = 8 #batch的大小,为1的时候,过程有使用unsqueeze,可能会出错 self.filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] # 判别器的窗口大小(也即每个窗口包含多少个单词) self.num_filters = [ 100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160 ] # 判别器channels数量 self.num_classes = 1 # 判别器分类类别数量(输出结点数) self.embedding_size = 100 # 单词embedding大小 self.hidden_size_gru = 100 # GRU的隐藏层大小 self.start_idx = 0 #开始token的序号 self.end_idx = 1 #结束token的序号 self.padding_idx = 2 #填充token的序号 self.start_input = torch.tensor( self.batch_size * [self.start_idx]).cuda() #Generator开始的输入 self.start_h = torch.zeros( self.batch_size, self.hidden_size_gru).cuda() #Generator开始的状态 self.rollout_num = 10 #rollout的数量 self.dataset = DataSet_Obama(root_src=r'../datas/obama/input.txt', start_idx=self.start_idx, end_idx=self.end_idx, padding_idx=self.padding_idx) #载入真实数据 self.sequence_length = self.dataset.max_doclen + 1 # 真实数据集的最大句子长度+1(算上end token) self.vocab_size = self.dataset.dictionary.__len__() # 字典大小 self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=2) self.G = Generator(self.vocab_size, self.embedding_size, self.hidden_size_gru) self.D = Discriminator(self.sequence_length, self.num_classes, self.vocab_size, self.embedding_size, self.filter_sizes, self.num_filters) self.embeddings = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_size, padding_idx=self.padding_idx) self.pre_optimizer = torch.optim.Adam([{ 'params': self.G.parameters() }, { 'params': self.embeddings.parameters() }]) self.G_optimizer = torch.optim.Adam([{ 'params': self.G.parameters() }, { 'params': self.embeddings.parameters() }]) self.D_optimizer = torch.optim.Adam([{ 'params': self.D.parameters() }, { 'params': self.embeddings.parameters() }]) def forward(self, input): return -1 def pad_data(self, samples, record, sequence_length): #对数据进行padding ''' :param samples:seq_len, batch :param record:dictionary :return: ''' for b in record.keys(): for t in range(record[b] + 1, sequence_length): samples[t][b] = 2 return samples def generate_X_nofixedlen(self, start_input, start_h): #生成器生成不定长的句子(会使用padding token进行填充) ''' :param start_input: batch :param start_h: batch * hidden_size :param sequence_length: int :return:samples: seq_len * batch||hs: seq_len * batch * hidden_size||predictions: seq_len * batch * vocab_size ''' record = {} #记录已经生成出end token的batch idx,以及对应在samples中end token的位置序号 now_len = 0 #记录最新生成的长度 samples = [] predictions = [] hs = [] input = self.embeddings(start_input) # 设置初始输入,batch, input_size last_h = start_h # 设置初始状态 while record.__len__() != start_h.shape[0]: #判断是否所有batch都生成初end token # 迭代GRU next_token, h, prediction = self.G( input, last_h) # 获得当前时间步预测的下一个token,隐藏状态和预测层 samples.append(torch.unsqueeze(next_token, dim=0)) hs.append(torch.unsqueeze(h, dim=0)) predictions.append(torch.unsqueeze(prediction, dim=0)) input = self.embeddings(next_token) last_h = h for i in range(next_token.shape[0]): #判断每一个next token是否end token if next_token[i] == 1 and i not in record.keys(): record[i] = now_len now_len += 1 samples = torch.cat(samples, dim=0) hs = torch.cat(hs, dim=0) predictions = torch.cat(predictions, dim=0) samples = self.pad_data( samples=samples, record=record) #对生成出来的token的end token后的位置进行padding。 return samples, hs, predictions, record, now_len # return seq_len, batch - seq_len, batch, hidden_size - seq_len, batch, vocab_size, list, int def generate_X(self, start_input, start_h, sequence_length): #生成样本,有最大长度 ''' :param start_input: batch :param start_h: batch * hidden_size :param sequence_length: int :return:samples: seq_len * batch||hs: seq_len * batch * hidden_size||predictions: seq_len * batch * vocab_size ''' record = {} # 记录已经生成出end token的batch idx,以及对应在samples中end token的位置序号 now_len = 0 # 记录最新生成的长度 samples = [] predictions = [] hs = [] input = self.embeddings(start_input) #设置初始输入,batch, input_size last_h = start_h #设置初始状态 for i in range(sequence_length): # 迭代GRU next_token, h, prediction = self.G( input, last_h) #获得当前时间步预测的下一个token,隐藏状态和预测层 samples.append(torch.unsqueeze(next_token, dim=0)) hs.append(torch.unsqueeze(h, dim=0)) predictions.append(torch.unsqueeze(prediction, dim=0)) input = self.embeddings(next_token) last_h = h for i in range(next_token.shape[0]): #判断每一个next token是否end token if next_token[i] == 1 and i not in record.keys(): record[i] = now_len now_len += 1 samples = torch.cat(samples, dim=0) hs = torch.cat(hs, dim=0) predictions = torch.cat(predictions, dim=0) samples = self.pad_data(samples=samples, record=record, sequence_length=sequence_length ) # 对生成出来的token的end token后的位置进行padding。 return samples, hs, predictions, record #return seq_len, batch - seq_len, batch, hidden_size - seq_len, batch, vocab_size - list def generate_pretrained(self, start_input, start_h, sequence_length, groundtrues): #预训练阶段,输入为正确的单词,输出预测 ''' :param start_input: batch :param start_h: batch * hidden_size :param sequence_length: int :param groundtrues: sequence_length * batch :return:predictions: seq_len * batch * vocab_size ''' predictions = [] input = self.embeddings(start_input) #设置初始输入,batch, input_size last_h = start_h #设置初始状态 for i in range(sequence_length): # 迭代GRU next_token, h, prediction = self.G( input, last_h) #获得当前时间步预测的下一个token,隐藏状态和预测层 predictions.append(torch.unsqueeze(prediction, dim=0)) input = self.embeddings(groundtrues[i]) #输入正确的单词embedding last_h = h predictions = torch.cat(predictions, dim=0) return predictions #return seq_len, batch, vocab_size def show_G(self): samples, _, _, _ = self.generate_X( start_input=self.start_input, start_h=self.start_h, sequence_length=self.sequence_length) return samples def pretraining(self): loss_func = nn.NLLLoss(ignore_index=self.padding_idx) for epoch in range(1): total_loss = 0.0 for i, x_batch in enumerate( self.dataloader): #x_batch: batch * seq_len self.pre_optimizer.zero_grad() x_groundtrues = torch.transpose( x_batch, dim0=0, dim1=1).cuda() #x_groundtrues: seq_len * batch if x_batch.size()[0] == self.batch_size: predictions = self.generate_pretrained( #predictions: seq_len * batch * vocab_size start_input=self.start_input, start_h=self.start_h, sequence_length=self.sequence_length, groundtrues=x_groundtrues) else: predictions = self.generate_pretrained( # predictions: seq_len * batch * vocab_size start_input=torch.tensor(x_batch.size()[0] * [self.start_idx]).cuda(), start_h=torch.zeros(x_batch.size()[0], self.hidden_size_gru).cuda(), sequence_length=self.sequence_length, groundtrues=x_groundtrues) loss = 0.0 for t in range(self.sequence_length): loss += loss_func( torch.log( torch.clamp(predictions[t], min=1e-20, max=1.0)), x_groundtrues[t]) #tar*log(pre) loss = loss / self.sequence_length total_loss += loss.item() loss.backward() self.pre_optimizer.step() total_loss = total_loss / (i + 1) #输出loss和生成的字符 return total_loss def rollout(self): samples, hs, predictions, record = self.generate_X( start_input=self.start_input, start_h=self.start_h, sequence_length=self.sequence_length) result_rollout = [] for given_num in range(self.sequence_length - 1): #given < T, 遍历 result_overtimes = [] #存放每个时间步的rollout结果 for i in range(self.rollout_num): sample_rollout, _, _, _ = self.generate_X( start_input=samples[given_num], start_h=hs[given_num], sequence_length=self.sequence_length - given_num - 1, ) result_overtimes.append( torch.unsqueeze( torch.cat([samples[0:given_num + 1], sample_rollout], 0), 0)) result_overtimes = torch.cat( result_overtimes, 0) #result_overtimes: rollout_num * seq_len * batch result_rollout.append(torch.unsqueeze(result_overtimes, 0)) result_rollout = torch.cat( result_rollout, 0) #result_rollout:(seq_len-1) * rollout_num * seq_len * batch return result_rollout, samples, predictions, record #result_rollout为1-(T-1)的rollout结果,samples为完整句子 def onehot(self, label): a = torch.FloatTensor(self.sequence_length, self.batch_size, self.vocab_size).zero_().cuda() return a.scatter_(dim=2, index=label, value=1) def generate_code(self, record): ''' :param record: :return code:seq_len * batch * 1 根据每个batch的句子长度来生成seq_len * batch * 1的0 1编码,1表示该位置的reward应该算上,0表示不算上。 如seq_len为4句子长度分别为1,3,2的batch(end token分别对应samples中的1,3,2位置)对应的code为: 1 1 1 0 1 1 0 1 0 1 1 1 num_elements:batch 为code中每个batch统计出为1的数量。 ''' num_elements = torch.zeros(self.batch_size, 1).new_full(size=(self.batch_size, 1), fill_value=self.sequence_length) code = torch.ones(self.sequence_length, self.batch_size, 1) for b in record.keys(): num_elements[b][0] = record[b] + 1 for t in range(record[b], self.sequence_length - 1): code[t][b][0] = 0 return code, num_elements def backward_G(self): result_rollout, result, predictions, record = self.rollout() total_reward = [] for t in range(self.sequence_length - 1): #计算T-1的rollout奖励 result_rollout_trans = torch.transpose( result_rollout[t], dim0=1, dim1=2) #result_rollout_trans: rollout_num * batch * seq_len input_D = self.embeddings(result_rollout_trans) input_D = torch.unsqueeze( input_D, 2 ) #input_D: rollout_num * batch * 1 * seq_len * embedding_size reward = 0.0 for i in range(self.rollout_num): reward += self.D(input_D[i]) reward = reward / self.rollout_num total_reward.append(torch.unsqueeze(reward, 0)) #计算T时间的奖励 result_trans = torch.transpose(result, dim0=0, dim1=1) #result_trans: batch * seq_len input_D = self.embeddings(result_trans) input_D = torch.unsqueeze( input_D, 1) #input_D: batch * 1 * seq_len * embedding_size total_reward.append(torch.unsqueeze(self.D(input_D), 0)) total_reward = torch.cat(total_reward, 0) #total_reward: seq_len * batch * 1 #计算J result_onehot = self.onehot(torch.unsqueeze(result, 2)) policy = result_onehot * predictions policy = torch.unsqueeze(torch.sum(policy, 2), 2) #policy: seq_len * batch * 1 code, num_elements = self.generate_code(record=record) code = code.cuda() num_elements = num_elements.cuda() J_temp = torch.sum( torch.log(torch.clamp(policy, min=1e-20, max=1.0)) * total_reward * code, 0) / num_elements #J_temp: batch * 1 J = -(torch.sum(J_temp) / self.batch_size) self.G_optimizer.zero_grad() J.backward() self.G_optimizer.step() return J.item() def backward_D(self, update=True, loss_f='LOG', is_epoch=False): #is_epoch: 是否遍历整个真实样本 total_loss = 0.0 mse = nn.MSELoss() for i, x_batch_pos in enumerate(self.dataloader): self.D_optimizer.zero_grad() if x_batch_pos.size()[0] == self.batch_size: x_batch_neg, _, _, _ = self.generate_X( start_input=self.start_input, start_h=self.start_h, sequence_length=self.sequence_length) else: #如果dataloader抽出来的不满足batch_size的大小要求 x_batch_neg, _, _, _ = self.generate_X( start_input=torch.tensor(x_batch_pos.size()[0] * [self.start_idx]).cuda(), start_h=torch.zeros(x_batch_pos.size()[0], self.hidden_size_gru).cuda(), sequence_length=self.sequence_length) x_batch_neg = torch.transpose( x_batch_neg, dim0=0, dim1=1) #x_batch_neg: batch * seq_len input_batch_pos = self.embeddings(x_batch_pos.cuda()) input_batch_neg = self.embeddings(x_batch_neg) input_batch_pos = torch.unsqueeze( input_batch_pos, 1) #input_batch_pos: batch * 1 * seq_len * embedding_size input_batch_neg = torch.unsqueeze( input_batch_neg, 1) #input_batch_neg: batch * 1 * seq_len * embedding_size pre_pos = self.D(input=input_batch_pos) pre_neg = self.D(input=input_batch_neg) if loss_f == 'LOG': loss = -torch.sum( torch.log(torch.clamp(pre_pos, min=1e-20, max=1.0)) + torch.log(torch.clamp( (1 - pre_neg), min=1e-20, max=1.0))) / ( 2 * pre_pos.size()[0]) elif loss_f == 'MSE': loss = ( mse(pre_pos, torch.ones(x_batch_pos.size()[0], 1).cuda()) + mse(pre_neg, torch.zeros(x_batch_pos.size()[0], 1).cuda())) / 2.0 # 加入L2正则化 l2_loss = torch.tensor(0.).cuda() for param in self.D.output_layer.parameters(): l2_loss += torch.norm(param, p=2) loss += self.l2_reg_lambda * l2_loss total_loss += loss.item() loss.backward() if update: self.D_optimizer.step() if not is_epoch: return total_loss #只训练一个batch total_loss = total_loss / (i + 1) return total_loss
transforms.ToPILImage(), transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize( (mid_pixel_value,) * in_channels, (mid_pixel_value,) * in_channels ), ] ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") train_dataloader = create_dagan_dataloader( raw_data, num_training_classes, train_transform, batch_size ) g_opt = optim.Adam(g.parameters(), lr=0.0001, betas=(0.0, 0.9)) d_opt = optim.Adam(d.parameters(), lr=0.0001, betas=(0.0, 0.9)) val_data = raw_data[num_training_classes : num_training_classes + num_val_classes] flat_val_data = val_data.reshape( (val_data.shape[0] * val_data.shape[1], *val_data.shape[2:]) ) display_transform = train_transform trainer = DaganTrainer( generator=g, discriminator=d, gen_optimizer=g_opt, dis_optimizer=d_opt, batch_size=batch_size, device=device,
class CycleGAN(AlignmentModel): """This class implements the alignment model for GAN networks with two generators and two discriminators (cycle GAN). For description of the implemented functions, refer to the alignment model.""" def __init__(self, device, config, generator_a=None, generator_b=None, discriminator_a=None, discriminator_b=None): """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam optimizers for all models.""" super().__init__(device, config) self.epoch_losses = [0., 0., 0., 0.] if generator_a is None: generator_a_conf = dict( dim_1=config['dim_b'], dim_2=config['dim_a'], layer_number=config['generator_layers'], layer_expansion=config['generator_expansion'], initialize_generator=config['initialize_generator'], norm=config['gen_norm'], batch_norm=config['gen_batch_norm'], activation=config['gen_activation'], dropout=config['gen_dropout']) self.generator_a = Generator(generator_a_conf, device) self.generator_a.to(device) else: self.generator_a = generator_a if 'optimizer' in config: self.optimizer_g_a = OPTIMIZERS[config['optimizer']]( self.generator_a.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']]( self.generator_a.parameters(), config['learning_rate']) else: self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']]( self.generator_a.parameters()) else: self.optimizer_g_a = torch.optim.Adam( self.generator_a.parameters(), config['learning_rate']) if generator_b is None: generator_b_conf = dict( dim_1=config['dim_a'], dim_2=config['dim_b'], layer_number=config['generator_layers'], layer_expansion=config['generator_expansion'], initialize_generator=config['initialize_generator'], norm=config['gen_norm'], batch_norm=config['gen_batch_norm'], activation=config['gen_activation'], dropout=config['gen_dropout']) self.generator_b = Generator(generator_b_conf, device) self.generator_b.to(device) else: self.generator_b = generator_b if 'optimizer' in config: self.optimizer_g_b = OPTIMIZERS[config['optimizer']]( self.generator_b.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']]( self.generator_b.parameters(), config['learning_rate']) else: self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']]( self.generator_b.parameters()) else: self.optimizer_g_b = torch.optim.Adam( self.generator_b.parameters(), config['learning_rate']) if discriminator_a is None: discriminator_a_conf = dict( dim=config['dim_a'], layer_number=config['discriminator_layers'], layer_expansion=config['discriminator_expansion'], batch_norm=config['disc_batch_norm'], activation=config['disc_activation'], dropout=config['disc_dropout']) self.discriminator_a = Discriminator(discriminator_a_conf, device) self.discriminator_a.to(device) else: self.discriminator_a = discriminator_a if 'optimizer' in config: self.optimizer_d_a = OPTIMIZERS[config['optimizer']]( self.discriminator_a.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']]( self.discriminator_a.parameters(), config['learning_rate']) else: self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']]( self.discriminator_a.parameters()) else: self.optimizer_d_a = torch.optim.Adam( self.discriminator_a.parameters(), config['learning_rate']) if discriminator_b is None: discriminator_b_conf = dict( dim=config['dim_b'], layer_number=config['discriminator_layers'], layer_expansion=config['discriminator_expansion'], batch_norm=config['disc_batch_norm'], activation=config['disc_activation'], dropout=config['disc_dropout']) self.discriminator_b = Discriminator(discriminator_b_conf, device) self.discriminator_b.to(device) else: self.discriminator_b = discriminator_b if 'optimizer' in config: self.optimizer_d_b = OPTIMIZERS[config['optimizer']]( self.discriminator_b.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']]( self.discriminator_b.parameters(), config['learning_rate']) else: self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']]( self.discriminator_b.parameters()) else: self.optimizer_d_b = torch.optim.Adam( self.discriminator_b.parameters(), config['learning_rate']) def train(self): self.generator_a.train() self.generator_b.train() self.discriminator_a.train() self.discriminator_b.train() def eval(self): self.generator_a.eval() self.generator_b.eval() self.discriminator_a.eval() self.discriminator_b.eval() def zero_grad(self): self.optimizer_g_a.zero_grad() self.optimizer_g_b.zero_grad() self.optimizer_d_a.zero_grad() self.optimizer_d_b.zero_grad() def optimize_all(self): self.optimizer_g_a.step() self.optimizer_g_b.step() self.optimizer_d_a.step() self.optimizer_d_b.step() def optimize_generator(self): """Do the optimization step only for generators (e.g. when training generators and discriminators separately or in turns).""" self.optimizer_g_a.step() self.optimizer_g_b.step() def optimize_discriminator(self): """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately or in turns).""" self.optimizer_d_a.step() self.optimizer_d_b.step() def change_lr(self, factor): self.current_lr = self.current_lr * factor for param_group in self.optimizer_g_a.param_groups: param_group['lr'] = self.current_lr for param_group in self.optimizer_g_b.param_groups: param_group['lr'] = self.current_lr def update_losses_batch(self, *losses): loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses self.epoch_losses[0] += loss_g_a self.epoch_losses[1] += loss_g_b self.epoch_losses[2] += loss_d_a self.epoch_losses[3] += loss_d_b def complete_epoch(self, epoch_metrics): self.metrics.append(epoch_metrics + [sum(self.epoch_losses)]) self.losses.append(self.epoch_losses) self.epoch_losses = [0., 0., 0., 0.] def print_epoch_info(self): print( f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} " f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}" ) def copy_model(self): self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\ deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict()) def restore_model(self): self.generator_a.load_state_dict(self.model_copy[0]) self.generator_b.load_state_dict(self.model_copy[1]) self.discriminator_a.load_state_dict(self.model_copy[2]) self.discriminator_b.load_state_dict(self.model_copy[3]) def export_model(self, test_results, description=None): if description is None: description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}" export_cyclegan_alignment(description, self.config, self.generator_a, self.generator_b, self.discriminator_a, self.discriminator_b, self.metrics) save_alignment_test_results(test_results, description) print(f"Saved model to directory {description}.") @classmethod def load_model(cls, name, device): generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment( name, device) model = cls(device, config, generator_a, generator_b, discriminator_a, discriminator_b) return model
def train_gan(): print("GAN training to start on", device) print("Now loading data . . .") train_smiles_data = load_real_data() print("Data loaded") discriminator = Discriminator(vocab).to(device) generator = Generator(vocab).to(device) d_optimizer = optim.Adam(discriminator.parameters(), lr=param.lr_d) g_optimizer = optim.Adam(generator.parameters(), lr=param.lr_g) num_epochs = param.num_epochs batch_size = param.batch_size num_steps = param.num_steps dataset_size = train_smiles_data.sequence_list_size print("dataset_size =", dataset_size) # ************* # THE LOGIC # ************* # for num_iter # 1. for num_steps # (a) train discriminator with samples = batch_size # 2. Train generator with samples = batch_size # # How to determine num_iterations? # num_iter * num_steps * batch_size = dataset_size * num_epochs # # ************* num_iterations = math.ceil( (dataset_size * num_epochs) / (num_steps * batch_size)) print(f"num_iterations: {num_iterations}") saved_generator_loss = 10000000.0 start_index = 0 total_row_count = dataset_size print("Iter, d_error_real, d_error_fake, g_error_fake") accuracy_dr = 0 accuracy_df = 0 apply_grads = True for iter in range(num_iterations): d_error_real, d_error_fake = 0, 0 g_error_fake = 0 for _k in range(num_steps): if start_index >= total_row_count: start_index = 0 end_index = start_index + batch_size if (end_index - start_index) > total_row_count: end_index = total_row_count if end_index > total_row_count: end_index = total_row_count real_data = train_smiles_data.sequence_tensors[ start_index:end_index, :, :] real_data_l = train_smiles_data.sequence_length_data[ start_index:end_index] N = end_index - start_index fake_data, fake_data_l = generator(N) real_data = real_data.to(device) real_data_l = real_data_l.to(device) fake_data = fake_data.detach().to(device) fake_data_l = fake_data_l.to(device) if (apply_grads and (accuracy_dr > 90) and (accuracy_df > 90)): apply_grads = False filepath = os.path.join(base_model_path, "discriminator_" + str(iter) + ".pt") torch.save(discriminator, filepath) error_real, accuracy_dr = train_discriminator( discriminator, d_optimizer, real_data, real_data_l, 'real', apply_grads) error_fake, accuracy_df = train_discriminator( discriminator, d_optimizer, fake_data, fake_data_l, 'fake', apply_grads) d_error_real += error_real.mean() d_error_fake += error_fake.mean() start_index = end_index d_error_real = d_error_real / num_steps d_error_fake = d_error_fake / num_steps N = batch_size fake_data, fake_data_l = generator(N) fake_data = fake_data.to(device) fake_data_l = fake_data_l.to(device) g_error, accuracy_g = train_generator(discriminator, generator, g_optimizer, fake_data, fake_data_l) g_error_fake += g_error.mean() print( f" Accuracy Numbers: D = {accuracy_dr}, {accuracy_df}, G = {accuracy_g}" ) if (g_error_fake < saved_generator_loss) or (iter % 10 == 0): if (g_error_fake < saved_generator_loss): saved_generator_loss = g_error_fake base_folder = base_model_path pt_filepath = os.path.join(base_folder, "generator.pt") gen_txt_filepath = os.path.join(base_folder, "generatod_samples.txt") else: base_folder = os.path.join(base_model_path, "tens") pt_filepath = os.path.join(base_folder, "generator_" + str(iter) + ".pt") gen_txt_filepath = os.path.join( base_folder, "generatod_samples" + str(iter) + ".txt") torch.save(generator, pt_filepath) with torch.no_grad(): y, len = generator(500) generator_results = generator.get_sequences_from_tensor(y, len) with open(gen_txt_filepath, "w") as outfile: outfile.write(f"Iter: {iter}, Loss: {g_error_fake}\n") outfile.write("\n".join(generator_results)) if iter % 1 == 0: print( f"{iter}, {d_error_real.item()}, {d_error_fake.item()}, {g_error_fake.item()}" ) print( "--------------------------------------------------------------------------" ) filepath = os.path.join(base_model_path, "discriminator.pt") torch.save(discriminator, filepath) filepath = os.path.join(base_model_path, "generator.pt") torch.save(generator, filepath) return filepath
def _main(): print_gpu_details() device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") train_root = args.train_path image_size = 256 cropped_image_size = 256 print("set image folder") train_set = dset.ImageFolder(root=train_root, transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(cropped_image_size), transforms.ToTensor() ])) normalizer_clf = transforms.Compose([ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) normalizer_discriminator = transforms.Compose([ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) print('set data loader') train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) # Network creation classifier = torch.load(args.classifier_path) classifier.eval() generator = Generator(gen_type=args.gen_type) discriminator = Discriminator(args.discriminator_norm, dis_type=args.gen_type) # init weights if args.generator_path is not None: generator.load_state_dict(torch.load(args.generator_path)) else: generator.init_weights() if args.discriminator_path is not None: discriminator.load_state_dict(torch.load(args.discriminator_path)) else: discriminator.init_weights() classifier.to(device) generator.to(device) discriminator.to(device) # losses + optimizers criterion_discriminator, criterion_generator = get_wgan_losses_fn() criterion_features = nn.L1Loss() criterion_diversity_n = nn.L1Loss() criterion_diversity_d = nn.L1Loss() generator_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999)) discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)) num_of_epochs = args.epochs starting_time = time.time() iterations = 0 # creating dirs for keeping models checkpoint, temp created images, and loss summary outputs_dir = os.path.join('wgan-gp_models', args.model_name) if not os.path.isdir(outputs_dir): os.makedirs(outputs_dir, exist_ok=True) temp_results_dir = os.path.join(outputs_dir, 'temp_results') if not os.path.isdir(temp_results_dir): os.mkdir(temp_results_dir) models_dir = os.path.join(outputs_dir, 'models_checkpoint') if not os.path.isdir(models_dir): os.mkdir(models_dir) writer = tensorboardX.SummaryWriter(os.path.join(outputs_dir, 'summaries')) z = torch.randn(args.batch_size, 128, 1, 1).to(device) # a fixed noise for sampling z2 = torch.randn(args.batch_size, 128, 1, 1).to(device) # a fixed noise for diversity sampling fixed_features = 0 fixed_masks = 0 fixed_features_diversity = 0 first_iter = True print("Starting Training Loop...") for epoch in range(num_of_epochs): for data in train_loader: train_type = random.choices([1, 2], [args.train1_prob, 1-args.train1_prob]) # choose train type iterations += 1 if iterations % 30 == 1: print('epoch:', epoch, ', iter', iterations, 'start, time =', time.time() - starting_time, 'seconds') starting_time = time.time() images, _ = data images = images.to(device) # change to gpu tensor images_discriminator = normalizer_discriminator(images) images_clf = normalizer_clf(images) _, features = classifier(images_clf) if first_iter: # save batch of images to keep track of the model process first_iter = False fixed_features = [torch.clone(features[x]) for x in range(len(features))] fixed_masks = [torch.ones(features[x].shape, device=device) for x in range(len(features))] fixed_features_diversity = [torch.clone(features[x]) for x in range(len(features))] for i in range(len(features)): for j in range(fixed_features_diversity[i].shape[0]): fixed_features_diversity[i][j] = fixed_features_diversity[i][j % 8] grid = vutils.make_grid(images_discriminator, padding=2, normalize=True, nrow=8) vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images.jpg')) orig_images_diversity = torch.clone(images_discriminator) for i in range(orig_images_diversity.shape[0]): orig_images_diversity[i] = orig_images_diversity[i % 8] grid = vutils.make_grid(orig_images_diversity, padding=2, normalize=True, nrow=8) vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images_diversity.jpg')) # Select a features layer to train on features_to_train = random.randint(1, len(features) - 2) if args.fixed_layer is None else args.fixed_layer # Set masks masks = [features[i].clone() for i in range(len(features))] setMasksPart1(masks, device, features_to_train) if train_type == 1 else setMasksPart2(masks, device, features_to_train) discriminator_loss_dict = train_discriminator(generator, discriminator, criterion_discriminator, discriminator_optimizer, images_discriminator, features, masks) for k, v in discriminator_loss_dict.items(): writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=iterations) if iterations % 30 == 1: print('{}: {:.6f}'.format(k, v)) if iterations % args.discriminator_steps == 1: generator_loss_dict = train_generator(generator, discriminator, criterion_generator, generator_optimizer, images.shape[0], features, criterion_features, features_to_train, classifier, normalizer_clf, criterion_diversity_n, criterion_diversity_d, masks, train_type) for k, v in generator_loss_dict.items(): writer.add_scalar('G/%s' % k, v.data.cpu().numpy(), global_step=iterations//5 + 1) if iterations % 30 == 1: print('{}: {:.6f}'.format(k, v)) # Save generator and discriminator weights every 1000 iterations if iterations % 1000 == 1: torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G') torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D') # Save temp results if args.keep_temp_results: if iterations < 10000 and iterations % 1000 == 1 or iterations % 2000 == 1: # regular sampling (batch of different images) first_features = True fake_images = None fake_images_diversity = None for i in range(1, 5): one_layer_mask = isolate_layer(fixed_masks, i, device) if first_features: first_features = False fake_images = sample(generator, z, fixed_features, one_layer_mask) fake_images_diversity = sample(generator, z, fixed_features_diversity, one_layer_mask) else: tmp_fake_images = sample(generator, z, fixed_features, one_layer_mask) fake_images = torch.vstack((fake_images, tmp_fake_images)) tmp_fake_images = sample(generator, z2, fixed_features_diversity, one_layer_mask) fake_images_diversity = torch.vstack((fake_images_diversity, tmp_fake_images)) grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=8) vutils.save_image(grid, os.path.join(temp_results_dir, 'res_iter_{}.jpg'.format(iterations // 1000))) # diversity sampling (8 different images each with few different noises) grid = vutils.make_grid(fake_images_diversity, padding=2, normalize=True, nrow=8) vutils.save_image(grid, os.path.join(temp_results_dir, 'div_iter_{}.jpg'.format(iterations // 1000))) if iterations % 20000 == 1: torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G_' + str(iterations // 15000)) torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D_' + str(iterations // 15000))
zeros_label = Variable(torch.zeros(BATCH_SIZE)) if __name__ == "__main__": print 'main' gen_model = Tiramisu() disc_model = Discriminator() if is_gpu_mode: gen_model.cuda() disc_model.cuda() # gen_model = torch.nn.DataParallel(gen_model).cuda() # disc_model = torch.nn.DataParallel(disc_model).cuda() optimizer_gen = torch.optim.Adam(gen_model.parameters(), lr=LEARNING_RATE_GENERATOR) optimizer_disc = torch.optim.Adam(disc_model.parameters(), lr=LEARNING_RATE_DISCRIMINATOR) # read imgs image_buff_read_index = 0 # pytorch style input_img = np.empty(shape=(BATCH_SIZE, 3, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT)) answer_img = np.empty(shape=(BATCH_SIZE, 3, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT)) motion_vec_img = np.empty(shape=(BATCH_SIZE, 1, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT)) fake_motion_vec_img = np.empty(shape=(BATCH_SIZE, 1, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT)) # opencv style output_img_opencv = np.empty(shape=(data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT, 3))
# Print the model print(netD) # Initialize BCELoss function criterion = nn.BCELoss() # Create batch of latent vectors that we will use to visualize # the progression of the generator fixed_noise = torch.randn(64, cf.nz, 1, 1, device=device) # Establish convention for real and fake labels during training real_label = 1 fake_label = 0 # Setup Adam optimizers for both G and D optimizerD = optim.Adam(netD.parameters(), lr=cf.lr, betas=(cf.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=cf.lr, betas=(cf.beta1, 0.999)) # Lists to keep track of progress img_list = [] G_losses = [] D_losses = [] def train(): # Training Loop iters = 0 print("Starting Training Loop...") # For each epoch for epoch in range(cf.num_epochs):
m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) G = Generator().to(device) G.apply(weights_init) D = Discriminator().to(device) D.apply(weights_init) # Training the DCGANs criterion = nn.BCELoss() optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999)) Dis_loss = [] gen_loss = [] for epoch in range(25): print("***************Epoch is *******************", epoch + 1) for i, data in enumerate(dataloader, 0): D.zero_grad() real, _ = data input = Variable(real).to(device) target = Variable(torch.ones(input.size()[0])).to(device) output = D(input) output = output.to(device) Derr_real = criterion(output, target) z = Variable(torch.randn(input.size()[0], 100, 1, 1)).to(device)
def train(config): genAB = UNet(3, 3, bilinear=config.model.bilinear_upsample).cuda() init_weights(genAB, 'normal') genBA = UNet(3, 3, bilinear=config.model.bilinear_upsample).cuda() init_weights(genBA, 'normal') discrA = Discriminator(3).cuda() init_weights(discrA, 'normal') discrB = Discriminator(3).cuda() init_weights(discrB, 'normal') writer = SummaryWriter(config.name) data_train, data_test = datasets_by_name(config.dataset.name, config.dataset) train_dataloader = DataLoader(data_train, batch_size=config.bs, shuffle=True, num_workers=config.num_workers) test_dataloader = DataLoader(data_test, batch_size=config.bs, shuffle=True, num_workers=config.num_workers) idt_loss = nn.L1Loss() cycle_consistency = nn.L1Loss() l2_loss = nn.MSELoss() discriminator_loss = nn.BCELoss() lambda_idt, lambda_C, lambda_D = config.loss.lambda_idt, config.loss.lambda_C, config.loss.lambda_D optG = torch.optim.Adam(itertools.chain(genAB.parameters(), genBA.parameters()), lr=config.train.lr, betas=(config.train.beta1, 0.999)) optD = torch.optim.Adam(itertools.chain(discrA.parameters(), discrB.parameters()), lr=config.train.lr, betas=(config.train.beta1, 0.999)) genAB, genBA, discrA, discrB, optG, optD, start_epoch = load_if_exsists( config, genAB, genBA, discrA, discrB, optG, optD) for epoch in range(start_epoch, config.train.epochs): set_train([genAB, genBA, discrA, discrB]) set_requires_grad([genAB, genBA, discrA, discrB], True) for i, (batch_A, batch_B) in enumerate(tqdm(train_dataloader)): batch_A, batch_B = batch_A.cuda(), batch_B.cuda() optG.zero_grad() loss_G, loss_D = 0, 0 fake_B = genAB(batch_A) cycle_A = genBA(fake_B) fake_A = genBA(batch_B) cycle_B = genAB(fake_A) if lambda_idt > 0: loss_G += idt_loss(fake_B, batch_B) * lambda_idt loss_G += idt_loss(fake_A, batch_A) * lambda_idt if lambda_C > 0: loss_G += cycle_consistency(cycle_A, batch_A) * lambda_C loss_G += cycle_consistency(cycle_B, batch_B) * lambda_C if lambda_D > 0: set_requires_grad([discrA, discrB], False) discr_feedbackA = discrA(fake_A) discr_feedbackB = discrB(fake_B) loss_G += discriminator_loss( discr_feedbackA, torch.ones_like(discr_feedbackA)) * lambda_D loss_G += discriminator_loss( discr_feedbackB, torch.ones_like(discr_feedbackB)) * lambda_D loss_G.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(genAB.parameters(), genBA.parameters()), 15) optG.step() if lambda_D > 0: set_requires_grad([discrA, discrB], True) loss_D_fake, loss_D_true = 0, 0 optD.zero_grad() logits = discrA(fake_A.detach()) loss_D_fake += discriminator_loss(logits, torch.zeros_like(logits)) logits = discrB(fake_B.detach()) loss_D_fake += discriminator_loss(logits, torch.zeros_like(logits)) loss_D_fake.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(discrA.parameters(), discrB.parameters()), 15) optD.step() optD.zero_grad() logits = discrA(batch_A) loss_D_true += discriminator_loss(logits, torch.ones_like(logits)) logits = discrB(batch_B) loss_D_true += discriminator_loss(logits, torch.ones_like(logits)) loss_D_true.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(discrA.parameters(), discrB.parameters()), 15) optD.step() loss_D = loss_D_fake + loss_D_true if (i % config.train.verbose_period == 0): writer.add_scalar('train/loss_G', loss_G.item(), len(train_dataloader) * epoch + i) writer.add_scalar('train/pixel_error_A', l2_loss(fake_A, batch_A).mean().item(), len(train_dataloader) * epoch + i) writer.add_scalar('train/pixel_error_B', l2_loss(fake_B, batch_B).mean().item(), len(train_dataloader) * epoch + i) if lambda_D > 0: writer.add_scalar('train/loss_D', loss_D.item(), len(train_dataloader) * epoch + i) writer.add_scalar('train/mean_D_A', discr_feedbackA.mean().item(), len(train_dataloader) * epoch + i) writer.add_scalar('train/mean_D_B', discr_feedbackB.mean().item(), len(train_dataloader) * epoch + i) for batch_i in range(fake_A.shape[0]): concat = (torch.cat([fake_A[batch_i], batch_B[batch_i]], dim=-1) + 1.) / 2. writer.add_image('train/fake_A_' + str(batch_i), concat, len(train_dataloader) * epoch + i) for batch_i in range(fake_B.shape[0]): concat = (torch.cat([fake_B[batch_i], batch_A[batch_i]], dim=-1) + 1.) / 2. writer.add_image('train/fake_B_' + str(batch_i), concat, len(train_dataloader) * epoch + i) if not config.validate: continue set_eval([genAB, genBA, discrA, discrB]) set_requires_grad([genAB, genBA, discrA, discrB], False) loss_G, loss_D, discr_feedbackA_mean, discr_feedbackB_mean = 0, 0, 0, 0 pixel_error_A, pixel_error_B = 0, 0 for i, (batch_A, batch_B) in enumerate(tqdm(test_dataloader)): batch_A, batch_B = batch_A.cuda(), batch_B.cuda() fake_B = genAB(batch_A) cycle_A = genBA(fake_B) fake_A = genBA(batch_B) cycle_B = genAB(fake_A) pixel_error_A += l2_loss(fake_A, batch_A).mean() pixel_error_B += l2_loss(fake_B, batch_B).mean() if lambda_idt > 0: loss_G += idt_loss(fake_B, batch_B) * lambda_idt loss_G += idt_loss(fake_A, batch_A) * lambda_idt if lambda_C > 0: loss_G += cycle_consistency(cycle_A, batch_A) * lambda_C loss_G += cycle_consistency(cycle_B, batch_B) * lambda_C if lambda_D > 0: discr_feedbackA = discrA(fake_A) discr_feedbackB = discrB(fake_B) loss_G += discriminator_loss( discr_feedbackA, torch.ones_like(discr_feedbackA)) * lambda_D loss_G += discriminator_loss( discr_feedbackB, torch.ones_like(discr_feedbackB)) * lambda_D discr_feedbackA_mean += discr_feedbackA.mean() discr_feedbackB_mean += discr_feedbackB.mean() if lambda_D > 0: loss_D_fake, loss_D_true = 0, 0 logits = discrA(fake_A.detach()) loss_D_fake += discriminator_loss(logits, torch.zeros_like(logits)) logits = discrB(fake_B.detach()) loss_D_fake += discriminator_loss(logits, torch.zeros_like(logits)) logits = discrA(batch_A) loss_D_true += discriminator_loss(logits, torch.ones_like(logits)) logits = discrB(batch_B) loss_D_true += discriminator_loss(logits, torch.ones_like(logits)) loss_D += loss_D_fake + loss_D_true if i == 0: for batch_i in range(fake_A.shape[0]): concat = (torch.cat([fake_A[batch_i], batch_B[batch_i]], dim=-1) + 1.) / 2. writer.add_image('val/fake_A_' + str(batch_i), concat, epoch) for batch_i in range(fake_B.shape[0]): concat = (torch.cat([fake_B[batch_i], batch_A[batch_i]], dim=-1) + 1.) / 2. writer.add_image('val/fake_B_' + str(batch_i), concat, epoch) loss_G /= len(test_dataloader) pixel_error_A /= len(test_dataloader) pixel_error_B /= len(test_dataloader) writer.add_scalar('val/loss_G', loss_G.item(), epoch) writer.add_scalar('val/pixel_error_A', pixel_error_A.item(), epoch) writer.add_scalar('val/pixel_error_B', pixel_error_B.item(), epoch) if lambda_D > 0: loss_D /= len(test_dataloader) discr_feedbackA_mean /= len(test_dataloader) discr_feedbackB_mean /= len(test_dataloader) writer.add_scalar('val/loss_D', loss_D.item(), epoch) writer.add_scalar('val/mean_D_A', discr_feedbackA_mean.item(), epoch) writer.add_scalar('val/mean_D_B', discr_feedbackB_mean.item(), epoch) torch.save( { 'genAB': genAB.state_dict(), 'genBA': genBA.state_dict(), 'discrA': discrA.state_dict(), 'discrB': discrB.state_dict(), 'optG': optG.state_dict(), 'optD': optD.state_dict(), 'epoch': epoch }, os.path.join(config.name, 'model.pth'))
class SGAN: def __init__(self): self.read_dataset() if not os.path.exists(cfg.train.run_directory): os.makedirs(cfg.train.run_directory) with open(cfg.train.run_directory + 'params.txt', 'w') as f: f.write(str(vars(cfg))) self.build_model() return def read_dataset(self): self.train_loader, self.valid_loader = get_train_valid_loader( data_dir=cfg.dataset.data_dir, dataset_type=cfg.dataset.dataset_name, train_batch_size=cfg.train.batch_size, valid_batch_size=cfg.validation.batch_size, augment=False if cfg.dataset.dataset_name == 'mnist' else True, random_seed=cfg.dataset.seed, valid_size=cfg.train.valid_part, shuffle=True, show_sample=False, num_workers=multiprocessing.cpu_count(), pin_memory=False) return def real_data_target(self, size): ''' Tensor containing ones, with shape = size ''' data = Variable(torch.ones(size, 1)) if torch.cuda.is_available(): return data.cuda() return data def fake_data_target(self, size): ''' Tensor containing zeros, with shape = size ''' data = Variable(torch.zeros(size, 1)) if torch.cuda.is_available(): return data.cuda() return data def train_discriminator(self, discriminator, optimizer, real_data, fake_data, labels): # Reset gradients optimizer.zero_grad() # 1. Train on Real Data D_real = discriminator(cfg.dataset.dataset_name, real_data, labels) # Calculate error and backpropagate D_loss_real = self.loss(D_real, self.real_data_target(real_data.size(0))) D_loss_real.backward() # 2. Train on Fake Data D_fake = discriminator(cfg.dataset.dataset_name, fake_data, labels) # Calculate error and backpropagate D_loss_fake = self.loss(D_fake, self.fake_data_target(fake_data.size(0))) D_loss_fake.backward() if cfg.train.loss_type == cfg.VANILLA: D_loss = D_loss_real + D_loss_fake elif cfg.train.loss_type == cfg.WGAN: D_loss = D_loss_fake - D_loss_real if cfg.train.use_GP: grad_penalty, gradient_norm = gradient_penalty( discriminator, real_data, fake_data, cfg.train.gp_weight, labels, cfg.dataset.dataset_name) D_loss += grad_penalty # Update weights with gradients optimizer.step() return D_real, D_fake, D_loss, D_loss_real, D_loss_fake def train_generator(self, generator, discriminator, optimizer, z_noise, labels): # Reset gradients optimizer.zero_grad() # Sample noise and generate fake data G_fake_data = generator(cfg.dataset.dataset_name, z_noise, labels) D_fake = discriminator(cfg.dataset.dataset_name, G_fake_data, labels) # Calculate error and backpropagate G_loss = self.loss(D_fake, self.real_data_target(D_fake.size(0))) if cfg.train.loss_type == cfg.WGAN: G_loss = -1 * G_loss G_loss.backward() # Update weights with gradients optimizer.step() # Return error return G_fake_data, G_loss def build_model(self): if cfg.train.loss_type == cfg.VANILLA: self.loss = nn.BCELoss() elif cfg.train.loss_type == cfg.WGAN: self.loss = lambda logits, labels: torch.mean(logits) self.D_global = Discriminator(cfg.dataset.dataset_name) self.G_global = Generator(cfg.dataset.dataset_name) # Enable cuda if available if torch.cuda.is_available(): self.D_global.cuda() self.G_global.cuda() # Optimizers self.D_global_optimizer = Adam(self.D_global.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.G_global_optimizer = Adam(self.G_global.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.D_pairs = [] self.G_pairs = [] self.D_pairs_optimizers = [] self.G_pairs_optimizers = [] self.D_msg_pairs = [] self.D_msg_pairs_optimizers = [] for id in range(1, cfg.train.N_pairs + 1): discriminator = Discriminator(cfg.dataset.dataset_name) generator = Generator(cfg.dataset.dataset_name) # Enable cuda if available if torch.cuda.is_available(): generator.cuda() discriminator.cuda() self.D_pairs.append(discriminator) self.G_pairs.append(generator) # Optimizers D_optimizer = Adam(discriminator.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) G_optimizer = Adam(generator.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.D_pairs_optimizers.append(D_optimizer) self.G_pairs_optimizers.append(G_optimizer) # create msg Discriminator pair for G_global discriminator = Discriminator(cfg.dataset.dataset_name) # Enable cuda if available if torch.cuda.is_available(): generator.cuda() discriminator.cuda() self.D_msg_pairs.append(discriminator) # Optimizers D_optimizer = Adam(discriminator.parameters(), lr=cfg.train.learning_rate, betas=(cfg.train.beta1, 0.999)) self.D_msg_pairs_optimizers.append(D_optimizer) self.logger = Logger(model_name='DCGAN', data_name='MNIST', logdir=cfg.validation.validation_dir) return def run_validation(self, generator, discriminator, epoch, i, type_GAN): nrof_batches = len(self.valid_loader) for batch_idx, (valid_batch_images, valid_batch_labels) in enumerate(self.valid_loader): valid_batch_size = len(valid_batch_images) valid_batch_labels = valid_batch_labels.type(torch.float32) valid_batch_z = torch.from_numpy( np.random.uniform(-1, 1, [valid_batch_size, cfg.train.z_dim]).astype( np.float32)) if torch.cuda.is_available(): valid_batch_images = valid_batch_images.cuda() valid_batch_labels = valid_batch_labels.cuda() valid_batch_z = valid_batch_z.cuda() G_fake_data = generator(cfg.dataset.dataset_name, valid_batch_z, valid_batch_labels) D_fake = discriminator(cfg.dataset.dataset_name, G_fake_data, valid_batch_labels) G_loss = self.loss(D_fake, self.real_data_target(D_fake.size(0))) D_real = discriminator(cfg.dataset.dataset_name, valid_batch_images, valid_batch_labels) D_loss_real = self.loss( D_real, self.real_data_target(valid_batch_images.size(0))) D_fake = discriminator(cfg.dataset.dataset_name, G_fake_data, valid_batch_labels) D_loss_fake = self.loss(D_fake, self.fake_data_target(D_fake.size(0))) D_loss = D_loss_real + D_loss_fake if len(valid_batch_images) == cfg.validation.batch_size: inception_score, std = Score.inception_score(G_fake_data) self.logger.log_score(inception_score, epoch, batch_idx, nrof_batches, type_GAN, 'IS_validation') # self.logger.log_images(generated_images, valid_batch_size, epoch, val_i, nrof_valid_batches, # type_GAN='pairs', format='NHWC') print("[Sample] d_loss: %.8f, g_loss: %.8f" % (D_loss, G_loss)) if batch_idx > 0 and batch_idx % 15 == 0: generated_images = G_fake_data.detach().cpu() generated_images = generated_images.permute([0, 2, 3, 1]) self.logger.log_images2(generated_images, epoch, batch_idx, type_GAN=type_GAN) batch_idx += 1 # self.logger.save_models(self.G_pairs[id], self.D_pairs[id], epoch, 'pairs') return def copy_network_parameters(self, src_network, dest_network): params_src = src_network.named_parameters() params_dest = dest_network.named_parameters() dict_dest_params = dict(params_dest) for name_src, param_src in params_src: if name_src in dict_dest_params: dict_dest_params[name_src].data.copy_(param_src.data) return def run_train(self): for epoch in range(cfg.train.num_epochs): for id in range(cfg.train.N_pairs): print('Train pairs') self.train_pairs_epoch(id, epoch) self.copy_network_parameters(self.D_pairs[id], self.D_msg_pairs[id]) self.train_G_global_epoch(id, epoch) self.train_D_global_epoch(id, epoch) self.run_validation(self.G_global, self.D_global, epoch, None, 'global_pair') self.logger.save_models(self.G_global, self.D_global, epoch, 'global_pair') return def train_D_global_epoch(self, id, epoch): # torch.set_default_tensor_type('torch.DoubleTensor') nrof_batches = len(self.train_loader) train_time = 0 for batch_idx, (batch_images, batch_labels) in enumerate(self.train_loader): start_time = time.time() batch_size = len(batch_images) batch_labels = batch_labels.type(torch.float32) batch_z = torch.from_numpy( np.random.uniform(-1, 1, [batch_size, cfg.train.z_dim]).astype( np.float32)) # 1. Train Discriminator if torch.cuda.is_available(): batch_images = batch_images.cuda() batch_labels = batch_labels.cuda() batch_z = batch_z.cuda() # Generate fake data G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z, batch_labels).detach() # Train D D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator( self.D_global, self.D_global_optimizer, batch_images, G_fake_data, batch_labels) # 2. Train Generator G_fake_data, G_loss = self.train_generator( self.G_pairs[id], self.D_global, self.G_pairs_optimizers[id], batch_z, batch_labels) # 3. Train Discriminator twice # Generate fake data G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z, batch_labels).detach() # Train D D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator( self.D_global, self.D_global_optimizer, batch_images, G_fake_data, batch_labels) # Log error self.logger.log(D_loss, G_loss, epoch, batch_idx, nrof_batches, 'D0-' + str(id + 1)) if len(batch_images) == cfg.train.batch_size: inception_score, std = Score.inception_score(G_fake_data) self.logger.log_score(inception_score, epoch, batch_idx, nrof_batches, 'D0-' + str(id + 1), 'IS') duration = time.time() - start_time print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (epoch, cfg.train.num_epochs, batch_idx, nrof_batches, time.time() - start_time, D_loss, G_loss)) train_time += duration if batch_idx > 0 and batch_idx % 101 == 0: self.run_validation(self.G_pairs[id], self.D_global, epoch, batch_idx, 'D_global_pairs-' + str(id + 1)) batch_idx += 1 self.logger.save_models(self.G_pairs[id], self.D_global, epoch, 'D_global_pairs-' + str(id + 1)) return def train_G_global_epoch(self, id, epoch): # torch.set_default_tensor_type('torch.DoubleTensor') nrof_batches = len(self.train_loader) train_time = 0 for batch_idx, (batch_images, batch_labels) in enumerate(self.train_loader): start_time = time.time() batch_size = len(batch_images) batch_labels = batch_labels.type(torch.float32) batch_z = torch.from_numpy( np.random.uniform(-1, 1, [batch_size, cfg.train.z_dim]).astype( np.float32)) # 1. Train Discriminator if torch.cuda.is_available(): batch_images = batch_images.cuda() batch_labels = batch_labels.cuda() batch_z = batch_z.cuda() # Generate fake data G_fake_data = self.G_global(cfg.dataset.dataset_name, batch_z, batch_labels).detach() # Train D D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator( self.D_msg_pairs[id], self.D_msg_pairs_optimizers[id], batch_images, G_fake_data, batch_labels) # 2. Train Generator G_fake_data, G_loss = self.train_generator(self.G_global, self.D_msg_pairs[id], self.G_global_optimizer, batch_z, batch_labels) # 3. Train Discriminator twice # Generate fake data G_fake_data = self.G_global(cfg.dataset.dataset_name, batch_z, batch_labels).detach() # Train D D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator( self.D_msg_pairs[id], self.D_msg_pairs_optimizers[id], batch_images, G_fake_data, batch_labels) # Log error self.logger.log(D_loss, G_loss, epoch, batch_idx, nrof_batches, 'G0-' + str(id + 1)) if len(batch_images) == cfg.train.batch_size: inception_score, std = Score.inception_score(G_fake_data) self.logger.log_score(inception_score, epoch, batch_idx, nrof_batches, 'G0-' + str(id + 1), 'IS') duration = time.time() - start_time print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (epoch, cfg.train.num_epochs, batch_idx, nrof_batches, time.time() - start_time, D_loss, G_loss)) train_time += duration if batch_idx > 0 and batch_idx % 101 == 0: self.run_validation(self.G_global, self.D_msg_pairs[id], epoch, batch_idx, 'G_global_pairs-' + str(id + 1)) batch_idx += 1 self.logger.save_models(self.G_global, self.D_msg_pairs[id], epoch, 'G_global_pairs-' + str(id + 1)) return def train_pairs_epoch(self, id, epoch): nrof_batches = len(self.train_loader) train_time = 0 for batch_idx, (batch_images, batch_labels) in enumerate(self.train_loader): start_time = time.time() batch_size = len(batch_images) batch_labels = batch_labels.type(torch.float32) batch_z = torch.from_numpy( np.random.uniform(-1, 1, [batch_size, cfg.train.z_dim]).astype( np.float32)) # 1. Train Discriminator if torch.cuda.is_available(): batch_images = batch_images.cuda() batch_labels = batch_labels.cuda() batch_z = batch_z.cuda() # Generate fake data G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z, batch_labels).detach() # Train D D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator( self.D_pairs[id], self.D_pairs_optimizers[id], batch_images, G_fake_data, batch_labels) # 2. Train Generator G_fake_data, G_loss = self.train_generator( self.G_pairs[id], self.D_pairs[id], self.G_pairs_optimizers[id], batch_z, batch_labels) # 3. Train Discriminator twice # Generate fake data G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z, batch_labels).detach() # Train D D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator( self.D_pairs[id], self.D_pairs_optimizers[id], batch_images, G_fake_data, batch_labels) # Log error self.logger.log(D_loss, G_loss, epoch, batch_idx, nrof_batches, str(id + 1)) if len(batch_images) == cfg.train.batch_size: inception_score, std = Score.inception_score(G_fake_data) self.logger.log_score(inception_score, epoch, batch_idx, nrof_batches, str(id + 1), 'IS') duration = time.time() - start_time print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (epoch, cfg.train.num_epochs, batch_idx, nrof_batches, time.time() - start_time, D_loss, G_loss)) train_time += duration if batch_idx > 0 and batch_idx % 101 == 0: self.run_validation(self.G_pairs[id], self.D_pairs[id], epoch, batch_idx, 'pairs-' + str(id + 1)) self.logger.save_models(self.G_pairs[id], self.D_pairs[id], epoch, 'pairs-' + str(id + 1)) return
def main(): random.seed(SEED) np.random.seed(SEED) # Define Networks generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda) discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim, d_filter_sizes, d_num_filters, d_dropout) target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda) if opt.cuda: generator = generator.cuda() discriminator = discriminator.cuda() target_lstm = target_lstm.cuda() # Generate toy data using target lstm print('Generating data ...') generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE) # Load data from file gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE) # Pretrain Generator using MLE gen_criterion = nn.NLLLoss(reduction='sum') gen_optimizer = optim.Adam(generator.parameters()) if opt.cuda: gen_criterion = gen_criterion.cuda() print('Pretrain with MLE ...') for epoch in range(PRE_EPOCH_NUM): loss = train_epoch(generator, gen_data_iter, gen_criterion, gen_optimizer) print('Epoch [%d] Model Loss: %f' % (epoch, loss)) generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE) loss = eval_epoch(target_lstm, eval_iter, gen_criterion) print('Epoch [%d] True Loss: %f' % (epoch, loss)) # Pretrain Discriminator dis_criterion = nn.NLLLoss(reduction='sum') dis_optimizer = optim.Adam(discriminator.parameters()) if opt.cuda: dis_criterion = dis_criterion.cuda() print('Pretrain Discriminator ...') for epoch in range(5): generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE) dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE) for _ in range(3): loss = train_epoch(discriminator, dis_data_iter, dis_criterion, dis_optimizer) print('Epoch [%d], loss: %f' % (epoch, loss)) # Adversarial Training rollout = Rollout(generator, 0.8) print('#####################################################') print('Start Adeversatial Training...\n') gen_gan_loss = GANLoss() gen_gan_optm = optim.Adam(generator.parameters()) if opt.cuda: gen_gan_loss = gen_gan_loss.cuda() gen_criterion = nn.NLLLoss(reduction='sum') if opt.cuda: gen_criterion = gen_criterion.cuda() dis_criterion = nn.NLLLoss(reduction='sum') dis_optimizer = optim.Adam(discriminator.parameters()) if opt.cuda: dis_criterion = dis_criterion.cuda() for total_batch in range(TOTAL_BATCH): ## Train the generator for one step for it in range(1): samples = generator.sample(BATCH_SIZE, g_sequence_len) # construct the input to the genrator, add zeros before samples and delete the last column zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor) if samples.is_cuda: zeros = zeros.cuda() inputs = Variable( torch.cat([zeros, samples.data], dim=1)[:, :-1].contiguous()) targets = Variable(samples.data).contiguous().view((-1, )) # calculate the reward rewards = rollout.get_reward(samples, 16, discriminator) rewards = Variable(torch.Tensor(rewards)) rewards = torch.exp(rewards).contiguous().view((-1, )) if opt.cuda: rewards = rewards.cuda() prob = generator.forward(inputs) loss = gen_gan_loss(prob, targets, rewards) gen_gan_optm.zero_grad() loss.backward() gen_gan_optm.step() if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1: generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE) loss = eval_epoch(target_lstm, eval_iter, gen_criterion) print('Batch [%d] True Loss: %f' % (total_batch, loss)) rollout.update_params() for _ in range(4): generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE) dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE) for _ in range(2): loss = train_epoch(discriminator, dis_data_iter, dis_criterion, dis_optimizer)