def one_im_discrim(discrim_path, im_path): discriminator = Discriminator(3, 64) discriminator.load_state_dict( torch.load(discrim_path, map_location=torch.device('cpu'))) discriminator.eval() tensor = transforms.ToTensor() im = torchImage.open(im_path) result = discriminator(tensor(Image.open(im_path))).view(-1) print(result.data.item())
def run_model(model_path, discrim_path): model = Deblurrer() model.load_state_dict( torch.load(model_path, map_location=torch.device('cpu'))) model.eval() discriminator = Discriminator(3, 64) discriminator.load_state_dict( torch.load(discrim_path, map_location=torch.device('cpu'))) discriminator.eval() dataset = LFWC(["../data/train/faces_blurred"], "../data/train/faces") #dataset = FakeData(size=1000, image_size=(3, 128, 128), transform=transforms.ToTensor()) data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) for data in data_loader: blurred_img = Variable(data['blurred']) nonblurred = Variable(data['nonblurred']) # Should be near zero discrim_output_blurred = discriminator(blurred_img).view( -1).data.item() # Should be naer one discrim_output_nonblurred = discriminator(nonblurred).view( -1).data.item() #im = Image.open(image_path) #transform = transforms.ToTensor() transformback = transforms.ToPILImage() plt.imshow(transformback(blurred_img[0])) plt.title('Blurred, Discrim value: ' + str(discrim_output_blurred)) plt.show() plt.imshow(transformback(nonblurred[0])) plt.title('Non Blurred, Discrim value: ' + str(discrim_output_nonblurred)) plt.show() out = model(blurred_img) discrim_output_model = discriminator(out).view(-1).data.item() #print(out.shape) outIm = transformback(out[0]) plt.imshow(outIm) plt.title('Model out, Discrim value: ' + str(discrim_output_model)) plt.show()
class CycleGAN(AlignmentModel): """This class implements the alignment model for GAN networks with two generators and two discriminators (cycle GAN). For description of the implemented functions, refer to the alignment model.""" def __init__(self, device, config, generator_a=None, generator_b=None, discriminator_a=None, discriminator_b=None): """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam optimizers for all models.""" super().__init__(device, config) self.epoch_losses = [0., 0., 0., 0.] if generator_a is None: generator_a_conf = dict( dim_1=config['dim_b'], dim_2=config['dim_a'], layer_number=config['generator_layers'], layer_expansion=config['generator_expansion'], initialize_generator=config['initialize_generator'], norm=config['gen_norm'], batch_norm=config['gen_batch_norm'], activation=config['gen_activation'], dropout=config['gen_dropout']) self.generator_a = Generator(generator_a_conf, device) self.generator_a.to(device) else: self.generator_a = generator_a if 'optimizer' in config: self.optimizer_g_a = OPTIMIZERS[config['optimizer']]( self.generator_a.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']]( self.generator_a.parameters(), config['learning_rate']) else: self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']]( self.generator_a.parameters()) else: self.optimizer_g_a = torch.optim.Adam( self.generator_a.parameters(), config['learning_rate']) if generator_b is None: generator_b_conf = dict( dim_1=config['dim_a'], dim_2=config['dim_b'], layer_number=config['generator_layers'], layer_expansion=config['generator_expansion'], initialize_generator=config['initialize_generator'], norm=config['gen_norm'], batch_norm=config['gen_batch_norm'], activation=config['gen_activation'], dropout=config['gen_dropout']) self.generator_b = Generator(generator_b_conf, device) self.generator_b.to(device) else: self.generator_b = generator_b if 'optimizer' in config: self.optimizer_g_b = OPTIMIZERS[config['optimizer']]( self.generator_b.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']]( self.generator_b.parameters(), config['learning_rate']) else: self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']]( self.generator_b.parameters()) else: self.optimizer_g_b = torch.optim.Adam( self.generator_b.parameters(), config['learning_rate']) if discriminator_a is None: discriminator_a_conf = dict( dim=config['dim_a'], layer_number=config['discriminator_layers'], layer_expansion=config['discriminator_expansion'], batch_norm=config['disc_batch_norm'], activation=config['disc_activation'], dropout=config['disc_dropout']) self.discriminator_a = Discriminator(discriminator_a_conf, device) self.discriminator_a.to(device) else: self.discriminator_a = discriminator_a if 'optimizer' in config: self.optimizer_d_a = OPTIMIZERS[config['optimizer']]( self.discriminator_a.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']]( self.discriminator_a.parameters(), config['learning_rate']) else: self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']]( self.discriminator_a.parameters()) else: self.optimizer_d_a = torch.optim.Adam( self.discriminator_a.parameters(), config['learning_rate']) if discriminator_b is None: discriminator_b_conf = dict( dim=config['dim_b'], layer_number=config['discriminator_layers'], layer_expansion=config['discriminator_expansion'], batch_norm=config['disc_batch_norm'], activation=config['disc_activation'], dropout=config['disc_dropout']) self.discriminator_b = Discriminator(discriminator_b_conf, device) self.discriminator_b.to(device) else: self.discriminator_b = discriminator_b if 'optimizer' in config: self.optimizer_d_b = OPTIMIZERS[config['optimizer']]( self.discriminator_b.parameters(), config['learning_rate']) elif 'optimizer_default' in config: if config['optimizer_default'] == 'sgd': self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']]( self.discriminator_b.parameters(), config['learning_rate']) else: self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']]( self.discriminator_b.parameters()) else: self.optimizer_d_b = torch.optim.Adam( self.discriminator_b.parameters(), config['learning_rate']) def train(self): self.generator_a.train() self.generator_b.train() self.discriminator_a.train() self.discriminator_b.train() def eval(self): self.generator_a.eval() self.generator_b.eval() self.discriminator_a.eval() self.discriminator_b.eval() def zero_grad(self): self.optimizer_g_a.zero_grad() self.optimizer_g_b.zero_grad() self.optimizer_d_a.zero_grad() self.optimizer_d_b.zero_grad() def optimize_all(self): self.optimizer_g_a.step() self.optimizer_g_b.step() self.optimizer_d_a.step() self.optimizer_d_b.step() def optimize_generator(self): """Do the optimization step only for generators (e.g. when training generators and discriminators separately or in turns).""" self.optimizer_g_a.step() self.optimizer_g_b.step() def optimize_discriminator(self): """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately or in turns).""" self.optimizer_d_a.step() self.optimizer_d_b.step() def change_lr(self, factor): self.current_lr = self.current_lr * factor for param_group in self.optimizer_g_a.param_groups: param_group['lr'] = self.current_lr for param_group in self.optimizer_g_b.param_groups: param_group['lr'] = self.current_lr def update_losses_batch(self, *losses): loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses self.epoch_losses[0] += loss_g_a self.epoch_losses[1] += loss_g_b self.epoch_losses[2] += loss_d_a self.epoch_losses[3] += loss_d_b def complete_epoch(self, epoch_metrics): self.metrics.append(epoch_metrics + [sum(self.epoch_losses)]) self.losses.append(self.epoch_losses) self.epoch_losses = [0., 0., 0., 0.] def print_epoch_info(self): print( f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} " f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}" ) def copy_model(self): self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\ deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict()) def restore_model(self): self.generator_a.load_state_dict(self.model_copy[0]) self.generator_b.load_state_dict(self.model_copy[1]) self.discriminator_a.load_state_dict(self.model_copy[2]) self.discriminator_b.load_state_dict(self.model_copy[3]) def export_model(self, test_results, description=None): if description is None: description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}" export_cyclegan_alignment(description, self.config, self.generator_a, self.generator_b, self.discriminator_a, self.discriminator_b, self.metrics) save_alignment_test_results(test_results, description) print(f"Saved model to directory {description}.") @classmethod def load_model(cls, name, device): generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment( name, device) model = cls(device, config, generator_a, generator_b, discriminator_a, discriminator_b) return model
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) print("======printing args========") print(args) print("=================================") # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): print("Loading bin dataset") dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) else: print(f"Loading raw text dataset {args.data}") dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # try to load generator model g_model_path = 'checkpoints/generator/best_gmodel.pt' if not os.path.exists(g_model_path): print("Start training generator!") train_g(args, dataset) assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load(g_model_path) #print(f"First dict: {pretrained_dict}") # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } #print(f"Second dict: {pretrained_dict}") # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) #print(f"model dict: {model_dict}") # 3. load the new state dict generator.load_state_dict(model_dict) print("Generator has successfully loaded!") # try to load discriminator model d_model_path = 'checkpoints/discriminator/best_dmodel.pt' if not os.path.exists(d_model_path): print("Start training discriminator!") train_d(args, dataset) assert os.path.exists(d_model_path) discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = discriminator.state_dict() pretrained_dict = torch.load(d_model_path) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(model_dict) print("Discriminator has successfully loaded!") #return print("starting main training loop") torch.autograd.set_detect_anomaly(True) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/joint'): os.makedirs('checkpoints/joint') checkpoints_path = 'checkpoints/joint/' # define loss function g_criterion = torch.nn.NLLLoss(size_average=False, ignore_index=dataset.dst_dict.pad(), reduce=True) d_criterion = torch.nn.BCEWithLogitsLoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) # seed = args.seed + epoch_i # torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when rand > 50% rand = random.random() if rand >= 0.5: # policy gradient training generator.decoder.is_testing = True sys_out_batch, prediction, _ = generator(sample) generator.decoder.is_testing = False with torch.no_grad(): n_i = sample['net_input']['src_tokens'] #print(f"net input:\n{n_i}, pred: \n{prediction}") reward = discriminator( sample['net_input']['src_tokens'], prediction) # dataset.dst_dict.pad()) train_trg_batch = sample['target'] #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}") pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() # oracle valid _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) else: # MLE training #print(f"printing sample: \n{sample}") _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence) #, dataset.dst_dict.pad()) #print(f"disc out: {disc_out.shape}, labels: {labels.shape}") #print(f"labels: {labels}") d_loss = d_criterion(disc_out, labels.long()) acc = torch.sum(torch.Sigmoid() (disc_out).round() == labels).float() / len(labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg, # d_logging_meters['train_acc'].avg, # i)) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: sample['id'] = sample['id'].cuda() sample['net_input']['src_tokens'] = sample['net_input'][ 'src_tokens'].cuda() sample['net_input']['src_lengths'] = sample['net_input'][ 'src_lengths'].cuda() sample['net_input']['prev_output_tokens'] = sample[ 'net_input']['prev_output_tokens'].cuda() sample['target'] = sample['target'].cuda() # generator validation _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug("G dev loss at batch {0}: {1:.3f}".format( i, g_logging_meters['valid_loss'].avg)) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence, dataset.dst_dict.pad()) d_loss = d_criterion(disc_out, labels) acc = torch.sum(torch.Sigmoid()(disc_out).round() == labels).float() / len(labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg, # d_logging_meters['valid_acc'].avg, i)) torch.save(generator, open( checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format( g_logging_meters['valid_loss'].avg, epoch_i), 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
class SVM_Classifier: def __init__(self, batch_size, image_size=64): self.image_size = image_size self.device = torch.device("cuda:0" if ( torch.cuda.is_available()) else "cpu") self.save_filename = f'model_{datetime.datetime.now().strftime("%a_%H_%M")}.sav' transform = transforms.Compose([ # transforms.Resize(self.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) self.trainloader = data.DataLoader(self.trainset, batch_size=batch_size, shuffle=True, num_workers=2) self.testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) self.testloader = data.DataLoader(self.testset, batch_size=batch_size, shuffle=False, num_workers=2) saved_state = torch.load( "C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\imagenet\\trained_model_Tue_17_06.pth" ) self.discriminator = Discriminator(ngpu=1, num_channels=3, num_features=64, data_generation_mode=1, input_size=image_size) self.discriminator.load_state_dict(saved_state['discriminator']) self.discriminator.eval() # change the mode of the network. def plot_training_data(self): # Plot some training images real_batch = next(iter(self.trainloader)) real_batch = real_batch[0][0:8] plt.figure(figsize=(8, 8)) plt.axis("off") plt.title("Training Images") plt.imshow( np.transpose( vutils.make_grid(real_batch[0].to(self.device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0))) plt.show() def train(self): train_data, train_labels = next(iter(self.trainloader)) modified_train_data = self.discriminator(train_data) l2_svm = svm.LinearSVC(verbose=2, max_iter=2000) modified_train_data_ndarray = modified_train_data.detach().numpy() train_labels_ndarray = train_labels.detach().numpy() self.l2_svm = l2_svm.fit(modified_train_data_ndarray, train_labels_ndarray) # save model with open(self.save_filename, 'wb') as file: pickle.dump(self.l2_svm, file) def train_test_SGD_Classifier(self): est = make_pipeline(StandardScaler(), SGDClassifier(max_iter=200)) training_data = self.discriminator(next(iter(self.trainloader))[0]) training_data = training_data.detach().numpy() est.steps[0][1].fit(training_data) self.est = est for i, data in enumerate(self.trainloader): train_data, train_labels = data modified_train_data = self.discriminator(train_data) modified_train_data_ndarray = modified_train_data.detach().numpy() train_labels_ndarray = train_labels.detach().numpy() modified_train_data_ndarray = est.steps[0][1].transform( modified_train_data_ndarray) est.steps[1][1].partial_fit( modified_train_data_ndarray, train_labels_ndarray, classes=np.unique(train_labels_ndarray)) print(f'Batch: {i}') with open(self.save_filename, 'wb') as file: pickle.dump(est.steps[1][1], file) def test(self): l2_svm = self.est.steps[1][1] accuracy = [] for i, data in enumerate(self.testloader): test_data, test_labels = data modified_test_data = self.discriminator(test_data) modified_test_data_ndarray = modified_test_data.detach().numpy() test_labels_ndarray = test_labels.detach().numpy() modified_test_data_ndarray = self.est.steps[0][1].transform( modified_test_data_ndarray) predictions = l2_svm.predict(modified_test_data_ndarray) accuracy.append( metrics.accuracy_score(test_labels_ndarray, predictions)) print(f'Accuracy: {np.mean(accuracy)}')
class trainer(object): def __init__(self, cfg): self.cfg = cfg self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS, out_ch=cfg.DATASET.N_CLASS, side='out') self.Image_generator = U_Net(in_ch=3, out_ch=cfg.DATASET.N_CLASS, side='in') self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3, cfg.DATASET.IMGSIZE, patch=True) self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0], cfg.LOSS.LOSS_WEIGHT[1], cfg.LOSS.LOSS_WEIGHT[2], ignore_index=cfg.LOSS.IGNORE_INDEX) self.criterion_D = DiscriminatorLoss() train_dataset = BaseDataset(cfg, split='train') valid_dataset = BaseDataset(cfg, split='val') self.train_dataloader = data.DataLoader( train_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.valid_dataloader = data.DataLoader( valid_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints') if not os.path.isdir(self.ckpt_outdir): os.mkdir(self.ckpt_outdir) self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val') if not os.path.isdir(self.val_outdir): os.mkdir(self.val_outdir) self.start_epoch = cfg.TRAIN.RESUME self.n_epoch = cfg.TRAIN.N_EPOCH self.optimizer_G = torch.optim.Adam( [{ 'params': self.OldLabel_generator.parameters() }, { 'params': self.Image_generator.parameters() }], lr=cfg.OPTIMIZER.G_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) self.optimizer_D = torch.optim.Adam( [{ 'params': self.discriminator.parameters(), 'initial_lr': cfg.OPTIMIZER.D_LR }], lr=cfg.OPTIMIZER.D_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE lambda_poly = lambda iters: pow( (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9) self.scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.logger = logger(cfg.TRAIN.OUTDIR, name='train') self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS) if self.start_epoch >= 0: self.OldLabel_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_N']) self.Image_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_I']) self.discriminator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_D']) self.optimizer_G.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_G']) self.optimizer_D.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_D']) log = "Using the {}th checkpoint".format(self.start_epoch) self.logger.info(log) self.Image_generator = self.Image_generator.cuda() self.OldLabel_generator = self.OldLabel_generator.cuda() self.discriminator = self.discriminator.cuda() self.criterion_G = self.criterion_G.cuda() self.criterion_D = self.criterion_D.cuda() def train(self): all_train_iter_total_loss = [] all_train_iter_corr_loss = [] all_train_iter_recover_loss = [] all_train_iter_change_loss = [] all_train_iter_gan_loss_gen = [] all_train_iter_gan_loss_dis = [] all_val_epo_iou = [] all_val_epo_acc = [] iter_num = [0] epoch_num = [] num_batches = len(self.train_dataloader) for epoch_i in range(self.start_epoch + 1, self.n_epoch): iter_total_loss = AverageTracker() iter_corr_loss = AverageTracker() iter_recover_loss = AverageTracker() iter_change_loss = AverageTracker() iter_gan_loss_gen = AverageTracker() iter_gan_loss_dis = AverageTracker() batch_time = AverageTracker() tic = time.time() # train self.OldLabel_generator.train() self.Image_generator.train() self.discriminator.train() for i, meta in enumerate(self.train_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) # ------------------- # Train Discriminator # ------------------- self.discriminator.set_requires_grad(True) self.optimizer_D.zero_grad() fake_sample = torch.cat((image, corr_pred), 1).detach() real_sample = torch.cat( (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1) score_fake_d = self.discriminator(fake_sample) score_real = self.discriminator(real_sample) gan_loss_dis = self.criterion_D(pred_score=score_fake_d, real_score=score_real) gan_loss_dis.backward() self.optimizer_D.step() self.scheduler_D.step() # --------------- # Train Generator # --------------- self.discriminator.set_requires_grad(False) self.optimizer_G.zero_grad() score_fake = self.discriminator( torch.cat((image, corr_pred), 1)) total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G( corr_pred, recover_pred, score_fake, old_label, new_label) total_loss.backward() self.optimizer_G.step() self.scheduler_G.step() iter_total_loss.update(total_loss.item()) iter_corr_loss.update(corr_loss.item()) iter_recover_loss.update(recover_loss.item()) iter_change_loss.update(change_loss.item()) iter_gan_loss_gen.update(gan_loss_gen.item()) iter_gan_loss_dis.update(gan_loss_dis.item()) batch_time.update(time.time() - tic) tic = time.time() log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \ 'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format( datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item()) print(log) if (i + 1) % 10 == 0: all_train_iter_total_loss.append(iter_total_loss.avg) all_train_iter_corr_loss.append(iter_corr_loss.avg) all_train_iter_recover_loss.append(iter_recover_loss.avg) all_train_iter_change_loss.append(iter_change_loss.avg) all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg) all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg) iter_total_loss.reset() iter_corr_loss.reset() iter_recover_loss.reset() iter_change_loss.reset() iter_gan_loss_gen.reset() iter_gan_loss_dis.reset() vis.line(X=np.column_stack( np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)), Y=np.column_stack((all_train_iter_total_loss, all_train_iter_corr_loss, all_train_iter_recover_loss, all_train_iter_change_loss, all_train_iter_gan_loss_gen, all_train_iter_gan_loss_dis)), opts={ 'legend': [ 'total_loss', 'corr_loss', 'recover_loss', 'change_loss', 'gan_loss_gen', 'gan_loss_dis' ], 'linecolor': np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255], [255, 0, 255]]), 'title': 'Train loss of generator and discriminator' }, win='Train loss of generator and discriminator') iter_num.append(iter_num[-1] + 1) # eval self.OldLabel_generator.eval() self.Image_generator.eval() self.discriminator.eval() with torch.no_grad(): for j, meta in enumerate(self.valid_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) preds = np.argmax(corr_pred.cpu().detach().numpy().copy(), axis=1) target = new_label.cpu().detach().numpy().copy() self.running_metrics.update(target, preds) if j == 0: color_map1 = gen_color_map(preds[0, :]).astype( np.uint8) color_map2 = gen_color_map(preds[1, :]).astype( np.uint8) color_map = cv2.hconcat([color_map1, color_map2]) cv2.imwrite( os.path.join( self.val_outdir, '{}epoch*{}*{}.png'.format( epoch_i, meta[3][0], meta[3][1])), color_map) score = self.running_metrics.get_scores() oa = score['Overall Acc: \t'] precision = score['Precision: \t'][1] recall = score['Recall: \t'][1] iou = score['Class IoU: \t'][1] miou = score['Mean IoU: \t'] self.running_metrics.reset() epoch_num.append(epoch_i) all_val_epo_acc.append(oa) all_val_epo_iou.append(miou) vis.line(X=np.column_stack( np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)), Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)), opts={ 'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'], 'linecolor': np.array([[255, 0, 0], [0, 255, 0]]), 'title': 'Validate Accuracy and IoU' }, win='validate Accuracy and IoU') log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \ .format(datetime.now(), epoch_i, oa, recall, miou) self.logger.info(log) state = { 'epoch': epoch_i, "acc": oa, "recall": recall, "iou": miou, 'model_G_N': self.OldLabel_generator.state_dict(), 'model_G_I': self.Image_generator.state_dict(), 'model_D': self.discriminator.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict() } save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i)) torch.save(state, save_path)
class GAIL: def __init__(self, exp_dir, exp_thresh, state_dim, action_dim, learn_rate, betas, _device, _gamma, load_weights=False): """ exp_dir : directory containing the expert episodes exp_thresh : parameter to control number of episodes to load as expert based on returns (lower means more episodes) state_dim : dimesnion of state action_dim : dimesnion of action learn_rate : learning rate for optimizer _device : GPU or cpu _gamma : discount factor _load_weights : load weights from directory """ # storing runtime device self.device = _device # discount factor self.gamma = _gamma # Expert trajectory self.expert = ExpertTrajectories(exp_dir, exp_thresh, gamma=self.gamma) # Defining the actor and its optimizer self.actor = ActorNetwork(state_dim).to(self.device) self.optim_actor = torch.optim.Adam(self.actor.parameters(), lr=learn_rate, betas=betas) # Defining the discriminator and its optimizer self.disc = Discriminator(state_dim, action_dim).to(self.device) self.optim_disc = torch.optim.Adam(self.disc.parameters(), lr=learn_rate, betas=betas) if not load_weights: self.actor.apply(init_weights) self.disc.apply(init_weights) else: self.load() # Loss function crtiterion self.criterion = torch.nn.BCELoss() def get_action(self, state): """ obtain action for a given state using actor network """ state = torch.tensor(state, dtype=torch.float, device=self.device).view(1, -1) return self.actor(state).cpu().data.numpy().flatten() def update(self, n_iter, batch_size=100): """ train discriminator and actor for mini-batch """ # memory to store disc_losses = np.zeros(n_iter, dtype=np.float) act_losses = np.zeros(n_iter, dtype=np.float) for i in range(n_iter): # Get expert state and actions batch exp_states, exp_actions = self.expert.sample(batch_size) exp_states = torch.FloatTensor(exp_states).to(self.device) exp_actions = torch.FloatTensor(exp_actions).to(self.device) # Get state, and actions using actor states, _ = self.expert.sample(batch_size) states = torch.FloatTensor(states).to(self.device) actions = self.actor(states) ''' train the discriminator ''' self.optim_disc.zero_grad() # label tensors exp_labels = torch.full((batch_size, 1), 1, device=self.device) policy_labels = torch.full((batch_size, 1), 0, device=self.device) # with expert transitions prob_exp = self.disc(exp_states, exp_actions) exp_loss = self.criterion(prob_exp, exp_labels) # with policy actor transitions prob_policy = self.disc(states, actions.detach()) policy_loss = self.criterion(prob_policy, policy_labels) # use backprop disc_loss = exp_loss + policy_loss disc_losses[i] = disc_loss.mean().item() disc_loss.backward() self.optim_disc.step() ''' train the actor ''' self.optim_actor.zero_grad() loss_actor = -self.disc(states, actions) act_losses[i] = loss_actor.mean().detach().item() loss_actor.mean().backward() self.optim_actor.step() print("Finished training minibatch") return act_losses, disc_losses def save( self, directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights', name='GAIL'): torch.save(self.actor.state_dict(), '{}/{}_actor.pth'.format(directory, name)) torch.save(self.disc.state_dict(), '{}/{}_discriminator.pth'.format(directory, name)) def load( self, directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights', name='GAIL'): print(os.getcwd()) self.actor.load_state_dict( torch.load('{}/{}_actor.pth'.format(directory, name))) self.disc.load_state_dict( torch.load('{}/{}_discriminator.pth'.format(directory, name))) def set_mode(self, mode="train"): if mode == "train": self.actor.train() self.disc.train() else: self.actor.eval() self.disc.eval()
class Trainer: def __init__(self, corpus_data, *, params): self.fast_text = FastText(corpus_data.model).to(GPU) self.discriminator = Discriminator( params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.ft_optimizer = optim.SGD(self.fast_text.parameters(), lr=params.ft_lr) self.d_optimizer = optim.SGD(self.discriminator.parameters(), lr=params.d_lr, weight_decay=params.d_wd) self.a_optimizer = optim.SGD([{ "params": self.fast_text.u.parameters() }, { "params": self.fast_text.v.parameters() }], lr=params.a_lr) self.smooth = params.smooth self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.corpus_data_queue = _data_queue(corpus_data, n_threads=params.n_threads, n_sentences=params.n_sentences, batch_size=params.ft_bs) self.vocab_size = params.vocab_size self.d_bs = params.d_bs self.split = params.split self.align_output = params.align_output def fast_text_step(self): self.ft_optimizer.zero_grad() u_b, v_b = self.corpus_data_queue.__next__() s = self.fast_text(u_b, v_b) loss = FastText.loss_fn(s) loss.backward() self.ft_optimizer.step() return loss.item() def get_adv_batch(self, *, reverse, fix_embedding): vocab_split, bs_split = int(self.vocab_size * self.split), int( self.d_bs * self.split) x = (torch.randint(0, vocab_split, size=(bs_split, ), dtype=torch.long).tolist() + torch.randint(vocab_split, self.vocab_size, size=(self.d_bs - bs_split, ), dtype=torch.long).tolist()) if self.align_output: x = torch.LongTensor(x).view(self.d_bs, 1).to(GPU) if fix_embedding: with torch.no_grad(): x = self.fast_text.v(x).view(self.d_bs, -1) else: x = self.fast_text.v(x).view(self.d_bs, -1) else: x = self.fast_text.model.get_bag(x, self.fast_text.u.weight.device) if fix_embedding: with torch.no_grad(): x = self.fast_text.u(x[0], x[1]).view(self.d_bs, -1) else: x = self.fast_text.u(x[0], x[1]).view(self.d_bs, -1) y = torch.FloatTensor(self.d_bs).to(GPU).uniform_(0.0, self.smooth) if reverse: y[:bs_split] = 1 - y[:bs_split] else: y[bs_split:] = 1 - y[bs_split:] return x, y def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): x, y = self.get_adv_batch(reverse=False, fix_embedding=True) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() self.d_optimizer.step() return loss.item() def adversarial_step(self): self.a_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True, fix_embedding=False) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() self.a_optimizer.step() return loss.item()
""" pixels = X.reshape((28, 28)) plt.title(str(digit)) plt.imshow(pixels, cmap='gray') plt.show() X = get_normal_shaped_arrays(60000, (1, 784)) X_train, y_train, X_test, y_test = discriminator_train_test_set( X, X_train, params.DISCRIMINATOR_TRAIN_TEST_SPLIT) discriminator = Discriminator(params.DISCRIMINATOR_BATCH_SIZE, params.DISCRIMINATOR_EPOCHS) discriminator.train(X_train, y_train) print(discriminator.eval(X_test, y_test)) generator = Generator() gan = Gan(generator, discriminator) gan.set_discriminator_trainability(False) gan.show_trainable() X = get_normal_shaped_arrays(100000, (1, 16)) y = [] for _ in range(100000): y.append([0, 1]) y = np.array(y) generator = gan.train_generator(X, y)
class Trainer: def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000): self.skip_gram = [SkipGram(corpus_data_0.vocab_size + 1, params.emb_dim).to(GPU), SkipGram(corpus_data_1.vocab_size + 1, params.emb_dim).to(GPU)] self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False) self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim))) self.mapping = self.mapping.to(GPU) self.sg_optimizer, self.sg_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt(self.skip_gram[id].parameters(), lr=params.sg_lr, mode="max") self.sg_optimizer.append(optimizer) self.sg_scheduler.append(scheduler) self.a_optimizer, self.a_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt( [{"params": self.skip_gram[id].u.parameters()}, {"params": self.skip_gram[id].v.parameters()}], lr=params.a_lr, mode="max") self.a_optimizer.append(optimizer) self.a_scheduler.append(scheduler) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.corpus_data_queue = [ _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.sg_bs), _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.sg_bs) ] self.sampler = [ WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top), WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)] self.d_bs = params.d_bs def skip_gram_step(self): losses = [] for id in [0, 1]: self.sg_optimizer[id].zero_grad() pos_u_b, pos_v_b, neg_v_b = self.corpus_data_queue[id].__next__() pos_s, neg_s = self.skip_gram[id](pos_u_b, pos_v_b, neg_v_b) loss = SkipGram.loss_fn(pos_s, neg_s) loss.backward() self.sg_optimizer[id].step() losses.append(loss.item()) return losses[0], losses[1] def get_adv_batch(self, *, reverse, fix_embedding=False): batch = [torch.LongTensor([self.sampler[id].sample() for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU) for id in [0, 1]] if fix_embedding: with torch.no_grad(): x = [self.skip_gram[id].u(batch[id]).view(self.d_bs, -1) for id in [0, 1]] else: x = [self.skip_gram[id].u(batch[id]).view(self.d_bs, -1) for id in [0, 1]] x[0] = self.mapping(x[0]) x = torch.cat(x, 0) y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[: self.d_bs] = 1 - y[: self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] return x, y def adversarial_step(self, fix_embedding=False): for id in [0, 1]: self.a_optimizer[id].zero_grad() self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() for id in [0, 1]: self.a_optimizer[id].step() self.m_optimizer.step() _orthogonalize(self.mapping, self.m_beta) return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): x, y = self.get_adv_batch(reverse=False) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() self.d_optimizer.step() return loss.item() def scheduler_step(self, metric): for id in [0, 1]: self.sg_scheduler[id].step(metric) self.a_scheduler[id].step(metric) # self.d_scheduler.step(metric) self.m_scheduler.step(metric)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset( args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset( args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0.3 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0.3 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator_h = Discriminator_h(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda) print("Discriminator_h loaded successfully!") discriminator_s = Discriminator_s(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator_s loaded successfully!") def _calcualte_discriminator_loss(tf_scores, ar_scores): tf_loss = torch.log(tf_scores + 1e-9) * (-1) ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1) return tf_loss + ar_loss if use_cuda: if torch.cuda.device_count() > 1: discriminator_h = torch.nn.DataParallel(discriminator_h).cuda() discriminator_s = torch.nn.DataParallel(discriminator_s).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator_h.cuda() discriminator_s.cuda() else: discriminator_h.cpu() discriminator_s.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/professor2'): os.makedirs('checkpoints/professor2') checkpoints_path = 'checkpoints/professor2/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator_h word embedding (as Wu et al. do) for p in discriminator_s.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator_s.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer_h = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator_h.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) d_optimizer_s = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator_s.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator_h.train() discriminator_s.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(trainloader): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # print("Policy Gradient Training") sys_out_batch_PG, p_PG, hidden_list_PG = generator('PG', epoch_i, sample) # 64 X 50 X 6632 out_batch_PG = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch_PG.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator_s(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch_PG = sample['target'] # 64 x 50 pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda) sample_size_PG = sample['target'].size(0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss_PG = pg_loss_PG / math.log(2) g_logging_meters['train_loss'].update(logging_loss_PG.item(), sample_size_PG) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}") g_optimizer.zero_grad() pg_loss_PG.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() # print("MLE Training") sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator("MLE", epoch_i, sample) out_batch_MLE = sys_out_batch_MLE.contiguous().view(-1, sys_out_batch_MLE.size(-1)) # (64 X 50) X 6632 train_trg_batch_MLE = sample['target'].view(-1) # 64*50 = 3200 loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE) sample_size_MLE = sample['target'].size(0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss_MLE, sample_size_MLE) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}") g_optimizer.zero_grad() loss_MLE.backward(retain_graph=True) # all-reduce grads and rescale by grad_denom for p in generator.parameters(): # print(p.size()) if p.requires_grad: p.grad.data.div_(sample_size_MLE) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator # discriminator_h if num_update % 5 == 0: d_MLE = discriminator_h(hidden_list_MLE) d_PG = discriminator_h(hidden_list_PG) d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum() logging.debug(f"D_h training loss {d_loss} at batch {i}") d_optimizer_h.zero_grad() d_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator_h.parameters(), args.clip_norm) d_optimizer_h.step() #discriminator_s bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input']['src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = torch.ones(sample['target'].size(0)).float() # 64 length vector with torch.no_grad(): sys_out_batch, p, hidden_list = generator('MLE', epoch_i, sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = torch.zeros(sample['target'].size(0)).float() # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() fake_disc_out = discriminator_s(src_sentence, fake_sentence) # 64 X 1 true_disc_out = discriminator_s(src_sentence, true_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss + true_d_loss d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug( f"D_s training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}") d_optimizer_s.zero_grad() d_loss.backward() d_optimizer_s.step() # validation # set validation mode generator.eval() discriminator_h.eval() discriminator_s.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch_test, p_test, hidden_list_test = generator('test', epoch_i, sample) out_batch_test = sys_out_batch_test.contiguous().view(-1, sys_out_batch_test.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss_test = g_criterion(out_batch_test, dev_trg_batch) sample_size_test = sample['target'].size(0) if args.sentence_avg else sample['ntokens'] loss_test = loss_test / sample_size_test / math.log(2) g_logging_meters['valid_loss'].update(loss_test, sample_size_test) logging.debug(f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}") # # discriminator_h validation # bsz = sample['target'].size(0) # src_sentence = sample['net_input']['src_tokens'] # # train with half human-translation and half machine translation # true_sentence = sample['target'] # true_labels = torch.ones(sample['target'].size(0)).float() # with torch.no_grad(): # sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample) # # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1)) # (64 X 50) X 6632 # _, prediction = out_batch.topk(1) # prediction = prediction.squeeze(1) # 64 * 50 = 6632 # fake_labels = torch.zeros(sample['target'].size(0)).float() # fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 # if use_cuda: # fake_labels = fake_labels.cuda() # disc_out = discriminator_h(src_sentence, fake_sentence) # d_loss = d_criterion(disc_out.squeeze(1), fake_labels) # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # d_logging_meters['valid_acc'].update(acc) # d_logging_meters['valid_loss'].update(d_loss) # logging.debug( # f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}") torch.save(generator, open(checkpoints_path + f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt", 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
from discriminator import Discriminator import torchvision.datasets as dset from torchvision import transforms import torch.utils.data if __name__ == "__main__": saved_state = torch.load("C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\trained_model_Mon_05_45.pth") dis = Discriminator(ngpu=1, num_channels=3, num_features=64) dis.load_state_dict(saved_state['discriminator']) dis.eval() dataset = dset.ImageFolder(root="C:\\Users\\ankit\\Workspaces\\CS7150\\data\\imagenet", transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) # Create the dataloader dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) images = next(iter(dataloader)) out = dis(images[0]) print()
def train_d(args, dataset): logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) use_cuda = (torch.cuda.device_count() >= 1) # check checkpoints saving path if not os.path.exists('checkpoints/discriminator'): os.makedirs('checkpoints/discriminator') checkpoints_path = 'checkpoints/discriminator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['train_acc'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['valid_acc'] = AverageMeter() logging_meters['update_times'] = AverageMeter() # Build model discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # Load generator assert os.path.exists('checkpoints/generator/best_gmodel.pt') generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt') # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() # generator = torch.nn.DataParallel(generator).cuda() generator.cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() criterion = torch.nn.CrossEntropyLoss() # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()), # args.d_learning_rate, momentum=args.momentum, nesterov=True) optimizer = torch.optim.RMSprop( filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the accuracy achieve the define value max_epoch = args.max_epoch or math.inf epoch_i = 1 trg_acc = 0.82 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] # validation set data loader (only prepare once) train = prepare_training_data(args, dataset, 'train', generator, epoch_i, use_cuda) valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i, use_cuda) data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len) data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len) # main training loop while lr > args.min_d_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) if args.sample_without_replacement > 0 and epoch_i > 1: train = prepare_training_data(args, dataset, 'train', generator, epoch_i, use_cuda) data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len) # discriminator training dataloader train_loader = train_dataloader(data_train, batch_size=args.joint_batch_size, seed=seed, epoch=epoch_i, sort_by_source_size=False) valid_loader = eval_dataloader(data_valid, num_workers=4, batch_size=args.joint_batch_size) # set training mode discriminator.train() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(train_loader): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=use_cuda) disc_out = discriminator(sample['src_tokens'], sample['trg_tokens']) loss = criterion(disc_out, sample['labels']) _, prediction = F.softmax(disc_out, dim=1).topk(1) acc = torch.sum( prediction == sample['labels'].unsqueeze(1)).float() / len( sample['labels']) logging_meters['train_acc'].update(acc.item()) logging_meters['train_loss'].update(loss.item()) logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \ format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg, optimizer.param_groups[0]['lr'], i)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(discriminator.parameters(), args.clip_norm) optimizer.step() # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc del disc_out, loss, prediction, acc # set validation mode discriminator.eval() for i, sample in enumerate(valid_loader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=use_cuda) disc_out = discriminator(sample['src_tokens'], sample['trg_tokens']) loss = criterion(disc_out, sample['labels']) _, prediction = F.softmax(disc_out, dim=1).topk(1) acc = torch.sum( prediction == sample['labels'].unsqueeze(1)).float() / len( sample['labels']) logging_meters['valid_acc'].update(acc.item()) logging_meters['valid_loss'].update(loss.item()) logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \ format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg, optimizer.param_groups[0]['lr'], i)) del disc_out, loss, prediction, acc lr_scheduler.step(logging_meters['valid_loss'].avg) if logging_meters['valid_acc'].avg >= 0.70: torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \ .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i)) if logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = logging_meters['valid_loss'].avg torch.save(discriminator.state_dict(), checkpoints_path + "best_dmodel.pt") # pretrain the discriminator to achieve accuracy 82% if logging_meters['valid_acc'].avg >= trg_acc: return epoch_i += 1
class SegPoseNet(nn.Module): def __init__(self, data_options): super(SegPoseNet, self).__init__() pose_arch_cfg = data_options['pose_arch_cfg'] self.width = int(data_options['width']) self.height = int(data_options['height']) self.channels = int(data_options['channels']) self.domains = int(data_options['domains']) # note you need to change this after modifying the network self.output_h = 76 self.output_w = 76 self.coreModel = Darknet(pose_arch_cfg, self.width, self.height, self.channels, self.domains) self.segLayer = PoseSegLayer(data_options) self.regLayer = Pose2DLayer(data_options) self.discLayer = Discriminator() self.training = False def forward(self, x, y = None, adapt = False, domains = None): outlayers = self.coreModel(x, domains=domains) if self.training and adapt: in1 = source_only(outlayers[0], domains) in2 = source_only(outlayers[1], domains) else: in1 = outlayers[0] in2 = outlayers[1] out3 = self.discLayer(outlayers[2]) out4 = outlayers[3] out5 = outlayers[4] out1 = self.segLayer(in1) out2 = self.regLayer(in2) out_preds = [out1, out2, out3, out4, out5] return out_preds def train(self): self.coreModel.train() self.segLayer.train() self.regLayer.train() self.discLayer.train() self.training = True def eval(self): self.coreModel.eval() self.segLayer.eval() self.regLayer.eval() self.discLayer.eval() self.training = False def print_network(self): self.coreModel.print_network() def load_weights(self, weightfile): self.coreModel.load_state_dict(torch.load(weightfile)) def save_weights(self, weightfile): torch.save(self.coreModel.state_dict(), weightfile)
# Loss and accuracy within the current epoch. loss1.append(gradient_penalty.item()) loss2.append(disc_fake_source.item()) loss3.append(disc_real_source.item()) loss4.append(disc_real_class.item()) loss5.append(disc_fake_class.item()) acc1.append(accuracy) if batch_idx % 50 == 0: print("[", epoch, batch_idx, "]", "%.2f" % np.mean(loss1), "%.2f" % np.mean(loss2), "%.2f" % np.mean(loss3), "%.2f" % np.mean(loss4), "%.2f" % np.mean(loss5), "%.2f" % np.mean(acc1)) # Test the model after every epoch. aD.eval() with torch.no_grad(): test_accu = [] for batch_idx, (X_test_batch, Y_test_batch) in enumerate(testloader): X_test_batch, Y_test_batch = Variable( X_test_batch).cuda(), Variable(Y_test_batch).cuda() with torch.no_grad(): _, output = aD(X_test_batch) prediction = output.data.max(1)[1] # first column has actual prob. accuracy = (float(prediction.eq(Y_test_batch.data).sum()) / float(batch_size)) * 100.0 test_accu.append(accuracy) accuracy_test = np.mean(test_accu)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator loaded successfully!") g_model_path = 'checkpoints/zhenwarm/generator.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/myzhencli5'): os.makedirs('checkpoints/myzhencli5') checkpoints_path = 'checkpoints/myzhencli5/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(trainloader): # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when random.random() > 50% if random.random() >= 0.5: print("Policy Gradient Training") sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch = sample['target'] # 64 x 50 pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss = pg_loss / math.log(2) g_logging_meters['train_loss'].update(logging_loss.item(), sample_size) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: # MLE training print("MLE Training") sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 train_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator if num_update % 5 == 0: bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = Variable( torch.ones( sample['target'].size(0)).float()) # 64 length vector with torch.no_grad(): sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros( sample['target'].size(0)).float()) # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() # fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 # true_disc_out = discriminator(src_sentence, true_sentence) # # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (fake_acc + true_acc) / 2 # # d_loss = fake_d_loss + true_d_loss if random.random() > 0.5: fake_disc_out = discriminator(src_sentence, fake_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss acc = fake_acc else: true_disc_out = discriminator(src_sentence, true_sentence) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) d_loss = true_d_loss acc = true_acc d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug( f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}" ) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() if num_update % 10000 == 0: # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss = g_criterion(out_batch, dev_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update( loss, sample_size) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) fake_sentence = torch.reshape( prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 true_disc_out = discriminator(src_sentence, true_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) d_loss = fake_d_loss + true_d_loss fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) acc = (fake_acc + true_acc) / 2 d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) logging.debug( f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}" ) # torch.save(discriminator, # open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill) # if d_logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = d_logging_meters['valid_loss'].avg # torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)
targeted_model = torch.nn.DataParallel(resnet_model.__dict__['resnet32']()) targeted_model.cuda() targeted_model.load_state_dict(checkpoint['state_dict']) targeted_model.eval() # load the generator of adversarial gan pretrained_generator_path = './models/lp_pretrained/netG_rl_epoch_20.pth' pretrained_G = Generator().to(device) pretrained_G.load_state_dict(torch.load(pretrained_generator_path)) pretrained_G.eval() # load the discriminator of adversarial gan pretrained_disciminator_path = './models/lp_pretrained/netDisc_rl_epoch_20.pth' pretrained_Disc = Discriminator().to(device) pretrained_Disc.load_state_dict(torch.load(pretrained_disciminator_path)) pretrained_Disc.eval() # load the Pixel Valuation network pretrained_pvrl_path = './models/lp_pretrained/netPv_rl_epoch_20.pth' pretrained_PV = PVRL().to(device) pretrained_PV.load_state_dict(torch.load(pretrained_pvrl_path)) pretrained_PV.eval() # test adversarial examples in CIFAR10 training dataset cifar_dataset = torchvision.datasets.CIFAR10('./data', train=True, transform=transforms.ToTensor(), download=True) train_dataloader = DataLoader(cifar_dataset, batch_size=batch_size, shuffle=False,
def semi_main(options): print('\nSemi-Supervised Learning!\n') # 1. Make sure the options are valid argparse CLI options indeed assert isinstance(options, argparse.Namespace) # 2. Set up the logger logging.basicConfig(level=str(options.loglevel).upper()) # 3. Make sure the output dir `outf` exists _check_out_dir(options) # 4. Set the random state _set_random_state(options) # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch device = _configure_cuda(options) # 6. Prepare the datasets and split it for semi-supervised learning if options.dataset != 'cifar10': raise NotImplementedError( 'Semi-supervised learning only support CIFAR10 dataset at the moment!' ) test_data_loader, semi_data_loader, train_data_loader = _prepare_semi_dataset( options) # 7. Set the parameters ngpu = int(options.ngpu) # num of GPUs nz = int( options.nz) # size of latent vector, also the number of the generators ngf = int(options.ngf) # depth of feature maps through G ndf = int(options.ndf) # depth of feature maps through D nc = int(options.nc ) # num of channels of the input images, 3 indicates color images M = int(options.mcmc) # num of SGHMC chains run concurrently nd = int(options.nd) # num of discriminators nsetz = int(options.nsetz) # num of noise batches # 8. Special preparations for Bayesian GAN for Generators # In order to inject the SGHMAC into the training process, instead of pause the gradient descent at # each training step, which can be easily defined with static computation graph(Tensorflow), in PyTorch, # we have to move the Generator Sampling to the very beginning of the whole training process, and use # a trick that initializing all of the generators explicitly for later usages. Generator_chains = [] for _ in range(nsetz): for __ in range(M): netG = Generator(ngpu, nz, ngf, nc).to(device) netG.apply(weights_init) Generator_chains.append(netG) logging.info( f'Showing the first generator of the Generator chain: \n {Generator_chains[0]}\n' ) # 9. Special preparations for Bayesian GAN for Discriminators assert options.dataset == 'cifar10', 'Semi-supervised learning only support CIFAR10 dataset at the moment!' num_class = 10 + 1 # To simplify the implementation we only consider the situation of 1 discriminator # if nd <= 1: # netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device) # netD.apply(weights_init) # else: # Discriminator_chains = [] # for _ in range(nd): # for __ in range(M): # netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device) # netD.apply(weights_init) # Discriminator_chains.append(netD) netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device) netD.apply(weights_init) logging.info(f'Showing the Discriminator model: \n {netD}\n') # 10. Loss function criterion = nn.CrossEntropyLoss() all_criterion = ComplementCrossEntropyLoss(except_index=0, device=device) # 11. Set up optimizers optimizerG_chains = [ optim.Adam(netG.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netG in Generator_chains ] # optimizerD_chains = [ # optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netD in Discriminator_chains # ] optimizerD = optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) import math # 12. Set up the losses for priors and noises gprior = PriorLoss(prior_std=1., total=500.) gnoise = NoiseLoss(params=Generator_chains[0].parameters(), device=device, scale=math.sqrt(2 * options.alpha / options.lr), total=500.) dprior = PriorLoss(prior_std=1., total=50000.) dnoise = NoiseLoss(params=netD.parameters(), device=device, scale=math.sqrt(2 * options.alpha * options.lr), total=50000.) gprior.to(device=device) gnoise.to(device=device) dprior.to(device=device) dnoise.to(device=device) # In order to let G condition on a specific noise, we attach the noise to a fixed Tensor fixed_noise = torch.FloatTensor(options.batchSize, options.nz, 1, 1).normal_(0, 1).to(device=device) inputT = torch.FloatTensor(options.batchSize, 3, options.imageSize, options.imageSize).to(device=device) noiseT = torch.FloatTensor(options.batchSize, options.nz, 1, 1).to(device=device) labelT = torch.FloatTensor(options.batchSize).to(device=device) real_label = 1 fake_label = 0 # 13. Transfer all the tensors and modules to GPU if applicable # for netD in Discriminator_chains: # netD.to(device=device) netD.to(device=device) for netG in Generator_chains: netG.to(device=device) criterion.to(device=device) all_criterion.to(device=device) # ======================== # === Training Process === # ======================== # Lists to keep track of progress img_list = [] G_losses = [] D_losses = [] stats = [] iters = 0 try: print("\nStarting Training Loop...\n") for epoch in range(options.niter): top1 = Metrics() for i, data in enumerate(train_data_loader, 0): # ################## # Train with real # ################## netD.zero_grad() real_cpu = data[0].to(device) batch_size = real_cpu.size(0) # label = torch.full((batch_size,), real_label, device=device) inputT.resize_as_(real_cpu).copy_(real_cpu) labelT.resize_(batch_size).fill_(real_label) inputv = torch.autograd.Variable(inputT) labelv = torch.autograd.Variable(labelT) output = netD(inputv) errD_real = all_criterion(output) errD_real.backward() D_x = 1 - torch.nn.functional.softmax( output).data[:, 0].mean().item() # ################## # Train with fake # ################## fake_images = [] for i_z in range(nsetz): noiseT.resize_(batch_size, nz, 1, 1).normal_( 0, 1) # prior, sample from N(0, 1) distribution noisev = torch.autograd.Variable(noiseT) for m in range(M): idx = i_z * M + m netG = Generator_chains[idx] _fake = netG(noisev) fake_images.append(_fake) # output = torch.stack(fake_images) fake = torch.cat(fake_images) output = netD(fake.detach()) labelv = torch.autograd.Variable( torch.LongTensor(fake.data.shape[0]).to( device=device).fill_(fake_label)) errD_fake = criterion(output, labelv) errD_fake.backward() D_G_z1 = 1 - torch.nn.functional.softmax( output).data[:, 0].mean().item() # ################## # Semi-supervised learning # ################## for ii, (input_sup, target_sup) in enumerate(semi_data_loader): input_sup, target_sup = input_sup.to( device=device), target_sup.to(device=device) break input_sup_v = input_sup.to(device=device) target_sup_v = (target_sup + 1).to(device=device) output_sup = netD(input_sup_v) err_sup = criterion(output_sup, target_sup_v) err_sup.backward() pred1 = accuracy(output_sup.data, target_sup + 1, topk=(1, ))[0] top1.update(value=pred1.item(), N=input_sup.size(0)) errD_prior = dprior(netD.parameters()) errD_prior.backward() errD_noise = dnoise(netD.parameters()) errD_noise.backward() errD = errD_real + errD_fake + err_sup + errD_prior + errD_noise optimizerD.step() # ################## # Sample and construct generator(s) # ################## for netG in Generator_chains: netG.zero_grad() labelv = torch.autograd.Variable( torch.FloatTensor(fake.data.shape[0]).to( device=device).fill_(real_label)) output = netD(fake) errG = all_criterion(output) for netG in Generator_chains: errG = errG + gprior(netG.parameters()) errG = errG + gnoise(netG.parameters()) errG.backward() D_G_z2 = 1 - torch.nn.functional.softmax( output).data[:, 0].mean().item() for optimizerG in optimizerG_chains: optimizerG.step() # ################## # Evaluate testing accuracy # ################## # Pause and compute the test accuracy after every 10 times of the notefreq if iters % 10 * int(options.notefreq) == 0: # get test accuracy on train and test netD.eval() compute_test_accuracy(discriminator=netD, testing_data_loader=test_data_loader, device=device) netD.train() # ################## # Note down # ################## # Report status for the current iteration training_status = f"[{epoch}/{options.niter}][{i}/{len(train_data_loader)}] Loss_D: {errD.item():.4f} " \ f"Loss_G: " \ f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}" \ f" | Acc {top1.value:.1f} / {top1.mean:.1f}" print(training_status) # Save samples to disk if i % int(options.notefreq) == 0: vutils.save_image( real_cpu, f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png", normalize=True) for _iz in range(nsetz): for _m in range(M): gidx = _iz * M + _m netG = Generator_chains[gidx] fake = netG(fixed_noise) vutils.save_image( fake.detach(), f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}_z{_iz}_m{_m}.png", normalize=True) # Save Losses statistics for post-mortem G_losses.append(errG.item()) D_losses.append(errD.item()) stats.append(training_status) # # Check how the generator is doing by saving G's output on fixed_noise # if (iters % 500 == 0) or ((epoch == options.niter - 1) and (i == len(data_loader) - 1)): # with torch.no_grad(): # fake = netG(fixed_noise).detach().cpu() # img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) iters += 1 # TODO: find an elegant way to support saving checkpoints in Bayesian GAN context except Exception as e: print(e) # save training stats no matter what kind of errors occur in the processes _save_stats(statistic=G_losses, save_name='G_losses', options=options) _save_stats(statistic=D_losses, save_name='D_losses', options=options) _save_stats(statistic=stats, save_name='Training_stats', options=options)
class AttnGAN: def __init__(self, damsm, device=DEVICE): self.gen = Generator(device) self.disc = Discriminator(device) self.damsm = damsm.to(device) self.damsm.txt_enc.eval(), self.damsm.img_enc.eval() freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc) self.device = device self.gen.apply(init_weights), self.disc.apply(init_weights) self.gen_optimizer = torch.optim.Adam(self.gen.parameters(), lr=GENERATOR_LR, betas=(0.5, 0.999)) self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256] self.disc_optimizers = [ torch.optim.Adam(d.parameters(), lr=DISCRIMINATOR_LR, betas=(0.5, 0.999)) for d in self.discriminators ] #@torch.no_grad() def train(self, dataset, epoch, batch_size=GAN_BATCH, test_sample_every=5, hist_avg=False, evaluator=None): start_time = time.strftime("%Y-%m-%d-%H-%M", time.gmtime()) os.makedirs(f'{OUT_DIR}/{start_time}') # print('cun') # for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True): # self.gen.eval() # generated_samples = [resolution.unsqueeze(0) for resolution in self.sample_test_set(dataset)] # self._save_generated(generated_samples, e, f'{OUT_DIR}/{start_time}') # # return if hist_avg: avg_g_params = deepcopy(list(p.data for p in self.gen.parameters())) loader_config = { 'batch_size': batch_size, 'shuffle': True, 'drop_last': True, 'collate_fn': dataset.collate_fn } train_loader = DataLoader(dataset.train, **loader_config) metrics = { 'IS': [], 'FID': [], 'loss': { 'g': [], 'd': [] }, 'accuracy': { 'real': [], 'fake': [], 'mismatched': [], 'unconditional_real': [], 'unconditional_fake': [] } } if evaluator is not None: evaluator = evaluator(dataset, self.damsm.img_enc.inception_model, batch_size, self.device) noise = torch.FloatTensor(batch_size, D_Z).to(self.device) gen_updates = 0 self.disc.train() for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True): self.gen.train(), self.disc.train() g_loss = 0 w_loss = 0 s_loss = 0 kl_loss = 0 g_stage_loss = np.zeros(3, dtype=float) d_loss = np.zeros(3, dtype=float) real_acc = np.zeros(3, dtype=float) fake_acc = np.zeros(3, dtype=float) mismatched_acc = np.zeros(3, dtype=float) uncond_real_acc = np.zeros(3, dtype=float) uncond_fake_acc = np.zeros(3, dtype=float) disc_skips = np.zeros(3, dtype=int) train_pbar = tqdm(train_loader, desc='Training', leave=False, dynamic_ncols=True) for batch in train_pbar: real_imgs = [batch['img64'], batch['img128'], batch['img256']] with torch.no_grad(): word_embs, sent_embs = self.damsm.txt_enc(batch['caption']) attn_mask = torch.tensor(batch['caption']).to( self.device) == dataset.vocab[END_TOKEN] # Generate images noise.data.normal_(0, 1) generated, att, mu, logvar = self.gen(noise, sent_embs, word_embs, attn_mask) # Discriminator loss (with label smoothing) batch_d_loss, batch_real_acc, batch_fake_acc, batch_mismatched_acc, batch_uncond_real_acc, batch_uncond_fake_acc, batch_disc_skips = self.discriminator_step( real_imgs, generated, sent_embs, 0.1) d_grad_norm = [grad_norm(d) for d in self.discriminators] d_loss += batch_d_loss real_acc += batch_real_acc fake_acc += batch_fake_acc mismatched_acc += batch_mismatched_acc uncond_real_acc += batch_uncond_real_acc uncond_fake_acc += batch_uncond_fake_acc disc_skips += batch_disc_skips # Generator loss batch_g_losses = self.generator_step(generated, word_embs, sent_embs, mu, logvar, batch['label']) g_total, batch_g_stage_loss, batch_w_loss, batch_s_loss, batch_kl_loss = batch_g_losses g_stage_loss += batch_g_stage_loss w_loss += batch_w_loss s_loss += (batch_s_loss) kl_loss += (batch_kl_loss) gen_updates += 1 avg_g_loss = g_total.item() / batch_size g_loss += float(avg_g_loss) if hist_avg: for p, avg_p in zip(self.gen.parameters(), avg_g_params): avg_p.mul_(0.999).add_(0.001, p.data) if gen_updates % 1000 == 0: tqdm.write( 'Replacing generator weights with their moving average' ) for p, avg_p in zip(self.gen.parameters(), avg_g_params): p.data.copy_(avg_p) train_pbar.set_description( f'Training (G: {grad_norm(self.gen):.2f} ' f'D64: {d_grad_norm[0]:.2f} ' f'D128: {d_grad_norm[1]:.2f} ' f'D256: {d_grad_norm[2]:.2f})') batches = len(train_loader) g_loss /= batches g_stage_loss /= batches w_loss /= batches s_loss /= batches kl_loss /= batches d_loss /= batches real_acc /= batches fake_acc /= batches mismatched_acc /= batches uncond_real_acc /= batches uncond_fake_acc /= batches metrics['loss']['g'].append(g_loss) metrics['loss']['d'].append(d_loss) metrics['accuracy']['real'].append(real_acc) metrics['accuracy']['fake'].append(fake_acc) metrics['accuracy']['mismatched'].append(mismatched_acc) metrics['accuracy']['unconditional_real'].append(uncond_real_acc) metrics['accuracy']['unconditional_fake'].append(uncond_fake_acc) sep = '_' * 10 tqdm.write(f'{sep}Epoch {e}{sep}') if e % test_sample_every == 0: self.gen.eval() generated_samples = [ resolution.unsqueeze(0) for resolution in self.sample_test_set(dataset) ] self._save_generated(generated_samples, e, f'{OUT_DIR}/{start_time}') if evaluator is not None: scores = evaluator.evaluate(self) for k, v in scores.items(): metrics[k].append(v) tqdm.write(f'{k}: {v:.2f}') tqdm.write( f'Generator avg loss: total({g_loss:.3f}) ' f'stage0({g_stage_loss[0]:.3f}) stage1({g_stage_loss[1]:.3f}) stage2({g_stage_loss[2]:.3f}) ' f'w({w_loss:.3f}) s({s_loss:.3f}) kl({kl_loss:.3f})') for i, _ in enumerate(self.discriminators): tqdm.write(f'Discriminator{i} avg: ' f'loss({d_loss[i]:.3f}) ' f'r-acc({real_acc[i]:.3f}) ' f'f-acc({fake_acc[i]:.3f}) ' f'm-acc({mismatched_acc[i]:.3f}) ' f'ur-acc({uncond_real_acc[i]:.3f}) ' f'uf-acc({uncond_fake_acc[i]:.3f}) ' f'skips({disc_skips[i]})') return metrics def sample_test_set(self, dataset, nb_samples=8, nb_captions=2, noise_variations=2): subset = dataset.test sample_indices = np.random.choice(len(subset), nb_samples, replace=False) cap_indices = np.random.choice(10, nb_captions, replace=False) texts = [ subset.data[f'caption_{cap_idx}'].iloc[sample_idx] for sample_idx in sample_indices for cap_idx in cap_indices ] generated_samples = [ self.generate_from_text(texts, dataset) for _ in range(noise_variations) ] combined_img64 = torch.FloatTensor() combined_img128 = torch.FloatTensor() combined_img256 = torch.FloatTensor() for noise_variant in generated_samples: noise_var_img64 = torch.FloatTensor() noise_var_img128 = torch.FloatTensor() noise_var_img256 = torch.FloatTensor() for i in range(nb_samples): # rows: samples, columns: captions * noise variants row64 = torch.cat([ noise_variant[0][i * nb_captions + j] for j in range(nb_captions) ], dim=-1).cpu() row128 = torch.cat([ noise_variant[1][i * nb_captions + j] for j in range(nb_captions) ], dim=-1).cpu() row256 = torch.cat([ noise_variant[2][i * nb_captions + j] for j in range(nb_captions) ], dim=-1).cpu() noise_var_img64 = torch.cat([noise_var_img64, row64], dim=-2) noise_var_img128 = torch.cat([noise_var_img128, row128], dim=-2) noise_var_img256 = torch.cat([noise_var_img256, row256], dim=-2) combined_img64 = torch.cat([combined_img64, noise_var_img64], dim=-1) combined_img128 = torch.cat([combined_img128, noise_var_img128], dim=-1) combined_img256 = torch.cat([combined_img256, noise_var_img256], dim=-1) return combined_img64, combined_img128, combined_img256 @staticmethod def KL_loss(mu, logvar): loss = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) loss = torch.mean(loss).mul_(-0.5) return loss def generator_step(self, generated_imgs, word_embs, sent_embs, mu, logvar, class_labels): self.gen.zero_grad() avg_stage_g_loss = [0, 0, 0] local_features, global_features = self.damsm.img_enc( generated_imgs[-1]) batch_size = sent_embs.size(0) match_labels = torch.LongTensor(range(batch_size)).to(self.device) w1_loss, w2_loss, _ = self.damsm.words_loss(local_features, word_embs, class_labels, match_labels) w_loss = (w1_loss + w2_loss) * LAMBDA s1_loss, s2_loss = self.damsm.sentence_loss(global_features, sent_embs, class_labels, match_labels) s_loss = (s1_loss + s2_loss) * LAMBDA kl_loss = self.KL_loss(mu, logvar) g_total = w_loss + s_loss + kl_loss for i, d in enumerate(self.discriminators): features = d(generated_imgs[i]) fake_logits = d.logit(features, sent_embs) real_labels = torch.ones_like(fake_logits).to(self.device) disc_error = F.binary_cross_entropy_with_logits( fake_logits, real_labels) uncond_fake_logits = d.logit(features) uncond_disc_error = F.binary_cross_entropy_with_logits( uncond_fake_logits, real_labels) stage_loss = disc_error + uncond_disc_error avg_stage_g_loss[i] = stage_loss.item() / batch_size g_total += stage_loss g_total.backward() self.gen_optimizer.step() return g_total, avg_stage_g_loss, w_loss.item( ) / batch_size, s_loss.item() / batch_size, kl_loss.item() def discriminator_step(self, real_imgs, generated_imgs, sent_embs, label_smoothing, skip_acc_threshold=0.9, p_flip=0.05, halting=False): self.disc.zero_grad() batch_size = sent_embs.size(0) avg_d_loss = [0, 0, 0] real_accuracy = [0, 0, 0] fake_accuracy = [0, 0, 0] mismatched_accuracy = [0, 0, 0] uncond_real_accuracy = [0, 0, 0] uncond_fake_accuracy = [0, 0, 0] skipped = [0, 0, 0] for i, d in enumerate(self.discriminators): real_features = d(real_imgs[i].to(self.device)) fake_features = d(generated_imgs[i].detach()) real_logits = d.logit(real_features, sent_embs) real_labels = torch.full_like(real_logits, 1 - label_smoothing).to(self.device) fake_labels = torch.zeros_like(real_logits, dtype=torch.float).to(self.device) # flip_mask = torch.Tensor(real_labels.size()).bernoulli_(p_flip).type(torch.bool) # real_labels[flip_mask], fake_labels[flip_mask] = fake_labels[flip_mask], real_labels[flip_mask] real_error = F.binary_cross_entropy_with_logits( real_logits, real_labels) # Real images should be classified as real real_accuracy[i] = (real_logits >= 0).sum().item() / real_logits.numel() fake_logits = d.logit(fake_features, sent_embs) fake_error = F.binary_cross_entropy_with_logits( fake_logits, fake_labels) # Generated images should be classified as fake fake_accuracy[i] = (fake_logits < 0).sum().item() / fake_logits.numel() mismatched_logits = d.logit(real_features, rotate_tensor(sent_embs, 1)) mismatched_error = F.binary_cross_entropy_with_logits( mismatched_logits, fake_labels) # Images with mismatched descriptions should be classified as fake mismatched_accuracy[i] = (mismatched_logits < 0).sum().item( ) / mismatched_logits.numel() uncond_real_logits = d.logit(real_features) uncond_real_error = F.binary_cross_entropy_with_logits( uncond_real_logits, real_labels) uncond_real_accuracy[i] = (uncond_real_logits >= 0).sum().item( ) / uncond_real_logits.numel() uncond_fake_logits = d.logit(fake_features) uncond_fake_error = F.binary_cross_entropy_with_logits( uncond_fake_logits, fake_labels) uncond_fake_accuracy[i] = (uncond_fake_logits < 0).sum().item( ) / uncond_fake_logits.numel() error = (real_error + uncond_real_error) / 2 + ( fake_error + uncond_fake_error + mismatched_error) / 3 if not halting or fake_accuracy[i] + real_accuracy[ i] < skip_acc_threshold * 2: error.backward() self.disc_optimizers[i].step() else: skipped[i] = 1 avg_d_loss[i] = error.item() / batch_size return avg_d_loss, real_accuracy, fake_accuracy, mismatched_accuracy, uncond_real_accuracy, uncond_fake_accuracy, skipped def generate_from_text(self, texts, dataset, noise=None): encoded = [dataset.train.encode_text(t) for t in texts] generated = self.generate_from_encoded_text(encoded, dataset, noise) return generated def generate_from_encoded_text(self, encoded, dataset, noise=None): with torch.no_grad(): w_emb, s_emb = self.damsm.txt_enc(encoded) attn_mask = torch.tensor(encoded).to( self.device) == dataset.vocab[END_TOKEN] if noise is None: noise = torch.FloatTensor(len(encoded), D_Z).to(self.device) noise.data.normal_(0, 1) generated, att, mu, logvar = self.gen(noise, s_emb, w_emb, attn_mask) return generated def _save_generated(self, generated, epoch, out_dir=OUT_DIR): nb_samples = generated[0].size(0) save_dir = f'{out_dir}/epoch_{epoch:03}' os.makedirs(save_dir) for i in range(nb_samples): save_image(generated[0][i], f'{save_dir}/{i}_64.jpg', normalize=True, range=(-1, 1)) save_image(generated[1][i], f'{save_dir}/{i}_128.jpg', normalize=True, range=(-1, 1)) save_image(generated[2][i], f'{save_dir}/{i}_256.jpg', normalize=True, range=(-1, 1)) def save(self, name, save_dir=GAN_MODEL_DIR, metrics=None): os.makedirs(save_dir, exist_ok=True) torch.save(self.gen.state_dict(), f'{save_dir}/{name}_generator.pt') torch.save(self.disc.state_dict(), f'{save_dir}/{name}_discriminator.pt') if metrics is not None: with open(f'{save_dir}/{name}_metrics.json', 'w') as f: metrics = pre_json_metrics(metrics) json.dump(metrics, f) def load_(self, name, load_dir=GAN_MODEL_DIR): self.gen.load_state_dict(torch.load(f'{load_dir}/{name}_generator.pt')) self.disc.load_state_dict( torch.load(f'{load_dir}/{name}_discriminator.pt')) self.gen.eval(), self.disc.eval() @staticmethod def load(name, damsm, load_dir=GAN_MODEL_DIR, device=DEVICE): attngan = AttnGAN(damsm, device=device) attngan.load_(name, load_dir) return attngan def validate_test_set(self, dataset, batch_size=GAN_BATCH, save_dir=f'{OUT_DIR}/test_samples'): os.makedirs(save_dir, exist_ok=True) loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=dataset.collate_fn) loader = tqdm(loader, dynamic_ncols=True, leave=True, desc='Generating samples for test set') self.gen.eval() with torch.no_grad(): i = 0 for batch in loader: word_embs, sent_embs = self.damsm.txt_enc(batch['caption']) attn_mask = torch.tensor(batch['caption']).to( self.device) == dataset.vocab[END_TOKEN] noise = torch.FloatTensor(len(batch['caption']), D_Z).to(self.device) noise.data.normal_(0, 1) generated, att, mu, logvar = self.gen(noise, sent_embs, word_embs, attn_mask) for img in generated[-1]: save_image(img, f'{save_dir}/{i}.jpg', normalize=True, range=(-1, 1)) i += 1 def get_d_score(self, imgs, sent_embs): d = self.disc.d256 features = d(imgs.to(self.device)) scores = d.logit(features, sent_embs) return scores def accept_prob(self, score1, score2): return min(1, (1 / score1 - 1) / (1 / score2 - 1)) def d_scores_test(self, dataset): with torch.no_grad(): loader = DataLoader(dataset.test, batch_size=20, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) scores = [] d = self.disc.d256 for b in loader: img = b['img256'].to(self.device) f = d(img) l = d.logit(f) scores.append(torch.sigmoid(l)) scores = [x.item() for s in scores for x in s.reshape(-1)] return scores def z_test(self, scores, labels): labels = np.array(labels) scores = np.array(scores) num = np.sum(labels - scores) denom = np.sqrt(np.sum(scores * (1 - scores))) return num / denom def d_scores_gen(self, dataset): with torch.no_grad(): loader = DataLoader(dataset.test, batch_size=20, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) scores = [] d = self.disc.d256 for b in loader: noise = torch.FloatTensor(len(b['caption']), D_Z).to(self.device) noise.data.normal_(0, 1) word_embs, sent_embs = self.damsm.txt_enc(b['caption']) attn_mask = torch.tensor(b['caption']).to( self.device) == dataset.vocab[END_TOKEN] generated, _, _, _ = self.gen(noise, sent_embs, word_embs, attn_mask) f = d(generated[-1]) l = d.logit(f) scores.append(torch.sigmoid(l)) scores = [x.item() for s in scores for x in s.reshape(-1)] return scores def mh_sample(self, dataset, k, save_dir='test_samples', batch=GAN_BATCH): evaluator = IS_FID_Evaluator(dataset, self.damsm.img_enc.inception_model, batch, self.device) # self.disc.d256.train() with torch.no_grad(): l = len(dataset.test) score_real = self.d_scores_test(dataset) score_gen = self.d_scores_gen(dataset) print(np.mean(score_real)) print(np.mean(score_gen)) portion = -l // 5 score_test = score_real[:portion] + score_gen[:portion] label_test = [1] * (len(score_test) // 2) + [0] * (len(score_test) // 2) print('Z test before calibration: ', self.z_test(torch.tensor(score_test), label_test)) score_real_calib = score_real[portion:] score_gen_calib = score_gen[portion:] # score_calib = score_real_calib + score_gen_calib score_calib = score_gen_calib + score_real_calib label_calib = len(score_gen_calib) * [0] + len( score_real_calib) * [1] cal_clf = LogisticRegression() cal_clf.fit(np.array(score_calib).reshape(-1, 1), label_calib) score_pred = cal_clf.predict_proba( np.array(score_test).reshape(-1, 1))[:, 1] print('Score pred avg: ', np.mean(score_pred)) test_pred = cal_clf.predict(np.array(score_test).reshape(-1, 1)) print('Z test after calibration: ', self.z_test(score_pred, label_test)) print('Accuracy: ', sum((test_pred == label_test)) / len(test_pred)) os.makedirs(save_dir, exist_ok=True) loader = DataLoader(dataset.test, batch_size=1, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) loader = tqdm(loader, dynamic_ncols=True, leave=True, desc='Generating samples for test set') imgs = [] true_probs = 0 noaccept = 0 for i, sample in enumerate(loader): if i > l - (l // 10): continue word_embs, sent_embs = self.damsm.txt_enc(sample['caption']) attn_mask = torch.tensor(sample['caption']).to( self.device) == dataset.vocab[END_TOKEN] img_chain = [] while len(img_chain) < k: noise = torch.FloatTensor(batch, D_Z).to(self.device) noise.data.normal_(0, 1) generated, _, _, _ = self.gen( noise, sent_embs.repeat(batch, 1), word_embs.repeat(batch, 1, 1), attn_mask.repeat(batch, 1)) for img in generated[-1]: img_chain.append(img) img_chain = img_chain[:k] img_chain = torch.stack(img_chain).to(self.device) score_chain = [] d_loader = DataLoader(img_chain, batch_size=batch, shuffle=False, drop_last=False) for d_batch in d_loader: scores = self.get_d_score(d_batch, sent_embs.repeat(batch, 1)) scores = scores.reshape(-1, 1).cpu().numpy() scores = cal_clf.predict_proba(scores)[:, 1] for s in scores: score_chain.append(s) chosen = 0 for j, s in enumerate(score_chain[1:], 1): alpha = self.accept_prob(score_chain[chosen], s) if np.random.rand() < alpha: chosen = j if chosen == 0: imgs.append(img_chain[torch.tensor( score_chain[1:]).argmax()].cpu()) noaccept += 1 else: imgs.append(img_chain[chosen].cpu()) true_probs += score_chain[0] print(noaccept) print(true_probs / len(dataset.test)) mu_real, sig_real = evaluator.mu_real, evaluator.sig_real mu_fake, sig_fake = activation_statistics( self.damsm.img_enc.inception_model, imgs) print('FID: ', frechet_dist(mu_real, sig_real, mu_fake, sig_fake)) return imgs
def train_D_With_G(): aD = Discriminator() aD.cuda() aG = Generator() aG.cuda() optimizer_g = torch.optim.Adam(aG.parameters(), lr=0.0001, betas=(0, 0.9)) optimizer_d = torch.optim.Adam(aD.parameters(), lr=0.0001, betas=(0, 0.9)) criterion = nn.CrossEntropyLoss() n_z = 100 n_classes = 10 np.random.seed(352) label = np.asarray(list(range(10)) * 10) noise = np.random.normal(0, 1, (100, n_z)) label_onehot = np.zeros((100, n_classes)) label_onehot[np.arange(100), label] = 1 noise[np.arange(100), :n_classes] = label_onehot[np.arange(100)] noise = noise.astype(np.float32) save_noise = torch.from_numpy(noise) save_noise = Variable(save_noise).cuda() start_time = time.time() # Train the model num_epochs = 500 loss1 = [] loss2 = [] loss3 = [] loss4 = [] loss5 = [] acc1 = [] for epoch in range(0, num_epochs): aG.train() aD.train() avoidOverflow(optimizer_d) avoidOverflow(optimizer_g) for batch_idx, (X_train_batch, Y_train_batch) in enumerate(trainloader): if (Y_train_batch.shape[0] < batch_size): continue # train G if batch_idx % gen_train == 0: for p in aD.parameters(): p.requires_grad_(False) aG.zero_grad() label = np.random.randint(0, n_classes, batch_size) noise = np.random.normal(0, 1, (batch_size, n_z)) label_onehot = np.zeros((batch_size, n_classes)) label_onehot[np.arange(batch_size), label] = 1 noise[np.arange(batch_size), :n_classes] = label_onehot[ np.arange(batch_size)] noise = noise.astype(np.float32) noise = torch.from_numpy(noise) noise = Variable(noise).cuda() fake_label = Variable(torch.from_numpy(label)).cuda() fake_data = aG(noise) gen_source, gen_class = aD(fake_data) gen_source = gen_source.mean() gen_class = criterion(gen_class, fake_label) gen_cost = -gen_source + gen_class gen_cost.backward() optimizer_g.step() # train D for p in aD.parameters(): p.requires_grad_(True) aD.zero_grad() # train discriminator with input from generator label = np.random.randint(0, n_classes, batch_size) noise = np.random.normal(0, 1, (batch_size, n_z)) label_onehot = np.zeros((batch_size, n_classes)) label_onehot[np.arange(batch_size), label] = 1 noise[np.arange(batch_size), :n_classes] = label_onehot[np.arange( batch_size)] noise = noise.astype(np.float32) noise = torch.from_numpy(noise) noise = Variable(noise).cuda() fake_label = Variable(torch.from_numpy(label)).cuda() with torch.no_grad(): fake_data = aG(noise) disc_fake_source, disc_fake_class = aD(fake_data) disc_fake_source = disc_fake_source.mean() disc_fake_class = criterion(disc_fake_class, fake_label) # train discriminator with input from the discriminator real_data = Variable(X_train_batch).cuda() real_label = Variable(Y_train_batch).cuda() disc_real_source, disc_real_class = aD(real_data) prediction = disc_real_class.data.max(1)[1] accuracy = (float(prediction.eq(real_label.data).sum()) / float(batch_size)) * 100.0 disc_real_source = disc_real_source.mean() disc_real_class = criterion(disc_real_class, real_label) gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data) disc_cost = disc_fake_source - disc_real_source + disc_real_class + disc_fake_class + gradient_penalty disc_cost.backward() optimizer_d.step() loss1.append(gradient_penalty.item()) loss2.append(disc_fake_source.item()) loss3.append(disc_real_source.item()) loss4.append(disc_real_class.item()) loss5.append(disc_fake_class.item()) acc1.append(accuracy) if batch_idx % 50 == 0: print(epoch, batch_idx, "%.2f" % np.mean(loss1), "%.2f" % np.mean(loss2), "%.2f" % np.mean(loss3), "%.2f" % np.mean(loss4), "%.2f" % np.mean(loss5), "%.2f" % np.mean(acc1)) # Test the model aD.eval() with torch.no_grad(): test_accu = [] for batch_idx, (X_test_batch, Y_test_batch) in enumerate(testloader): X_test_batch, Y_test_batch = Variable( X_test_batch).cuda(), Variable(Y_test_batch).cuda() with torch.no_grad(): _, output = aD(X_test_batch) prediction = output.data.max(1)[ 1] # first column has actual prob. accuracy = (float(prediction.eq(Y_test_batch.data).sum()) / float(batch_size)) * 100.0 test_accu.append(accuracy) accuracy_test = np.mean(test_accu) print('Testing', accuracy_test, time.time() - start_time) # save output with torch.no_grad(): aG.eval() samples = aG(save_noise) samples = samples.data.cpu().numpy() samples += 1.0 samples /= 2.0 samples = samples.transpose(0, 2, 3, 1) aG.train() fig = plot(samples) plt.savefig('output/%s.png' % str(epoch).zfill(3), bbox_inches='tight') plt.close(fig) if (epoch + 1) % 1 == 0: torch.save(aG, 'tempG.model') torch.save(aD, 'tempD.model') torch.save(aG, 'generator.model') torch.save(aD, 'discriminator.model')
class Trainer: def __init__(self, params, *, n_samples=10000000): self.model = [ fastText.load_model( os.path.join(params.dataDir, params.model_path_0)), fastText.load_model( os.path.join(params.dataDir, params.model_path_1)) ] self.dic = [ list(zip(*self.model[id].get_words(include_freq=True))) for id in [0, 1] ] x = [ np.empty((params.vocab_size, params.emb_dim), dtype=np.float64) for _ in [0, 1] ] for id in [0, 1]: for i in range(params.vocab_size): x[id][i, :] = self.model[id].get_word_vector( self.dic[id][i][0]) x[id] = normalize_embeddings_np(x[id], params.normalize_pre) u0, s0, _ = scipy.linalg.svd(x[0], full_matrices=False) u1, s1, _ = scipy.linalg.svd(x[1], full_matrices=False) if params.spectral_align_pre: s = (s0 + s1) * 0.5 x[0] = u0 @ np.diag(s) x[1] = u1 @ np.diag(s) else: x[0] = u0 @ np.diag(s0) x[1] = u1 @ np.diag(s1) self.embedding = [ nn.Embedding.from_pretrained(torch.from_numpy(x[id]).to( torch.float).to(GPU), freeze=True, sparse=True) for id in [0, 1] ] self.discriminator = Discriminator( params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = Mapping(params.emb_dim).to(GPU) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt( self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear( self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt( self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd, factor=params.m_lr_decay, patience=params.m_lr_patience) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear( self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.wgan = params.wgan self.d_clip_mode = params.d_clip_mode if params.wgan: self.loss_fn = _wasserstein_distance else: self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.sampler = [ WordSampler(self.dic[id], n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top) for id in [0, 1] ] self.d_bs = params.d_bs self.d_gp = params.d_gp def get_adv_batch(self, *, reverse, gp=False): batch = [ torch.LongTensor( [self.sampler[id].sample() for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU) for id in [0, 1] ] with torch.no_grad(): x = [ self.embedding[id](batch[id]).view(self.d_bs, -1) for id in [0, 1] ] y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[:self.d_bs] = 1 - y[:self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] x[0] = self.mapping(x[0]) if gp: t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0]) z = x[0] * t + x[1] * (1.0 - t) x = torch.cat(x, 0) return x, y, z else: x = torch.cat(x, 0) return x, y def adversarial_step(self): self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() self.m_optimizer.step() self.mapping.clip_weights() return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): if self.d_gp > 0: x, y, z = self.get_adv_batch(reverse=False, gp=True) else: x, y = self.get_adv_batch(reverse=False) z = None y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) if self.d_gp > 0: z.requires_grad_() z_out = self.discriminator(z) g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU), retain_graph=True, create_graph=True, only_inputs=True)[0] gp = torch.mean((g.norm(p=2, dim=1) - 1.0)**2) loss += self.d_gp * gp loss.backward() self.d_optimizer.step() if self.wgan: self.discriminator.clip_weights(self.d_clip_mode) return loss.item() def scheduler_step(self, metric): self.m_scheduler.step(metric)
class Trainer: def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000): self.fast_text = [FastText(corpus_data_0.model).to(GPU), FastText(corpus_data_1.model).to(GPU)] self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False) self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim))) self.mapping = self.mapping.to(GPU) self.ft_optimizer, self.ft_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt(self.fast_text[id].parameters(), lr=params.ft_lr, mode="max", factor=params.ft_lr_decay, patience=params.ft_lr_patience) self.ft_optimizer.append(optimizer) self.ft_scheduler.append(scheduler) self.a_optimizer, self.a_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt( [{"params": self.fast_text[id].u.parameters()}, {"params": self.fast_text[id].v.parameters()}], lr=params.a_lr, mode="max", factor=params.a_lr_decay, patience=params.a_lr_patience) self.a_optimizer.append(optimizer) self.a_scheduler.append(scheduler) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd, factor=params.m_lr_decay, patience=params.m_lr_patience) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.wgan = params.wgan self.d_clip_mode = params.d_clip_mode if params.wgan: self.loss_fn = _wasserstein_distance else: self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.corpus_data_queue = [ _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.ft_bs), _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.ft_bs) ] self.sampler = [ WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top), WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)] self.d_bs = params.d_bs self.dic_0, self.dic_1 = corpus_data_0.dic, corpus_data_1.dic self.d_gp = params.d_gp def fast_text_step(self): losses = [] for id in [0, 1]: self.ft_optimizer[id].zero_grad() u_b, v_b = self.corpus_data_queue[id].__next__() s = self.fast_text[id](u_b, v_b) loss = FastText.loss_fn(s) loss.backward() self.ft_optimizer[id].step() losses.append(loss.item()) return losses[0], losses[1] def get_adv_batch(self, *, reverse, fix_embedding=False, gp=False): batch = [[self.sampler[id].sample() for _ in range(self.d_bs)] for id in [0, 1]] batch = [self.fast_text[id].model.get_bag(batch[id], self.fast_text[id].u.weight.device) for id in [0, 1]] if fix_embedding: with torch.no_grad(): x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]] else: x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]] y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[: self.d_bs] = 1 - y[: self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] x[0] = self.mapping(x[0]) if gp: t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0]) z = x[0] * t + x[1] * (1.0 - t) x = torch.cat(x, 0) return x, y, z else: x = torch.cat(x, 0) return x, y def adversarial_step(self, fix_embedding=False): for id in [0, 1]: self.a_optimizer[id].zero_grad() self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() for id in [0, 1]: self.a_optimizer[id].step() self.m_optimizer.step() _orthogonalize(self.mapping, self.m_beta) return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): if self.d_gp > 0: x, y, z = self.get_adv_batch(reverse=False, gp=True) else: x, y = self.get_adv_batch(reverse=False) z = None y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) if self.d_gp > 0: z.requires_grad_() z_out = self.discriminator(z) g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU), retain_graph=True, create_graph=True, only_inputs=True)[0] gp = torch.mean((g.norm(p=2, dim=1) - 1.0) ** 2) loss += self.d_gp * gp loss.backward() self.d_optimizer.step() if self.wgan: self.discriminator.clip_weights(self.d_clip_mode) return loss.item() def scheduler_step(self, metric): for id in [0, 1]: self.ft_scheduler[id].step(metric) self.a_scheduler[id].step(metric) # self.d_scheduler.step(metric) self.m_scheduler.step(metric)
def main(args): # log hyperparameter print(args) # select device args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda: 0" if args.cuda else "cpu") # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) # data loader transform = transforms.Compose([ utils.Normalize(), utils.ToTensor() ]) train_dataset = TVDataset( root=args.root, sub_size=args.block_size, volume_list=args.volume_train_list, max_k=args.training_step, train=True, transform=transform ) test_dataset = TVDataset( root=args.root, sub_size=args.block_size, volume_list=args.volume_test_list, max_k=args.training_step, train=False, transform=transform ) kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {} train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) # model def generator_weights_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) def discriminator_weights_init(m): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.zeros_(m.bias) g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual) g_model.apply(generator_weights_init) if args.data_parallel and torch.cuda.device_count() > 1: g_model = nn.DataParallel(g_model) g_model.to(device) if args.gan_loss != "none": d_model = Discriminator(args.dis_sn) d_model.apply(discriminator_weights_init) # if args.dis_sn: # d_model = add_sn(d_model) if args.data_parallel and torch.cuda.device_count() > 1: d_model = nn.DataParallel(d_model) d_model.to(device) mse_loss = nn.MSELoss() adversarial_loss = nn.MSELoss() train_losses, test_losses = [], [] d_losses, g_losses = [], [] # optimizer g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) if args.gan_loss != "none": d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr, betas=(args.beta1, args.beta2)) Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor # load checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint {}".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] g_model.load_state_dict(checkpoint["g_model_state_dict"]) # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"]) if args.gan_loss != "none": d_model.load_state_dict(checkpoint["d_model_state_dict"]) # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"]) d_losses = checkpoint["d_losses"] g_losses = checkpoint["g_losses"] train_losses = checkpoint["train_losses"] test_losses = checkpoint["test_losses"] print("=> load chekcpoint {} (epoch {})" .format(args.resume, checkpoint["epoch"])) # main loop for epoch in tqdm(range(args.start_epoch, args.epochs)): # training.. g_model.train() if args.gan_loss != "none": d_model.train() train_loss = 0. volume_loss_part = np.zeros(args.training_step) for i, sample in enumerate(train_loader): params = list(g_model.named_parameters()) # pdb.set_trace() # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g))) # adversarial ground truths real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False) fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False) v_f = sample["v_f"].to(device) v_b = sample["v_b"].to(device) v_i = sample["v_i"].to(device) g_optimizer.zero_grad() fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm) # adversarial loss # update discriminator if args.gan_loss != "none": avg_d_loss = 0. avg_d_loss_real = 0. avg_d_loss_fake = 0. for k in range(args.n_d): d_optimizer.zero_grad() decisions = d_model(v_i) d_loss_real = adversarial_loss(decisions, real_label) fake_decisions = d_model(fake_volumes.detach()) d_loss_fake = adversarial_loss(fake_decisions, fake_label) d_loss = d_loss_real + d_loss_fake d_loss.backward() avg_d_loss += d_loss.item() / args.n_d avg_d_loss_real += d_loss_real / args.n_d avg_d_loss_fake += d_loss_fake / args.n_d d_optimizer.step() # update generator if args.gan_loss != "none": avg_g_loss = 0. avg_loss = 0. for k in range(args.n_g): loss = 0. g_optimizer.zero_grad() # adversarial loss if args.gan_loss != "none": fake_decisions = d_model(fake_volumes) g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label) loss += g_loss avg_g_loss += g_loss.item() / args.n_g # volume loss if args.volume_loss: volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes) for j in range(v_i.shape[1]): volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every loss += volume_loss # feature loss if args.feature_loss: feat_real = d_model.extract_features(v_i) feat_fake = d_model.extract_features(fake_volumes) for m in range(len(feat_real)): loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m]) avg_loss += loss / args.n_g loss.backward() g_optimizer.step() train_loss += avg_loss # log training status subEpoch = (i + 1) // args.log_every if (i+1) % args.log_every == 0: print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader), avg_loss )) print("Volume Loss: ") for j in range(volume_loss_part.shape[0]): print("\tintermediate {}: {:.6f}".format( j+1, volume_loss_part[j] )) if args.gan_loss != "none": print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format( avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss )) d_losses.append(avg_d_loss) g_losses.append(avg_g_loss) # train_losses.append(avg_loss) train_losses.append(train_loss.item() / args.log_every) print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format( subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time())) )) train_loss = 0. volume_loss_part = np.zeros(args.training_step) # testing... if (i + 1) % args.test_every == 0: g_model.eval() if args.gan_loss != "none": d_model.eval() test_loss = 0. with torch.no_grad(): for i, sample in enumerate(test_loader): v_f = sample["v_f"].to(device) v_b = sample["v_b"].to(device) v_i = sample["v_i"].to(device) fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm) test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item() test_losses.append(test_loss * args.batch_size / len(test_loader.dataset)) print("====> SubEpoch: {} Test set loss {:4f} Time {}".format( subEpoch, test_losses[-1], time.asctime(time.localtime(time.time())) )) # saving... if (i+1) % args.check_every == 0: print("=> saving checkpoint at epoch {}".format(epoch)) if args.gan_loss != "none": torch.save({"epoch": epoch + 1, "g_model_state_dict": g_model.state_dict(), "g_optimizer_state_dict": g_optimizer.state_dict(), "d_model_state_dict": d_model.state_dict(), "d_optimizer_state_dict": d_optimizer.state_dict(), "d_losses": d_losses, "g_losses": g_losses, "train_losses": train_losses, "test_losses": test_losses}, os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar") ) else: torch.save({"epoch": epoch + 1, "g_model_state_dict": g_model.state_dict(), "g_optimizer_state_dict": g_optimizer.state_dict(), "train_losses": train_losses, "test_losses": test_losses}, os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar") ) torch.save(g_model.state_dict(), os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth")) num_subEpoch = len(train_loader) // args.log_every print("====> Epoch: {} Average loss: {:.6f} Time {}".format( epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time())) ))
def main(): random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) gen_data_loader = Gen_Data_loader(BATCH_SIZE) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing vocab_size = 2000 dis_data_loader = Dis_dataloader(BATCH_SIZE) generator = Generator(vocab_size, EMB_DIM, HIDDEN_DIM, 1, START_TOKEN, SEQ_LENGTH).to(device) target_lstm = Generator(vocab_size, EMB_DIM, HIDDEN_DIM, 1, START_TOKEN, SEQ_LENGTH, oracle=True).to(device) discriminator = Discriminator(vocab_size, dis_embedding_dim, dis_filter_sizes, dis_num_filters, dis_dropout).to(device) generate_samples(target_lstm, BATCH_SIZE, generated_num, positive_file) gen_data_loader.create_batches(positive_file) pre_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2) adv_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2) dis_opt = torch.optim.Adam(discriminator.parameters(), 1e-4) dis_criterion = nn.NLLLoss() log = open('save/experiment-log.txt', 'w') print('Start pre-training...') log.write('pre-training...\n') for epoch in range(PRE_EPOCH_NUM): loss = pre_train_epoch(generator, pre_gen_opt, gen_data_loader) if (epoch + 1) % 5 == 0: generate_samples(generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(target_lstm, likelihood_data_loader) print('pre-train epoch ', epoch + 1, '\tnll:\t', test_loss) buffer = 'epoch:\t' + str(epoch + 1) + '\tnll:\t' + str(test_loss) + '\n' log.write(buffer) print('Start pre-training discriminator...') # Train 3 epoch on the generated data and do this for 50 times for e in range(50): generate_samples(generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) d_total_loss = [] for _ in range(3): dis_data_loader.reset_pointer() total_loss = [] for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() x_batch = x_batch.to(device) y_batch = y_batch.to(device) dis_output = discriminator(x_batch.detach()) d_loss = dis_criterion(dis_output, y_batch.detach()) dis_opt.zero_grad() d_loss.backward() dis_opt.step() total_loss.append(d_loss.data.cpu().numpy()) d_total_loss.append(np.mean(total_loss)) if (e + 1) % 5 == 0: buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format( e + 1, np.mean(d_total_loss)) print(buffer) log.write(buffer) rollout = Rollout(generator, 0.8) print( '#########################################################################' ) print('Start Adversarial Training...') log.write('adversarial training...\n') gan_loss = GANLoss() for total_batch in range(TOTAL_BATCH): # Train the generator for one step discriminator.eval() for it in range(1): samples, _ = generator.sample(num_samples=BATCH_SIZE) rewards = rollout.get_reward(samples, 16, discriminator) prob = generator(samples.detach()) adv_loss = gan_loss(prob, samples.detach(), rewards.detach()) adv_gen_opt.zero_grad() adv_loss.backward() nn.utils.clip_grad_norm_(generator.parameters(), 5.0) adv_gen_opt.step() # Test if (total_batch + 1) % 5 == 0: generate_samples(generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(target_lstm, likelihood_data_loader) self_bleu_score = self_bleu(generator) buffer = 'epoch:\t' + str(total_batch + 1) + '\tnll:\t' + str( test_loss) + '\tSelf Bleu:\t' + str(self_bleu_score) + '\n' print(buffer) log.write(buffer) # Update roll-out parameters rollout.update_params() # Train the discriminator discriminator.train() for _ in range(5): generate_samples(generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) d_total_loss = [] for _ in range(3): dis_data_loader.reset_pointer() total_loss = [] for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() x_batch = x_batch.to(device) y_batch = y_batch.to(device) dis_output = discriminator(x_batch.detach()) d_loss = dis_criterion(dis_output, y_batch.detach()) dis_opt.zero_grad() d_loss.backward() dis_opt.step() total_loss.append(d_loss.data.cpu().numpy()) d_total_loss.append(np.mean(total_loss)) if (total_batch + 1) % 5 == 0: buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format( total_batch + 1, np.mean(d_total_loss)) print(buffer) log.write(buffer) log.close()
class Trainer: def __init__(self): self.logger = None self.tester = None self.latent = None self.save_path = None self.epoch = 0 self.step = 0 self.loss_computer = None self.tf_prob = 0 # Models self.encoder = None self.latent_compressor = None self.latent_decompressor = None self.decoder = None self.generator = None if config["train"]["aae"]: self.discriminator = None # Optimizers self.encoder_optimizer = None self.decoder_optimizer = None self.criterion = None if config["train"]["aae"]: self.disc_optimizer = None self.gen_optimizer = None self.train_discriminator_not_generator = True self.disc_losses = [] self.gen_losses = [] self.disc_loss_init = None self.gen_loss_init = None self.beta = -0.1 # so it become 0 at first iteration self.reg_optimizer = None def test_losses(self, loss): losses = [loss] names = ["loss"] for ls, name in zip(losses, names): print("********************** Optimized by " + name) self.encoder_optimizer.zero_grad(set_to_none=True) self.decoder_optimizer.zero_grad(set_to_none=True) ls.backward(retain_graph=True) for model in [ self.encoder, self.latent_compressor, self.latent_decompressor, self.decoder, self.generator ]: # removed latent compressor for module_name, parameter in model.named_parameters(): if parameter.grad is not None: print(module_name) self.encoder_optimizer.zero_grad(set_to_none=True) self.decoder_optimizer.zero_grad(set_to_none=True) (losses[0]).backward(retain_graph=True) print("********************** NOT OPTIMIZED BY NOTHING") for model in [ self.encoder, self.latent_compressor, self.latent_decompressor, self.decoder, self.generator ]: # removed latent compressor for module_name, parameter in model.named_parameters(): if parameter.grad is None: print(module_name) def run_mb(self, batch): # SETUP VARIABLES srcs, trgs = batch srcs = torch.LongTensor(srcs.long()).to( config["train"]["device"]).transpose(0, 2) trgs = torch.LongTensor(trgs.long()).to( config["train"]["device"]).transpose(0, 2) # invert batch and bars latent = None batches = [ Batch(srcs[i], trgs[i], config["tokens"]["pad"]) for i in range(n_bars) ] ############ # ENCODING # ############ latents = [] for batch in batches: latent = self.encoder(batch.src, batch.src_mask) latents.append(latent) ############ # COMPRESS # ############ old_batches = copy.deepcopy(batches) if config["train"]["compress_latents"]: latent = self.latent_compressor( latents) # in: 3, 4, 200, 256, out: 3, 256 self.latent = latent.detach().cpu().numpy() if config["train"]["compress_latents"]: latents = self.latent_decompressor( latent) # in 3, 256, out: 3, 4, 200, 256 for i in range(n_bars): batches[i].src_mask = batches[i].src_mask.fill_( True)[:, :, :, :20] ############ # DECODING # ############ # Scheduled sampling for transformer if config["train"]["scheduled_sampling"] and self.step > config[ "train"]["after_steps_mix_sequences"]: for _ in range(1): # K self.tf_prob = 0.5 predicted = [] for batch, latent in zip(batches, latents): out = self.decoder(batch.trg, latent, batch.src_mask, batch.trg_mask) prob = self.generator(out) prob = torch.max(prob, dim=-1).indices predicted.append(prob) # add sos at beginning and cut last token for i in range(n_bars): sos = torch.full_like(predicted[i], config["tokens"]["sos"])[..., :1].to( predicted[i].device) pred = torch.cat((sos, predicted[i]), dim=-1)[..., :-1] # create mixed trg mixed_prob = torch.rand(batches[i].trg.shape, dtype=torch.float32).to( trgs.device) mixed_prob = mixed_prob < self.tf_prob batches[i].trg = batches[i].trg.where(mixed_prob, pred) outs = [] for batch, latent in zip(batches, latents): out = self.decoder(batch.trg, latent, batch.src_mask, batch.trg_mask) outs.append(out) # Format results outs = torch.stack(outs, dim=0) ##################### # LOSS AND ACCURACY # ##################### trg_ys = torch.stack([batch.trg_y for batch in batches]) bars, n_track, n_batch, seq_len, d_model = outs.shape outs = outs.permute(1, 2, 0, 3, 4).reshape(n_track, n_batch, bars * seq_len, d_model) # join bars trg_ys = trg_ys.permute(1, 2, 0, 3).reshape(n_track, n_batch, bars * seq_len) loss, accuracy = SimpleLossCompute(self.generator, self.criterion)( outs, trg_ys, batch.ntokens) # join instr # if self.encoder.training: # self.test_losses(loss) if self.generator.training: self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() # if n_bars == 16: # torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 0.1) # torch.nn.utils.clip_grad_norm_(self.latent_compressor.parameters(), 0.1) # torch.nn.utils.clip_grad_norm_(self.latent_decompressor.parameters(), 0.1) # torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 0.1) # torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 0.1) loss.backward() self.encoder_optimizer.step() self.decoder_optimizer.step() losses = (loss.item(), accuracy, 0, 0, 0, 0) # *loss_items) # LOG IMAGES if True and self.encoder.training and config["train"]["log_images"] and \ self.step % config["train"]["after_steps_log_images"] == 0 and self.step > 0: # # ENCODER SELF drums_encoder_attn = [] for layer in self.encoder.drums_encoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) drums_encoder_attn.append(instrument_attn) bass_encoder_attn = [] for layer in self.encoder.bass_encoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) bass_encoder_attn.append(instrument_attn) guitar_encoder_attn = [] for layer in self.encoder.guitar_encoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) guitar_encoder_attn.append(instrument_attn) strings_encoder_attn = [] for layer in self.encoder.strings_encoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) strings_encoder_attn.append(instrument_attn) enc_attention = [ drums_encoder_attn, guitar_encoder_attn, bass_encoder_attn, strings_encoder_attn ] # DECODER SELF drums_decoder_attn = [] for layer in self.decoder.drums_decoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) drums_decoder_attn.append(instrument_attn) bass_decoder_attn = [] for layer in self.decoder.bass_decoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) bass_decoder_attn.append(instrument_attn) guitar_decoder_attn = [] for layer in self.decoder.guitar_decoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) guitar_decoder_attn.append(instrument_attn) strings_decoder_attn = [] for layer in self.decoder.strings_decoder.layers: instrument_attn = [] for head in layer.self_attn.attn[0]: instrument_attn.append(head) strings_decoder_attn.append(instrument_attn) dec_attention = [ drums_decoder_attn, guitar_decoder_attn, bass_decoder_attn, strings_decoder_attn ] # DECODER SRC drums_src_attn = [] for layer in self.decoder.drums_decoder.layers: instrument_attn = [] for head in layer.src_attn.attn[0]: instrument_attn.append(head) drums_src_attn.append(instrument_attn) bass_src_attn = [] for layer in self.decoder.bass_decoder.layers: instrument_attn = [] for head in layer.src_attn.attn[0]: instrument_attn.append(head) bass_src_attn.append(instrument_attn) guitar_src_attn = [] for layer in self.decoder.guitar_decoder.layers: instrument_attn = [] for head in layer.src_attn.attn[0]: instrument_attn.append(head) guitar_src_attn.append(instrument_attn) strings_src_attn = [] for layer in self.decoder.strings_decoder.layers: instrument_attn = [] for head in layer.src_attn.attn[0]: instrument_attn.append(head) strings_src_attn.append(instrument_attn) src_attention = [ drums_src_attn, guitar_src_attn, bass_src_attn, strings_src_attn ] print("Logging images...") if config["train"]["compress_latents"]: self.logger.log_latent(self.latent) self.logger.log_attn_heatmap(enc_attention, dec_attention, src_attention) self.logger.log_examples(srcs, trgs) #################### # UPDATE GENERATOR # #################### if config["train"][ "aae"] and self.encoder.training and self.step > config[ "train"]["after_steps_train_aae"]: if self.step % config["train"][ "increase_beta_every"] == 0 and self.beta < config[ "train"]["max_beta"]: self.beta += 0.1 if self.beta > 0: # To suppress warnings D_real = 0 D_fake = 0 loss_critic = 0 ######################## # UPDATE DISCRIMINATOR # ######################## for p in self.encoder.parameters(): p.requires_grad = False for p in self.latent_compressor.parameters(): p.requires_grad = False for p in self.discriminator.parameters(): p.requires_grad = True latents = [] for batch in old_batches: latent = self.encoder(batch.src, batch.src_mask) latents.append(latent) latent = self.latent_compressor(latents) for _ in range(config["train"]["critic_iterations"]): prior = get_prior( (config["train"]["batch_size"], config["model"]["d_model"])) # autograd is intern D_real = self.discriminator(prior).reshape(-1) D_fake = self.discriminator(latent).reshape(-1) gradient_penalty = calc_gradient_penalty( self.discriminator, prior.data, latent.data) loss_critic = ( torch.mean(D_fake) - torch.mean(D_real) + config["train"]["lambda"] * gradient_penalty) loss_critic = loss_critic * self.beta self.discriminator.zero_grad() loss_critic.backward(retain_graph=True) self.disc_optimizer.step(lr=self.encoder_optimizer.lr) #################### # UPDATE GENERATOR # #################### for p in self.encoder.parameters(): p.requires_grad = True for p in self.latent_compressor.parameters(): p.requires_grad = True for p in self.discriminator.parameters(): p.requires_grad = False # to avoid computation latents = [] for batch in old_batches: latent = self.encoder(batch.src, batch.src_mask) latents.append(latent) latent = self.latent_compressor(latents) G = self.discriminator(latent).reshape(-1) loss_gen = -torch.mean(G) loss_gen = loss_gen * self.beta self.gen_optimizer.zero_grad() loss_gen.backward() self.gen_optimizer.step(lr=self.encoder_optimizer.lr) losses += (D_real.mean().cpu().data.numpy(), D_fake.mean().cpu().data.numpy(), G.mean().cpu().data.numpy(), loss_critic.cpu().data.numpy(), loss_gen.cpu().data.numpy(), D_real.mean().cpu().data.numpy() - D_fake.mean().cpu().data.numpy()) return losses def train(self): # Create checkpoint folder if not os.path.exists(config["paths"]["checkpoints"]): os.makedirs(config["paths"]["checkpoints"]) timestamp = str(datetime.now()) timestamp = timestamp[:timestamp.index('.')] timestamp = timestamp.replace(' ', '_').replace(':', '-') self.save_path = config["paths"]["checkpoints"] + os.sep + timestamp os.mkdir(self.save_path) # Create models self.latent_compressor = LatentCompressor( config["model"]["d_model"]).to(config["train"]["device"]) self.latent_decompressor = LatentDecompressor( config["model"]["d_model"]).to(config["train"]["device"]) voc_size = config["tokens"]["vocab_size"] device = config["train"]["device"] self.encoder, self.decoder, self.generator = make_model( voc_size, voc_size, N=config["model"]["layers"], device=device) if config["train"]["aae"]: self.discriminator = Discriminator( config["model"]["d_model"], config["model"]["discriminator_dropout"]).to( config["train"]["device"]) # Create optimizers enc_params = list(self.encoder.parameters()) + list( self.latent_compressor.parameters()) self.encoder_optimizer = CTOpt( torch.optim.Adam(enc_params, lr=0, betas=(0.9, 0.98)), config["train"]["warmup_steps"], (config["train"]["lr_min"], config["train"]["lr_max"]), config["train"]["decay_steps"], config["train"]["minimum_lr"]) dec_params = list(self.latent_decompressor.parameters()) + list( self.decoder.parameters()) + list(self.generator.parameters()) self.decoder_optimizer = CTOpt( torch.optim.Adam(dec_params, lr=0, betas=(0.9, 0.98)), config["train"]["warmup_steps"], (config["train"]["lr_min"], config["train"]["lr_max"]), config["train"]["decay_steps"], config["train"]["minimum_lr"]) if config["train"]["aae"]: self.disc_optimizer = CTOpt( torch.optim.Adam([{ "params": self.discriminator.parameters() }], lr=0, betas=(0.9, 0.98)), config["train"]["warmup_steps"], (config["train"]["lr_min"], config["train"]["lr_max"]), config["train"]["decay_steps"], config["train"]["minimum_lr"]) self.gen_optimizer = CTOpt( torch.optim.Adam(enc_params, lr=0, betas=(0.9, 0.98)), config["train"]["warmup_steps"], (config["train"]["lr_min"], config["train"]["lr_max"]), config["train"]["decay_steps"], config["train"]["minimum_lr"]) self.criterion = LabelSmoothing(size=config["tokens"]["vocab_size"], padding_idx=0, smoothing=0.1).to(device) # Load dataset tr_loader = SongIterator( dataset_path=config["paths"]["dataset"] + os.sep + "train", batch_size=config["train"]["batch_size"], n_workers=config["train"]["n_workers"]).get_loader() ts_loader = SongIterator( dataset_path=config["paths"]["dataset"] + os.sep + "eval", batch_size=config["train"]["batch_size"], n_workers=config["train"]["n_workers"]).get_loader() # Init WANDB self.logger = Logger() wandb.login() wandb.init(project="MusAE", config=config, name="r_" + timestamp if remote else "l_" + timestamp) wandb.watch(self.encoder, log_freq=1000, log="all") wandb.watch(self.latent_compressor, log_freq=1000, log="all") wandb.watch(self.latent_decompressor, log_freq=1000, log="all") wandb.watch(self.decoder, log_freq=1000, log="all") wandb.watch(self.generator, log_freq=1000, log="all") if config["train"]["aae"]: wandb.watch(self.discriminator, log_freq=1000, log="all") # Print info about training time.sleep( 1.) # sleep for one second to let the machine connect to wandb if config["train"]["verbose"]: print("Giving", len(tr_loader), "training samples and", len(ts_loader), "test samples") # print("Final set has size", len(dataset.final_set)) print("Model has", config["model"]["layers"], "layers") print("Batch size is", config["train"]["batch_size"]) print("d_model is", config["model"]["d_model"]) if config["train"]["aae"]: print("Imposing prior distribution on latents") print("Starting training aae after", config["train"]["train_aae_after_steps"]) print("lambda:", config["train"]["lambda"], ", critic iterations:", config["train"]["critic_iterations"]) else: print("NOT imposing prior distribution on latents") if config["train"]["log_images"]: print("Logging images") else: print("NOT logging images") if config["train"]["make_songs"]: print("Making songs every", config["train"]["after_steps_make_songs"]) else: print("NOT making songs") if config["train"]["do_eval"]: if config["train"]["eval_after_epoch"]: print("Doing evaluation after each epoch") else: print("Doing evaluation after", config["train"]["after_steps_do_eval"]) else: print("NOT DOING evaluation") if config["train"]["scheduled_sampling"]: print("Using scheduled sampling") else: print("NOT using scheduled sampling") if config["train"]["compress_latents"]: print("Compressing latents") else: print("NOT compressing latents") if config["train"]["use_rel_pos"]: print("Using relative positional encoding") else: print("NOT using relative positional encoding") print("Save model every", config["train"]["after_steps_save_model"]) if remote: wandb.save("compress_latents.py") wandb.save("train.py") wandb.save("config.py") wandb.save("test.py") wandb.save("loss_computer.py") wandb.save("utilities.py") wandb.save("discriminator.py") wandb.save("compressive_transformer.py") # Setup train self.encoder.train() self.latent_compressor.train() self.latent_decompressor.train() self.decoder.train() self.generator.train() if config["train"]["aae"]: self.discriminator.train() desc = "Train epoch " + str(self.epoch) + ", mb " + str(0) if config["train"]["eval_after_epoch"]: train_progress = tqdm(total=len(tr_loader), position=0, leave=True, desc=desc) else: train_progress = tqdm(total=config["train"]["after_steps_do_eval"], position=0, leave=True, desc=desc) self.step = 0 # -1 to do eval in first step first_batch = None # Main loop for self.epoch in range(config["train"]["n_epochs"]): # for each epoch for song_it, batch in enumerate(tr_loader): # for each song ######### # TRAIN # ######### if first_batch is None: # if training reconstruct from train, if eval reconstruct from eval first_batch = batch second_batch = batch tr_losses = self.run_mb(batch) if self.step % 10 == 0: self.logger.log_losses(tr_losses, self.encoder.training) self.logger.log_stuff( self.encoder_optimizer.lr, self.latent, self.disc_optimizer.lr if config["train"]["aae"] else None, self.gen_optimizer.lr if config["train"]["aae"] else None, self.beta if config["train"]["aae"] else None, get_prior(self.latent.shape) if config["train"]["aae"] else None, self.tf_prob) if self.step == 0: print("Latent shape is:", self.latent.shape) train_progress.update() ######## # EVAL # ######## eae = config["train"]["eval_after_epoch"] do_eval = config["train"]["do_eval"] sbe = config["train"]["after_steps_do_eval"] if ((eae and song_it == 0) or (not eae and self.step % sbe == 0)) and do_eval and self.step > 0: print("Evaluation") train_progress.close() ts_losses = [] self.encoder.eval() self.latent_compressor.eval() self.latent_decompressor.eval() self.decoder.eval() self.generator.eval() if config["train"]["aae"]: self.discriminator.eval() desc = "Eval epoch " + str( self.epoch) + ", mb " + str(song_it) # Compute validation score first_batch = None for test in tqdm(ts_loader, position=0, leave=True, desc=desc): # remember test losses if first_batch is None: first_batch = test second_batch = test with torch.no_grad(): ts_loss = self.run_mb(test) ts_losses.append(ts_loss) final = () # average losses for i in range(len(ts_losses[0])): # for each loss value aux = [] for loss in ts_losses: # for each computed loss aux.append(loss[i]) avg = sum(aux) / len(aux) final = final + (avg, ) self.logger.log_losses(final, self.encoder.training) # eval end self.encoder.train() self.latent_compressor.train() self.latent_decompressor.train() self.decoder.train() self.generator.train() if config["train"]["aae"]: self.discriminator.train() desc = "Train epoch " + str( self.epoch) + ", mb " + str(song_it) if config["train"]["eval_after_epoch"]: train_progress = tqdm(total=len(tr_loader), position=0, leave=True, desc=desc) else: train_progress = tqdm( total=config["train"]["after_steps_do_eval"], position=0, leave=True, desc=desc) ############## # SAVE MODEL # ############## if (self.step % config["train"]["after_steps_save_model"] ) == 0 and self.step > 0: full_path = self.save_path + os.sep + str(self.step) os.makedirs(full_path) print("Saving last model in " + full_path + ", DO NOT INTERRUPT") torch.save(self.encoder, os.path.join(full_path, "encoder.pt"), pickle_module=dill) torch.save(self.latent_compressor, os.path.join(full_path, "latent_compressor.pt"), pickle_module=dill) torch.save(self.latent_decompressor, os.path.join(full_path, "latent_decompressor.pt"), pickle_module=dill) torch.save(self.decoder, os.path.join(full_path, "decoder.pt"), pickle_module=dill) torch.save(self.generator, os.path.join(full_path, "generator.pt"), pickle_module=dill) if config["train"]["aae"]: torch.save(self.discriminator, os.path.join(full_path, "discriminator.pt"), pickle_module=dill) print("Model saved") ######## # TEST # ######## if (self.step % config["train"]["after_steps_make_songs"]) == 0 and config["train"]["make_songs"] \ and self.step > 0: print("Making songs") self.encoder.eval() self.latent_compressor.eval() self.latent_decompressor.eval() self.decoder.eval() self.generator.eval() self.tester = Tester(self.encoder, self.latent_compressor, self.latent_decompressor, self.decoder, self.generator) # RECONSTRUCTION note_manager = NoteRepresentationManager() to_reconstruct = second_batch with torch.no_grad(): original, reconstructed, acc = self.tester.reconstruct( to_reconstruct, note_manager) prefix = "epoch_" + str(self.epoch) + "_mb_" + str(song_it) self.logger.log_songs(os.path.join(wandb.run.dir, prefix), [original, reconstructed], ["original", "reconstructed"], "validation reconstruction example") self.logger.log_reconstruction_accuracy(acc) if config["train"]["aae"]: # GENERATION with torch.no_grad(): generated = self.tester.generate( note_manager) # generation self.logger.log_songs( os.path.join(wandb.run.dir, prefix), [generated], ["generated"], "generated") # INTERPOLATION with torch.no_grad(): first, interpolation, second = self.tester.interpolation( note_manager, first_batch, second_batch) self.logger.log_songs( os.path.join(wandb.run.dir, prefix), [first, interpolation, second], ["first", "interpolation", "second"], "interpolation") # end test self.encoder.train() self.latent_compressor.train() self.latent_decompressor.train() self.decoder.train() self.generator.train() self.step += 1