def create_losses(self): # define loss function self._g_criterion = torch.nn.NLLLoss(reduction='mean') self.d_criterion = torch.nn.BCELoss() #torch.nn.SoftMarginLoss() # self._pg_criterion = PGLoss(ignore_index=self.dataset.dst_dict.pad(), size_average=True, reduce=True) self._logsoftmax = torch.nn.LogSoftmax(dim=-1) self.g_criterion = lambda pred, true: self._g_criterion( self._logsoftmax(pred), true) self.pg_criterion = lambda pred, true, reward, modified_logits, predicted_tokens: \ self._pg_criterion( self._logsoftmax(pred), true, reward, self._logsoftmax(modified_logits) if modified_logits is not None else None, predicted_tokens, )
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator loaded successfully!") g_model_path = 'checkpoints/zhenwarm/generator.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/myzhencli5'): os.makedirs('checkpoints/myzhencli5') checkpoints_path = 'checkpoints/myzhencli5/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(trainloader): # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when random.random() > 50% if random.random() >= 0.5: print("Policy Gradient Training") sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch = sample['target'] # 64 x 50 pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss = pg_loss / math.log(2) g_logging_meters['train_loss'].update(logging_loss.item(), sample_size) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: # MLE training print("MLE Training") sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 train_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator if num_update % 5 == 0: bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = Variable( torch.ones( sample['target'].size(0)).float()) # 64 length vector with torch.no_grad(): sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros( sample['target'].size(0)).float()) # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() # fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 # true_disc_out = discriminator(src_sentence, true_sentence) # # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (fake_acc + true_acc) / 2 # # d_loss = fake_d_loss + true_d_loss if random.random() > 0.5: fake_disc_out = discriminator(src_sentence, fake_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss acc = fake_acc else: true_disc_out = discriminator(src_sentence, true_sentence) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) d_loss = true_d_loss acc = true_acc d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug( f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}" ) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() if num_update % 10000 == 0: # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss = g_criterion(out_batch, dev_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update( loss, sample_size) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) fake_sentence = torch.reshape( prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 true_disc_out = discriminator(src_sentence, true_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) d_loss = fake_d_loss + true_d_loss fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) acc = (fake_acc + true_acc) / 2 d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) logging.debug( f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}" ) # torch.save(discriminator, # open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill) # if d_logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = d_logging_meters['valid_loss'].avg # torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) print("======printing args========") print(args) print("=================================") # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): print("Loading bin dataset") dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) else: print(f"Loading raw text dataset {args.data}") dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # try to load generator model g_model_path = 'checkpoints/generator/best_gmodel.pt' if not os.path.exists(g_model_path): print("Start training generator!") train_g(args, dataset) assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load(g_model_path) #print(f"First dict: {pretrained_dict}") # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } #print(f"Second dict: {pretrained_dict}") # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) #print(f"model dict: {model_dict}") # 3. load the new state dict generator.load_state_dict(model_dict) print("Generator has successfully loaded!") # try to load discriminator model d_model_path = 'checkpoints/discriminator/best_dmodel.pt' if not os.path.exists(d_model_path): print("Start training discriminator!") train_d(args, dataset) assert os.path.exists(d_model_path) discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = discriminator.state_dict() pretrained_dict = torch.load(d_model_path) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(model_dict) print("Discriminator has successfully loaded!") #return print("starting main training loop") torch.autograd.set_detect_anomaly(True) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/joint'): os.makedirs('checkpoints/joint') checkpoints_path = 'checkpoints/joint/' # define loss function g_criterion = torch.nn.NLLLoss(size_average=False, ignore_index=dataset.dst_dict.pad(), reduce=True) d_criterion = torch.nn.BCEWithLogitsLoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) # seed = args.seed + epoch_i # torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when rand > 50% rand = random.random() if rand >= 0.5: # policy gradient training generator.decoder.is_testing = True sys_out_batch, prediction, _ = generator(sample) generator.decoder.is_testing = False with torch.no_grad(): n_i = sample['net_input']['src_tokens'] #print(f"net input:\n{n_i}, pred: \n{prediction}") reward = discriminator( sample['net_input']['src_tokens'], prediction) # dataset.dst_dict.pad()) train_trg_batch = sample['target'] #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}") pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() # oracle valid _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) else: # MLE training #print(f"printing sample: \n{sample}") _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence) #, dataset.dst_dict.pad()) #print(f"disc out: {disc_out.shape}, labels: {labels.shape}") #print(f"labels: {labels}") d_loss = d_criterion(disc_out, labels.long()) acc = torch.sum(torch.Sigmoid() (disc_out).round() == labels).float() / len(labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg, # d_logging_meters['train_acc'].avg, # i)) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: sample['id'] = sample['id'].cuda() sample['net_input']['src_tokens'] = sample['net_input'][ 'src_tokens'].cuda() sample['net_input']['src_lengths'] = sample['net_input'][ 'src_lengths'].cuda() sample['net_input']['prev_output_tokens'] = sample[ 'net_input']['prev_output_tokens'].cuda() sample['target'] = sample['target'].cuda() # generator validation _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug("G dev loss at batch {0}: {1:.3f}".format( i, g_logging_meters['valid_loss'].avg)) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence, dataset.dst_dict.pad()) d_loss = d_criterion(disc_out, labels) acc = torch.sum(torch.Sigmoid()(disc_out).round() == labels).float() / len(labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg, # d_logging_meters['valid_acc'].avg, i)) torch.save(generator, open( checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format( g_logging_meters['valid_loss'].avg, epoch_i), 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda) print("Discriminator loaded successfully!") def _calcualte_discriminator_loss(tf_scores, ar_scores): tf_loss = torch.log(tf_scores + 1e-6) * (-1) ar_loss = torch.log(1 - ar_scores + 1e-6) * (-1) return tf_loss + ar_loss if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/professorjp'): os.makedirs('checkpoints/professorjp') checkpoints_path = 'checkpoints/professorjp/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') # d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) # for p in discriminator.embed_src_tokens.parameters(): # p.requires_grad = False # for p in discriminator.embed_trg_tokens.parameters(): # p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(trainloader): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # print("Policy Gradient Training") sys_out_batch_PG, p_PG, hidden_list_PG = generator( 'PG', epoch_i, sample) # 64 X 50 X 6632 out_batch_PG = sys_out_batch_PG.contiguous().view( -1, sys_out_batch_PG.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch_PG.topk(1) # prediction = prediction.squeeze(1) # 64*50 = 3200 # prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(hidden_list_PG) # 64 X 1 train_trg_batch_PG = sample['target'] # 64 x 50 pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda) sample_size_PG = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss_PG = pg_loss_PG / math.log(2) g_logging_meters['train_loss'].update(logging_loss_PG.item(), sample_size_PG) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss_PG.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() # print("MLE Training") sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator( "MLE", epoch_i, sample) out_batch_MLE = sys_out_batch_MLE.contiguous().view( -1, sys_out_batch_MLE.size(-1)) # (64 X 50) X 6632 train_trg_batch_MLE = sample['target'].view(-1) # 64*50 = 3200 loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE) sample_size_MLE = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss_MLE, sample_size_MLE) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) # g_optimizer.zero_grad() loss_MLE.backward(retain_graph=True) # all-reduce grads and rescale by grad_denom for p in generator.parameters(): # print(p.size()) if p.requires_grad: p.grad.data.div_(sample_size_MLE) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator d_MLE = discriminator(hidden_list_MLE) d_PG = discriminator(hidden_list_PG) d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum() logging.debug(f"D training loss {d_loss} at batch {i}") d_optimizer.zero_grad() d_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), args.clip_norm) d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch_test, p_test, hidden_list_test = generator( 'test', epoch_i, sample) out_batch_test = sys_out_batch_test.contiguous().view( -1, sys_out_batch_test.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss_test = g_criterion(out_batch_test, dev_trg_batch) sample_size_test = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss_test = loss_test / sample_size_test / math.log(2) g_logging_meters['valid_loss'].update(loss_test, sample_size_test) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # # discriminator validation # bsz = sample['target'].size(0) # src_sentence = sample['net_input']['src_tokens'] # # train with half human-translation and half machine translation # # true_sentence = sample['target'] # true_labels = torch.ones(sample['target'].size(0)).float() # # with torch.no_grad(): # sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample) # # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1)) # (64 X 50) X 6632 # # _, prediction = out_batch.topk(1) # prediction = prediction.squeeze(1) # 64 * 50 = 6632 # # fake_labels = torch.zeros(sample['target'].size(0)).float() # # fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 # # if use_cuda: # fake_labels = fake_labels.cuda() # # disc_out = discriminator(src_sentence, fake_sentence) # d_loss = d_criterion(disc_out.squeeze(1), fake_labels) # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # d_logging_meters['valid_acc'].update(acc) # d_logging_meters['valid_loss'].update(d_loss) # logging.debug( # f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}") torch.save( generator, open( checkpoints_path + f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt", 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['MLE_train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['PG_train_loss'] = AverageMeter() g_logging_meters['MLE_train_acc'] = AverageMeter() g_logging_meters['PG_train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['D_h_train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['D_s_train_loss'] = AverageMeter() d_logging_meters['D_h_train_acc'] = AverageMeter() d_logging_meters['D_s_train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0.1 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0.1 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator_h = Discriminator_h(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda) print("Discriminator_h loaded successfully!") discriminator_s = Discriminator_s(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator_s loaded successfully!") # Load generator model g_model_path = 'checkpoints/zhenwarm/genev.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator_s.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator_s.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") # # Load discriminatorH model # d_H_model_path = 'checkpoints/joint_warm/DH.pt' # assert os.path.exists(d_H_model_path) # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # d_H_model_dict = discriminator_h.state_dict() # d_H_model = torch.load(d_H_model_path) # d_H_pretrained_dict = d_H_model.state_dict() # # 1. filter out unnecessary keys # d_H_pretrained_dict = {k: v for k, v in d_H_pretrained_dict.items() if k in d_H_model_dict} # # 2. overwrite entries in the existing state dict # d_H_model_dict.update(d_H_pretrained_dict) # # 3. load the new state dict # discriminator_h.load_state_dict(d_H_model_dict) # print("pre-trained Discriminator_H loaded successfully!") def _calcualte_discriminator_loss(tf_scores, ar_scores): tf_loss = torch.log(tf_scores + 1e-9) * (-1) ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1) return tf_loss + ar_loss if use_cuda: if torch.cuda.device_count() > 1: discriminator_h = torch.nn.DataParallel(discriminator_h).cuda() discriminator_s = torch.nn.DataParallel(discriminator_s).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator_h.cuda() discriminator_s.cuda() else: discriminator_h.cpu() discriminator_s.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/realmyzhenup10shrink07drop01new'): os.makedirs('checkpoints/realmyzhenup10shrink07drop01new') checkpoints_path = 'checkpoints/realmyzhenup10shrink07drop01new/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator_s word embedding (as Wu et al. do) for p in discriminator_s.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator_s.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer_h = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator_h.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) d_optimizer_s = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator_s.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode for i, sample in enumerate(trainloader): generator.train() discriminator_h.train() discriminator_s.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) train_MLE = 0 train_PG = 0 if random.random() > 0.5 and i != 0: train_MLE = 1 # print("MLE Training") sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator( "MLE", epoch_i, sample) out_batch_MLE = sys_out_batch_MLE.contiguous().view( -1, sys_out_batch_MLE.size(-1)) # (64 X 50) X 6632 train_trg_batch_MLE = sample['target'].view(-1) # 64*50 = 3200 loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE) sample_size_MLE = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log( 2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['MLE_train_loss'].update( logging_loss_MLE, sample_size_MLE) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['MLE_train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss_MLE.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size_MLE) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: train_PG = 1 ## part I: use gradient policy method to train the generator # print("Policy Gradient Training") sys_out_batch_PG, p_PG, hidden_list_PG = generator( 'PG', epoch_i, sample) # 64 X 50 X 6632 out_batch_PG = sys_out_batch_PG.contiguous().view( -1, sys_out_batch_PG.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch_PG.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 # if d_logging_meters['train_acc'].avg >= 0.75: with torch.no_grad(): reward = discriminator_s(sample['net_input']['src_tokens'], prediction) # 64 X 1 # else: # reward = torch.ones(args.joint_batch_size) train_trg_batch_PG = sample['target'] # 64 x 50 pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda) sample_size_PG = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss_PG = pg_loss_PG / math.log(2) g_logging_meters['PG_train_loss'].update( logging_loss_PG.item(), sample_size_PG) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() if g_logging_meters['PG_train_loss'].val < 0.5: pg_loss_PG.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # if g_logging_meters["MLE_train_loss"].avg < 4: # part II: train the discriminator # discriminator_h if num_update % 10 == 0: if g_logging_meters["PG_train_loss"].val < 2: assert (train_MLE == 1) != (train_PG == 1) if train_MLE == 1: d_MLE = discriminator_h(hidden_list_MLE.detach()) M_loss = torch.log(d_MLE + 1e-9) * (-1) h_d_loss = M_loss.sum() elif train_PG == 1: d_PG = discriminator_h(hidden_list_PG.detach()) P_loss = torch.log(1 - d_PG + 1e-9) * (-1) h_d_loss = P_loss.sum() # d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum() logging.debug(f"D_h training loss {h_d_loss} at batch {i}") d_optimizer_h.zero_grad() h_d_loss.backward() torch.nn.utils.clip_grad_norm_( discriminator_h.parameters(), args.clip_norm) d_optimizer_h.step() #discriminator_s bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output with torch.no_grad(): sys_out_batch, p, hidden_list = generator( 'PG', epoch_i, sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = torch.zeros( sample['target'].size(0)).float() # 64 length vector # true_labels = torch.ones(sample['target'].size(0)).float() # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 if use_cuda: fake_labels = fake_labels.cuda() # true_labels = true_labels.cuda() # if random.random() > 0.5: fake_disc_out = discriminator_s(src_sentence, fake_sentence) # 64 X 1 fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss # else: # # true_sentence = sample['target'].view(-1) # 64*50 = 3200 # # true_sentence = torch.reshape(true_sentence, src_sentence.shape) # true_disc_out = discriminator_s(src_sentence, true_sentence) # acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # d_loss = true_d_loss # acc_fake = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # acc_true = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (acc_fake + acc_true) / 2 # acc = acc_fake # d_loss = fake_d_loss + true_d_loss s_d_loss = d_loss d_logging_meters['D_s_train_acc'].update(acc) d_logging_meters['D_s_train_loss'].update(s_d_loss) logging.debug( f"D_s training loss {d_logging_meters['D_s_train_loss'].avg:.3f}, acc {d_logging_meters['D_s_train_acc'].avg:.3f} at batch {i}" ) d_optimizer_s.zero_grad() s_d_loss.backward() d_optimizer_s.step() if num_update % 10000 == 0: # validation # set validation mode print( 'validation and save+++++++++++++++++++++++++++++++++++++++++++++++' ) generator.eval() discriminator_h.eval() discriminator_s.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch_test, p_test, hidden_list_test = generator( 'test', epoch_i, sample) out_batch_test = sys_out_batch_test.contiguous().view( -1, sys_out_batch_test.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss_test = g_criterion(out_batch_test, dev_trg_batch) sample_size_test = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss_test = loss_test / sample_size_test / math.log(2) g_logging_meters['valid_loss'].update( loss_test, sample_size_test) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}w.sampling_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)