def main(args): """ Calc loss and perplexity on training and validation set """ logging.info('Commencing Validation!') torch.manual_seed(42) np.random.seed(42) utils.init_logging(args) # Load dictionaries [for each language] src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') if torch.cuda.is_available() and args.cuda: model = model.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers = 1, collate_fn = train_dataset.collater, batch_sampler = BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle = True, seed = 42)) # Calculate validation loss train_perplexity = validate(args, model, criterion, train_dataset, 0) valid_perplexity = validate(args, model, criterion, valid_dataset, 0)
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') if args.cuda: model = model.cuda() criterion = criterion.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer, scheduler) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): ## BPE Dropout # Set the seed to be equal to the epoch # (this way we guarantee same seeds over multiple training runs, but not for each training epoch) seed = epoch bpe_dropout_if_needed(seed, args.bpe_dropout) # Load the BPE (dropout-ed) training data train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue model.train() output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() # Update statistics for progress bar total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity, valid_loss = validate(args, model, criterion, valid_dataset, epoch) model.train() # Scheduler step if args.adaptive_lr: scheduler.step(valid_loss) # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, scheduler, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{ **vars(args), **vars(state_dict['args']) }) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format( args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) if args.cuda: go_slice = utils.move_to_cuda(go_slice) prev_words = go_slice next_words = None for _ in range(args.max_len): with torch.no_grad(): # Compute the decoder output by repeatedly feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # Suppress <UNK>s _, next_candidates = torch.topk(decoder_out, 2, dim=-1) best_candidates = next_candidates[:, :, 0] backoff_candidates = next_candidates[:, :, 1] next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates) prev_words = torch.cat([go_slice, next_words], dim=1) # Segment into sentences decoded_batch = next_words.cpu().numpy() output_sentences = [ decoded_batch[row, :] for row in range(decoded_batch.shape[0]) ] assert (len(output_sentences) == len(sample['id'].data)) # Remove padding temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append([]) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] # Save translations assert (len(output_sentences) == len(sample['id'].data)) for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader( test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict).cuda() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {}'.format(args.checkpoint_path)) translator = SequenceGenerator( model, tgt_dict, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early), normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty, ) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) with torch.no_grad(): hypos = translator.generate(sample['src_tokens'], sample['src_lengths']) for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)): src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :], tgt_dict.pad_idx) has_target = sample['tgt_tokens'] is not None target_tokens = utils.strip_pad(sample['tgt_tokens'].data[i, :], tgt_dict.pad_idx).int().cpu() if has_target else None src_str = src_dict.string(src_tokens, args.remove_bpe) target_str = tgt_dict.string(target_tokens, args.remove_bpe) if has_target else '' if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, colored(target_str, 'green'))) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu(), tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}'.format(sample_id, colored(hypo_str, 'blue'))) print('P-{}\t{}'.format(sample_id, ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())))) print('A-{}\t{}'.format(sample_id, ' '.join(map(lambda x: str(x.item()), alignment)))) # Score only the top hypothesis if has_target and i == 0: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.binarize(target_str, word_tokenize, add_if_not_exist=True)
def main(args): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported.') torch.manual_seed(args.seed) torch.cuda.set_device(args.device_id) utils.init_logging(args) if args.distributed_world_size > 1: torch.distributed.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({}) with {} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({}) with {} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{}.{}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{}.{}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data(split='train') valid_dataset = load_data(split='valid') # Build model and criterion model = models.build_model(args, src_dict, tgt_dict).cuda() logging.info('Built a model with {} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').cuda() # Build an optimizer and a learning rate schedule optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, min_lr=args.min_lr, factor=args.lr_shrink) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler) last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 for epoch in range(last_epoch + 1, args.max_epoch): train_loader = torch.utils.data.DataLoader( train_dataset, num_workers=args.num_workers, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=True, seed=args.seed)) model.train() stats = { 'loss': 0., 'lr': 0., 'num_tokens': 0., 'batch_size': 0., 'grad_norm': 0., 'clip': 0. } progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=(args.distributed_rank != 0)) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Forward and backward pass output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) optimizer.zero_grad() loss.backward() # Reduce gradients across all GPUs if args.distributed_world_size > 1: utils.reduce_grads(model.parameters()) total_loss, num_tokens, batch_size = list( map( sum, zip(*utils.all_gather_list([ loss.item(), sample['num_tokens'], len(sample['src_tokens']) ])))) else: total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) # Normalize gradients by number of tokens and perform clipping for param in model.parameters(): param.grad.data.div_(num_tokens) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() # Update statistics for progress bar stats['loss'] += total_loss / num_tokens / math.log(2) stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Adjust learning rate based on validation loss valid_loss = validate(args, model, criterion, valid_dataset, epoch) lr_scheduler.step(valid_loss) # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch, valid_loss) if optimizer.param_groups[0]['lr'] <= args.min_lr: logging.info('Done training!') break
def main(args): """ Main function. Visualizes attention weight arrays as nifty heat-maps. """ mpl.rc('font', family='VL Gothic') torch.manual_seed(42) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) print('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) print('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) vis_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, None, 1, 1, 0, shuffle=False, seed=42)) # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.load_state_dict(state_dict['model']) print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) # Store attention weight arrays attn_records = list() # Iterate over the visualization set for i, sample in enumerate(vis_loader): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Perform forward pass output, attn_weights = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) attn_records.append((sample, attn_weights)) # Only visualize the first 10 sentence pairs if i >= 10: break # Generate heat-maps and store them at the designated location if not os.path.exists(args.vis_dir): os.makedirs(args.vis_dir) for record_id, record in enumerate(attn_records): # Unpack sample, attn_map = record src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx) tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx) # Convert indices into word tokens src_str = src_dict.string(src_ids).split(' ') + ['<EOS>'] tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>'] # Generate heat-maps attn_map = attn_map.squeeze(dim=0).transpose(1, 0).cpu().detach().numpy() attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str) sns.heatmap(attn_df, cmap='Blues', linewidths=0.25, vmin=0.0, vmax=1.0, xticklabels=True, yticklabels=True, fmt='.3f') plt.yticks(rotation=0) plot_path = os.path.join(args.vis_dir, 'sentence_{:d}.png'.format(record_id)) plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight') plt.clf() print( 'Done! Visualized attention maps have been saved to the \'{:s}\' directory!' .format(args.vis_dir))
def unk_consumer(word, idx): if idx == dictionary.unk_idx and word != dictionary.unk_word: unk_counter.update([word]) tokens_list = [] with open(input_file, 'r') as inf: for line in inf: tokens = dictionary.binarize(line.strip(), word_tokenize, append_eos, consumer=unk_consumer) nsent, ntok = nsent + 1, ntok + len(tokens) tokens_list.append(tokens.numpy()) with open(output_file, 'wb') as outf: pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL) logging.info( 'Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token' .format(input_file, nsent, ntok, 100.0 * sum(unk_counter.values()) / ntok, dictionary.unk_word)) if __name__ == '__main__': args = get_args() utils.init_logging(args) logging.info('COMMAND: %s' % ' '.join(sys.argv)) logging.info('Arguments: {}'.format(vars(args))) main(args)
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) # sets the random seed from pytorch random number generators state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler(test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] # returns number of rows from sample['src_tokens'] searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)] # beam search with beamsize, max seq length and unkindex --> do this B times with torch.no_grad(): # disables gradient calculation # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent? # encoder_out = self.encoder(src_tokens, src_lengths) decoder_out = self.decoder(tgt_inputs, encoder_out) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) # vector of ones of length sample['src_tokens'] rows and 1 col filled with eos_idx casted to type sample[ # 'src_tokens'] if args.cuda: go_slice = utils.move_to_cuda(go_slice) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) # decoder out = decoder(tgt_inputs, encoder_out) # __QUESTION 2: Why do we keep one top candidate more than the beam size? log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # returns largest k elements (here beam_size+1) of the input torch.log(torch.softmax(decoder_out, # dim=2) in dimension -1 + 1 is taken because the input is given in logarithmic notation # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:, i, :] lstm_out = encoder_out['src_out'][0][:, i, :] final_hidden = encoder_out['src_out'][1][:, i, :] final_cell = encoder_out['src_out'][2][:, i, :] try: mask = encoder_out['src_mask'][i, :] except TypeError: mask = None node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat((go_slice[i], next_word)), log_p, 1) # add normalization here according to paper lp = normalize(node.length) score = node.eval()/lp # Add diverse score = diverse(score, j) # __QUESTION 3: Why do we add the node with a negative score? searches[i].add(-score, node) # Start generating further tokens until max sentence length reached for _ in range(args.max_len - 1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 2 log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # Create number of beam_size next nodes for every current node for i in range(log_probs.shape[0]): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction? # Store the node as final if EOS is generated if next_word[-1] == tgt_dict.eos_idx: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) # Add length normalization lp = normalize(node.length) score = node.eval()/lp # add diverse score = diverse(score, j) search.add_final(-score, node) # Add the node to current nodes for next iteration else: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) # Add length normalization lp = normalize(node.length) score = node.eval()/lp # add diverse score = diverse(score, j) search.add(-score, node) # __QUESTION 5: What happens internally when we prune our beams? # How do we know we always maintain the best sequences? for search in searches: search.prune() # Segment into 1 best sentences #best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches]) # segment 3 best oneliner best_sents = torch.stack([n[1].sequence[1:] for s in searches for n in s.get_best()]) # segment into n best sentences #for s in searches: # for n in s.get_best(): # best_sents = torch.stack([n[1].sequence[1:].cpu()]) print('n best sents', best_sents) # concatenates a sequence of tensors, gets the one best here, so we should use the n-best (3 best) here decoded_batch = best_sents.numpy() output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] # predicts first eos token if len(first_eos) > 0: # checks if the first eos token is not the beginning (position 0) temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] # here: adapt so that it takes the 3-best (aka n-best), % used for no overflow for ii, sent in enumerate(output_sentences): # all_hyps[int(sample['id'].data[ii])] = sent # variant for 3-best all_hyps[(int(sample['id'].data[int(ii / 3)]), int(ii % 3))] = sent # Write to file (write 3 best per sentence together) if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): # variant for 1-best # out_file.write(all_hyps[sent_id] + '\n') # variant for 3-best out_file.write(all_hyps[(int(sent_id / 3), int(sent_id % 3))] + '\n')
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{ **vars(args), **vars(state_dict['args']) }) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format( args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] searches = [ BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size) ] with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) # __QUESTION 1: What happens here and what do 'log_probs' and 'next_candidates' contain? decoder_out = length_normalization( decoder_out) #applies length normalization log_probs, next_candidates = torch.topk(torch.log( torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): # __QUESTION 2: Why do we need backoff candidates? best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:, i, :] lstm_out = encoder_out['src_out'][0][:, i, :] final_hidden = encoder_out['src_out'][1][:, i, :] final_cell = encoder_out['src_out'][2][:, i, :] try: mask = encoder_out['src_mask'][i, :] except TypeError: mask = None # __QUESTION 3: What happens internally when we add a new beam search node? node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat( (go_slice[i], next_word)), log_p, 1) searches[i].add(-node.eval(), node) # Start generating further tokens until max sentence length reached for _ in range(args.max_len - 1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack( [node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack( [node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 1 decoder_out = length_normalization( decoder_out) #length normalization function log_probs, next_candidates = torch.topk(torch.log( torch.softmax(length_normalization(decoder_out), dim=2)), args.beam_size + 1, dim=-1) # Create number of beam_size next nodes for every current node for i in range(log_probs.shape[0]): for j in range(args.beam_size): # see __QUESTION 2 best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: Why do we treat nodes that generated the end-of-sentence token differently? # Store the node as final if EOS is generated if next_word[-1] == tgt_dict.eos_idx: node = BeamSearchNode( search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) search.add_final(-node.eval(), node) # Add the node to current nodes for next iteration else: node = BeamSearchNode( search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) search.add(-node.eval(), node) # __QUESTION 5: What happens internally when we prune our beams? # How do we know we always maintain the best sequences? for search in searches: search.prune() # Segment into sentences best_sents = torch.stack( [search.get_best()[1].sequence[1:] for search in searches]) decoded_batch = best_sents.numpy() output_sentences = [ decoded_batch[row, :] for row in range(decoded_batch.shape[0]) ] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) np.random.seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data(split='train') if not args.train_on_tiny else load_data(split='tiny_train') valid_dataset = load_data(split='valid') # yichao: enable cuda use_cuda = torch.cuda.is_available() and args.device == 'cuda' device = torch.device("cuda" if use_cuda else "cpu") print("===> Using %s" % device) # Build model and optimization criterion # yichao: enable cuda, i.e. add .to(device) model = models.build_model(args, src_dict, tgt_dict).to(device) logging.info('Built a model with {:d} parameters'.format(sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').to(device) # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if len(sample) == 0: continue model.train() ''' ___QUESTION-1-DESCRIBE-F-START___ Describe what the following lines of code do. ''' # yichao: enable cuda sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'], sample['tgt_tokens'] = \ sample['src_tokens'].to(device), sample['src_lengths'].to(device), \ sample['tgt_inputs'].to(device), sample['tgt_tokens'].to(device) output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() '''___QUESTION-1-DESCRIBE-F-END___''' # Update statistics for progress bar total_loss, num_tokens, batch_size = loss.item(), sample['num_tokens'], len(sample['src_tokens']) stats['loss'] += total_loss * len(sample['src_lengths']) / sample['num_tokens'] stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix({key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items()}, refresh=True) logging.info('Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.4g}'.format( value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity = validate(args, model, criterion, valid_dataset, epoch) model.train() # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info('No validation set improvements observed for {:d} epochs. Early stop!'.format(args.patience)) break
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) model_rev = models.build_model(args, tgt_dict, src_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') criterion2 = nn.MSELoss(reduction='sum') if args.cuda: model = model.cuda() model_rev = model_rev.cuda() criterion = criterion.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler utils.load_checkpoint_rev(args, model_rev, optimizer) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() model_rev.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue model.train() (output, att), src_out = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) # print(sample['src_lengths']) # print(sample['tgt_inputs'].size()) # print(sample['src_tokens'].size()) src_inputs = sample['src_tokens'].clone() src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:( src_inputs.size(1) - 1)] src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1] tgt_lengths = sample['src_lengths'].clone( ) #torch.tensor([sample['tgt_tokens'].size(1)]) tgt_lengths += sample['tgt_inputs'].size( 1) - sample['src_tokens'].size(1) # print(tgt_lengths) # print(sample['num_tokens']) # if args.cuda: # tgt_lengths = tgt_lengths.cuda() (output_rev, att_rev), src_out_rev = model_rev(sample['tgt_tokens'], tgt_lengths, src_inputs) # notice that those are without masks already # print(sample['tgt_tokens'].view(-1)) d, d_rev = get_diff(att, src_out, att_rev, src_out_rev) # print(sample['src_tokens'].size()) # print(sample['tgt_inputs'].size()) # print(att.size()) # print(src_out.size()) # print(acontext.size()) # print(src_out_rev.size()) # # print(sample['tgt_inputs'].dtype) # # print(sample['src_lengths']) # # print(sample['src_tokens']) # # print('output %s' % str(output.size())) # # print(att) # # print(len(sample['src_lengths'])) # print(d) # print(d_rev) # print(criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])) # print(att2) # output=output.cpu().detach().numpy() # output=torch.from_numpy(output).cuda() # output_rev=output_rev.cpu().detach().numpy() # output_rev=torch.from_numpy(output_rev).cuda() loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) + d +\ criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) +d_rev loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) # loss_rev = \ # criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) # loss_rev.backward() # grad_norm_rev = torch.nn.utils.clip_grad_norm_(model_rev.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() # Update statistics for progress bar total_loss, num_tokens, batch_size = ( loss - d - d_rev).item(), sample['num_tokens'], len( sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] # stats['loss_rev'] += loss_rev.item() * len(sample['src_lengths']) / sample['src_tokens'].size(0) / sample['src_tokens'].size(1) stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity = validate(args, model, model_rev, criterion, valid_dataset, epoch) model.train() model_rev.train() # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, model_rev, optimizer, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler(test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} count = 0 for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)] with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent? go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) if args.cuda: go_slice = utils.move_to_cuda(go_slice) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) # __QUESTION 2: Why do we keep one top candidate more than the beam size? log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1) # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j+1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j+1] # For task 3 length normalization # To calculate the score after length normalization lp = (math.pow( (5 + log_probs.shape[1]), args.alpha ))/math.pow( (5+1), args.alpha) next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:,i,:] lstm_out = encoder_out['src_out'][0][:,i,:] final_hidden = encoder_out['src_out'][1][:,i,:] final_cell = encoder_out['src_out'][2][:,i,:] try: mask = encoder_out['src_mask'][i,:] except TypeError: mask = None node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat((go_slice[i], next_word)), log_p, 1) # __QUESTION 3: Why do we add the node with a negative score? # For task 3 and task 4 diversity promoting beam search # When alpha set to 0 and gamma set to 0, the is the original code # When alpha set to non-zero and gamma set to 0, this is for task 3 # When alpha set to 0 or non-zero and gamma non-zero, this is for task 4 searches[i].add(-(node.eval()/lp-(j+1)*args.gamma), node) # Start generating further tokens until max sentence length reached for _ in range(args.max_len-1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 2 log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1) for i in range(log_probs.shape[0]): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j+1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j+1] # For task 3 length normalization # To calculate the score after length normalization lp = (math.pow( (5 + log_probs.shape[1]), args.alpha ))/math.pow( (5+1), args.alpha) next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction? # Store the node as final if EOS is generated if next_word[-1 ] == tgt_dict.eos_idx: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) # For task 4 diversity promoting beam search. # Gamma is the weight to control the influences of rank on the score. # (j+1) is the rank for the current candidate. search.add_final(-(node.eval()/lp-(j+1)*args.gamma), node) # Add the node to current nodes for next iteration else: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) # For task 4 diversity promoting beam search. # Gamma is the weight to control the influences of rank on the score. # (j+1) is the rank for the current candidate. search.add(-(node.eval()/lp-(j+1)*args.gamma), node) # print ("loop") # __QUESTION 5: What happens internally when we prune our beams? # How do we know we always maintain the best sequences? for search in searches: search.prune() # Segment into sentences best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches]) decoded_batch = best_sents.numpy() # From line 239 to line 244, the code is for task 4 diversity promoting beam search. # To get the n-best lists # top_n_sent = [] # for search in searches : # top_n = search.get_top_n(args.beam_size) # for i in range(args.beam_size) : # top_n_sent.append(top_n[i][1].sequence[1:]) # best_top_sents = torch.stack(top_n_sent) # Line 248, the code is for task 4 diversity promoting beam search. # To get the n-best lists # decoded_batch = best_top_sents.numpy() output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # From line 270 to line 272, the code is for task 4 diversity promoting beam search. # To get the n-best lists # for sent in enumerate(output_sentences): # all_hyps[int(count)] = sent # count = count+1 # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')