def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator loaded successfully!") g_model_path = 'checkpoints/zhenwarm/generator.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/myzhencli5'): os.makedirs('checkpoints/myzhencli5') checkpoints_path = 'checkpoints/myzhencli5/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(trainloader): # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when random.random() > 50% if random.random() >= 0.5: print("Policy Gradient Training") sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch = sample['target'] # 64 x 50 pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss = pg_loss / math.log(2) g_logging_meters['train_loss'].update(logging_loss.item(), sample_size) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: # MLE training print("MLE Training") sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 train_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator if num_update % 5 == 0: bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = Variable( torch.ones( sample['target'].size(0)).float()) # 64 length vector with torch.no_grad(): sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros( sample['target'].size(0)).float()) # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() # fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 # true_disc_out = discriminator(src_sentence, true_sentence) # # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (fake_acc + true_acc) / 2 # # d_loss = fake_d_loss + true_d_loss if random.random() > 0.5: fake_disc_out = discriminator(src_sentence, fake_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss acc = fake_acc else: true_disc_out = discriminator(src_sentence, true_sentence) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) d_loss = true_d_loss acc = true_acc d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug( f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}" ) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() if num_update % 10000 == 0: # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss = g_criterion(out_batch, dev_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update( loss, sample_size) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) fake_sentence = torch.reshape( prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 true_disc_out = discriminator(src_sentence, true_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) d_loss = fake_d_loss + true_d_loss fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) acc = (fake_acc + true_acc) / 2 d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) logging.debug( f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}" ) # torch.save(discriminator, # open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill) # if d_logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = d_logging_meters['valid_loss'].avg # torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)
def main(args): use_cuda = (len(args.gpuid) >= 1) if args.gpuid: cuda.set_device(args.gpuid[0]) print(args.replace_unk) #None # Load dataset if args.replace_unk is None: dataset = data.load_dataset( args.data, ['test'], args.src_lang, args.trg_lang, ) else: dataset = data.load_raw_text_dataset( args.data, ['test'], args.src_lang, args.trg_lang, ) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| {} {} {} examples'.format(args.data, 'test', len(dataset.splits['test']))) # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # Load model g_model_path = args.model_dir #'checkpoints/generator/numupdate2.997465464368014.data.nll_270000.0.pt' assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) generator.eval() print("Generator loaded successfully!") if use_cuda > 0: generator.cuda() else: generator.cpu() max_positions = generator.encoder.max_positions() testloader = dataset.eval_dataloader( 'test', max_sentences=args.max_sentences, max_positions=max_positions, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, ) translator = SequenceGenerator(generator, beam_size=args.beam, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, unk_penalty=args.unkpen) if use_cuda: translator.cuda() with open('predictions.txt', 'wb') as translation_writer: with open('real.txt', 'wb') as ground_truth_writer: with open('src.txt', 'wb') as src_writer: translations = translator.generate_batched_itr( testloader, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda) for sample_id, src_tokens, target_tokens, hypos in translations: # Process input and ground truth target_tokens = target_tokens.int().cpu() src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True) # Process top predictions for i, hypo in enumerate( hypos[:min(len(hypos), args.nbest)]): hypo_tokens = hypo['tokens'].int().cpu() hypo_str = dataset.dst_dict.string( hypo_tokens, args.remove_bpe) hypo_str += '\n' target_str += '\n' src_str += '\n' translation_writer.write(hypo_str.encode('utf-8')) ground_truth_writer.write(target_str.encode('utf-8')) src_writer.write(src_str.encode('utf-8'))
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) # check checkpoints saving path if not os.path.exists('checkpoints/generator'): os.makedirs('checkpoints/generator') checkpoints_path = 'checkpoints/generator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['bsz'] = AverageMeter() # sentences per batch logging_meters['update_times'] = AverageMeter() # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # Build model generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # g_model_path = 'checkpoints/generator/numupdate1.4180668458302803.data.nll_105000.000.pt' # assert os.path.exists(g_model_path) # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # model_dict = generator.state_dict() # model = torch.load(g_model_path) # pretrained_dict = model.state_dict() # # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # # 3. load the new state dict # generator.load_state_dict(model_dict) # print("pre-trained Generator loaded successfully!") if use_cuda: if len(args.gpuid) > 1: generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() else: generator.cpu() print("Training generator...") g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf epoch_i = 1 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] num_update = 0 # main training loop while lr > args.min_g_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (min(args.max_source_positions, generator.encoder.max_positions()), min(args.max_target_positions, generator.decoder.max_positions())) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set training mode # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): generator.train() if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) train_trg_batch = sample['target'].view(-1) loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.item() / sample_size / math.log(2) logging_meters['bsz'].update(nsentences) logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format( i, logging_meters['train_loss'].avg, round(logging_meters['bsz'].avg), optimizer.param_groups[0]['lr'])) optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) optimizer.step() num_update = num_update + 1 if num_update % 5000 == 0: # validation -- this is a crude estimation because there might be some padding at the end max_positions_valid = ( generator.encoder.max_positions(), generator.decoder.max_positions(), ) # Initialize dataloader itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set validation mode generator.eval() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() with torch.no_grad(): for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) val_trg_batch = sample['target'].view(-1) loss = g_criterion(out_batch, val_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss.item() / sample_size / math.log(2) logging_meters['valid_loss'].update(loss, sample_size) logging.debug( "g dev loss at batch {0}: {1:.3f}".format( i, logging_meters['valid_loss'].avg)) # update learning rate lr_scheduler.step(logging_meters['valid_loss'].avg) lr = optimizer.param_groups[0]['lr'] logging.info( "Average g loss value per instance is {0} at the end of epoch {1}" .format(logging_meters['valid_loss'].avg, epoch_i)) torch.save( generator, open( checkpoints_path + "numupdate{1}.data.nll_{0:.1f}.pt".format( num_update, logging_meters['valid_loss'].avg), 'wb')) # if logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = logging_meters['valid_loss'].avg # torch.save(generator.state_dict(), open( # checkpoints_path + "best_gmodel.pt", 'wb')) epoch_i += 1
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) print("======printing args========") print(args) print("=================================") # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): print("Loading bin dataset") dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) else: print(f"Loading raw text dataset {args.data}") dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # try to load generator model g_model_path = 'checkpoints/generator/best_gmodel.pt' if not os.path.exists(g_model_path): print("Start training generator!") train_g(args, dataset) assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load(g_model_path) #print(f"First dict: {pretrained_dict}") # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } #print(f"Second dict: {pretrained_dict}") # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) #print(f"model dict: {model_dict}") # 3. load the new state dict generator.load_state_dict(model_dict) print("Generator has successfully loaded!") # try to load discriminator model d_model_path = 'checkpoints/discriminator/best_dmodel.pt' if not os.path.exists(d_model_path): print("Start training discriminator!") train_d(args, dataset) assert os.path.exists(d_model_path) discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = discriminator.state_dict() pretrained_dict = torch.load(d_model_path) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(model_dict) print("Discriminator has successfully loaded!") #return print("starting main training loop") torch.autograd.set_detect_anomaly(True) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/joint'): os.makedirs('checkpoints/joint') checkpoints_path = 'checkpoints/joint/' # define loss function g_criterion = torch.nn.NLLLoss(size_average=False, ignore_index=dataset.dst_dict.pad(), reduce=True) d_criterion = torch.nn.BCEWithLogitsLoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) # seed = args.seed + epoch_i # torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when rand > 50% rand = random.random() if rand >= 0.5: # policy gradient training generator.decoder.is_testing = True sys_out_batch, prediction, _ = generator(sample) generator.decoder.is_testing = False with torch.no_grad(): n_i = sample['net_input']['src_tokens'] #print(f"net input:\n{n_i}, pred: \n{prediction}") reward = discriminator( sample['net_input']['src_tokens'], prediction) # dataset.dst_dict.pad()) train_trg_batch = sample['target'] #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}") pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() # oracle valid _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) else: # MLE training #print(f"printing sample: \n{sample}") _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence) #, dataset.dst_dict.pad()) #print(f"disc out: {disc_out.shape}, labels: {labels.shape}") #print(f"labels: {labels}") d_loss = d_criterion(disc_out, labels.long()) acc = torch.sum(torch.Sigmoid() (disc_out).round() == labels).float() / len(labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg, # d_logging_meters['train_acc'].avg, # i)) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: sample['id'] = sample['id'].cuda() sample['net_input']['src_tokens'] = sample['net_input'][ 'src_tokens'].cuda() sample['net_input']['src_lengths'] = sample['net_input'][ 'src_lengths'].cuda() sample['net_input']['prev_output_tokens'] = sample[ 'net_input']['prev_output_tokens'].cuda() sample['target'] = sample['target'].cuda() # generator validation _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug("G dev loss at batch {0}: {1:.3f}".format( i, g_logging_meters['valid_loss'].avg)) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence, dataset.dst_dict.pad()) d_loss = d_criterion(disc_out, labels) acc = torch.sum(torch.Sigmoid()(disc_out).round() == labels).float() / len(labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg, # d_logging_meters['valid_acc'].avg, i)) torch.save(generator, open( checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format( g_logging_meters['valid_loss'].avg, epoch_i), 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
def train_g(args, dataset): logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) use_cuda = (cuda.device_count() >= 1) # check checkpoints saving path if not os.path.exists('checkpoints/generator'): os.makedirs('checkpoints/generator') checkpoints_path = 'checkpoints/generator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['bsz'] = AverageMeter() # sentences per batch logging_meters['update_times'] = AverageMeter() # Build model generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) if use_cuda: if len(args.gpuid) > 1: generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() else: generator.cpu() optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf epoch_i = 1 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] # main training loop while lr > args.min_g_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (min(args.max_source_positions, generator.encoder.max_positions()), min(args.max_target_positions, generator.decoder.max_positions())) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set training mode generator.train() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.item() / sample_size / math.log(2) logging_meters['bsz'].update(nsentences) logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format( i, logging_meters['train_loss'].avg, round(logging_meters['bsz'].avg), optimizer.param_groups[0]['lr'])) optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) optimizer.step() # validation -- this is a crude estimation because there might be some padding at the end max_positions_valid = ( generator.encoder.max_positions(), generator.decoder.max_positions(), ) # Initialize dataloader itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set validation mode generator.eval() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() with torch.no_grad(): for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss.item() / sample_size / math.log(2) logging_meters['valid_loss'].update(loss, sample_size) logging.debug("g dev loss at batch {0}: {1:.3f}".format( i, logging_meters['valid_loss'].avg)) # update learning rate lr_scheduler.step(logging_meters['valid_loss'].avg) lr = optimizer.param_groups[0]['lr'] logging.info( "Average g loss value per instance is {0} at the end of epoch {1}". format(logging_meters['valid_loss'].avg, epoch_i)) torch.save( generator.state_dict(), open( checkpoints_path + "data.nll_{0:.3f}.epoch_{1}.pt".format( logging_meters['valid_loss'].avg, epoch_i), 'wb')) if logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = logging_meters['valid_loss'].avg torch.save(generator.state_dict(), open(checkpoints_path + "best_gmodel.pt", 'wb')) epoch_i += 1
def train_d(args, dataset): logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) use_cuda = (torch.cuda.device_count() >= 1) # check checkpoints saving path if not os.path.exists('checkpoints/discriminator'): os.makedirs('checkpoints/discriminator') checkpoints_path = 'checkpoints/discriminator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['train_acc'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['valid_acc'] = AverageMeter() logging_meters['update_times'] = AverageMeter() # Build model discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # Load generator assert os.path.exists('checkpoints/generator/best_gmodel.pt') generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt') # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() # generator = torch.nn.DataParallel(generator).cuda() generator.cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() criterion = torch.nn.CrossEntropyLoss() # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()), # args.d_learning_rate, momentum=args.momentum, nesterov=True) optimizer = torch.optim.RMSprop( filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the accuracy achieve the define value max_epoch = args.max_epoch or math.inf epoch_i = 1 trg_acc = 0.82 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] # validation set data loader (only prepare once) train = prepare_training_data(args, dataset, 'train', generator, epoch_i, use_cuda) valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i, use_cuda) data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len) data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len) # main training loop while lr > args.min_d_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) if args.sample_without_replacement > 0 and epoch_i > 1: train = prepare_training_data(args, dataset, 'train', generator, epoch_i, use_cuda) data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len) # discriminator training dataloader train_loader = train_dataloader(data_train, batch_size=args.joint_batch_size, seed=seed, epoch=epoch_i, sort_by_source_size=False) valid_loader = eval_dataloader(data_valid, num_workers=4, batch_size=args.joint_batch_size) # set training mode discriminator.train() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(train_loader): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=use_cuda) disc_out = discriminator(sample['src_tokens'], sample['trg_tokens']) loss = criterion(disc_out, sample['labels']) _, prediction = F.softmax(disc_out, dim=1).topk(1) acc = torch.sum( prediction == sample['labels'].unsqueeze(1)).float() / len( sample['labels']) logging_meters['train_acc'].update(acc.item()) logging_meters['train_loss'].update(loss.item()) logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \ format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg, optimizer.param_groups[0]['lr'], i)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(discriminator.parameters(), args.clip_norm) optimizer.step() # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc del disc_out, loss, prediction, acc # set validation mode discriminator.eval() for i, sample in enumerate(valid_loader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=use_cuda) disc_out = discriminator(sample['src_tokens'], sample['trg_tokens']) loss = criterion(disc_out, sample['labels']) _, prediction = F.softmax(disc_out, dim=1).topk(1) acc = torch.sum( prediction == sample['labels'].unsqueeze(1)).float() / len( sample['labels']) logging_meters['valid_acc'].update(acc.item()) logging_meters['valid_loss'].update(loss.item()) logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \ format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg, optimizer.param_groups[0]['lr'], i)) del disc_out, loss, prediction, acc lr_scheduler.step(logging_meters['valid_loss'].avg) if logging_meters['valid_acc'].avg >= 0.70: torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \ .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i)) if logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = logging_meters['valid_loss'].avg torch.save(discriminator.state_dict(), checkpoints_path + "best_dmodel.pt") # pretrain the discriminator to achieve accuracy 82% if logging_meters['valid_acc'].avg >= trg_acc: return epoch_i += 1
def create_generator(self, args): self.generator = LSTMModel(args, self.dataset.src_dict, self.dataset.dst_dict, use_cuda=self.use_cuda) print("Generator loaded successfully!")