def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{ **vars(args), **vars(state_dict['args']) }) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format( args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] #searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)] #Task 3: Normalization searches = [ BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx, args.alpha_i / args.alpha_max) for i in range(batch_size) ] with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) # __QUESTION 1: What happens here and what do 'log_probs' and 'next_candidates' contain? log_probs, next_candidates = torch.topk(torch.log( torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): # __QUESTION 2: Why do we need backoff candidates? best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:, i, :] lstm_out = encoder_out['src_out'][0][:, i, :] final_hidden = encoder_out['src_out'][1][:, i, :] final_cell = encoder_out['src_out'][2][:, i, :] try: mask = encoder_out['src_mask'][i, :] except TypeError: mask = None # __QUESTION 3: What happens internally when we add a new beam search node? node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat( (go_slice[i], next_word)), log_p, 1) searches[i].add(-node.eval(), node) # Start generating further tokens until max sentence length reached for _ in range(args.max_len - 1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack( [node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack( [node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 1 log_probs, next_candidates = torch.topk(torch.log( torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # Create number of beam_size next nodes for every current node for i in range(log_probs.shape[0]): for j in range(args.beam_size): # see __QUESTION 2 best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: Why do we treat nodes that generated the end-of-sentence token differently? # Store the node as final if EOS is generated if next_word[-1] == tgt_dict.eos_idx: node = BeamSearchNode( search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) search.add_final(-node.eval(), node) # Add the node to current nodes for next iteration else: node = BeamSearchNode( search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) #search.add(-node.eval(), node) #Task 4: Diversity of Beam Search diversity_penalty = j * args.gamma_i / args.gamma_max search.add(-node.eval() + diversity_penalty, node) # __QUESTION 5: What happens internally when we prune our beams? # How do we know we always maintain the best sequences? for search in searches: search.prune() # Segment into sentences best_sents = torch.stack( [search.get_best()[1].sequence[1:] for search in searches]) decoded_batch = best_sents.numpy() output_sentences = [ decoded_batch[row, :] for row in range(decoded_batch.shape[0]) ] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{ **vars(args), **vars(state_dict['args']) }) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format( args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) if args.cuda: go_slice = utils.move_to_cuda(go_slice) prev_words = go_slice next_words = None for _ in range(args.max_len): with torch.no_grad(): # Compute the decoder output by repeatedly feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # Suppress <UNK>s _, next_candidates = torch.topk(decoder_out, 2, dim=-1) best_candidates = next_candidates[:, :, 0] backoff_candidates = next_candidates[:, :, 1] next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates) prev_words = torch.cat([go_slice, next_words], dim=1) # Segment into sentences decoded_batch = next_words.cpu().numpy() output_sentences = [ decoded_batch[row, :] for row in range(decoded_batch.shape[0]) ] assert (len(output_sentences) == len(sample['id'].data)) # Remove padding temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append([]) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] # Save translations assert (len(output_sentences) == len(sample['id'].data)) for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ if args.dropout: args.encoder_dropout_in = args.dropout args.encoder_dropout_out = args.dropout args.decoder_dropout_in = args.dropout args.decoder_dropout_out = args.dropout print(args) logging.info('Commencing training!') torch.manual_seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join( args.data, 'dict{:s}.{:s}'.format(args.bpe_vocab_size, args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join( args.data, 'dict{:s}.{:s}'.format(args.bpe_vocab_size, args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') if args.cuda: model = model.cuda() criterion = criterion.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue model.train() output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() # Update statistics for progress bar total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity = validate(args, model, criterion, valid_dataset, epoch) model.train() # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break
def main(args): """ Main function. Visualizes attention weight arrays as nifty heat-maps. """ mpl.rc('font', family='VL Gothic') torch.manual_seed(42) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) print('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) print('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) vis_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, None, 1, 1, 0, shuffle=False, seed=42)) # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.load_state_dict(state_dict['model']) print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) # Store attention weight arrays attn_records = list() # Iterate over the visualization set for i, sample in enumerate(vis_loader): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Perform forward pass output, attn_weights = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) attn_records.append((sample, attn_weights)) # Only visualize the first 10 sentence pairs if i >= 10: break # Generate heat-maps and store them at the designated location if not os.path.exists(args.vis_dir): os.makedirs(args.vis_dir) for record_id, record in enumerate(attn_records): # Unpack sample, attn_map = record src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx) tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx) # Convert indices into word tokens src_str = src_dict.string(src_ids).split(' ') + ['<EOS>'] tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>'] # Generate heat-maps attn_map = attn_map.squeeze(dim=0).transpose(1, 0).detach().numpy() attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str) sns.heatmap(attn_df, cmap='Blues', linewidths=0.25, vmin=0.0, vmax=1.0, xticklabels=True, yticklabels=True, fmt='.3f') plt.yticks(rotation=0) plot_path = os.path.join(args.vis_dir, 'sentence_{:d}.png'.format(record_id)) plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight') plt.clf() print( 'Done! Visualized attention maps have been saved to the \'{:s}\' directory!' .format(args.vis_dir))
def main(args): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported.') torch.manual_seed(args.seed) torch.cuda.set_device(args.device_id) utils.init_logging(args) if args.distributed_world_size > 1: torch.distributed.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({}) with {} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({}) with {} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{}.{}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{}.{}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data(split='train') valid_dataset = load_data(split='valid') # Build model and criterion model = models.build_model(args, src_dict, tgt_dict).cuda() logging.info('Built a model with {} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').cuda() # Build an optimizer and a learning rate schedule optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=20, min_lr=args.min_lr, factor=args.lr_shrink) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler) last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 for epoch in range(last_epoch + 1, args.max_epoch): train_loader = torch.utils.data.DataLoader( train_dataset, num_workers=args.num_workers, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=True, seed=args.seed)) model.train() stats = { 'loss': 0., 'lr': 0., 'num_tokens': 0., 'batch_size': 0., 'grad_norm': 0., 'clip': 0. } progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=(args.distributed_rank != 0)) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Forward and backward pass output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) optimizer.zero_grad() loss.backward() # Reduce gradients across all GPUs if args.distributed_world_size > 1: utils.reduce_grads(model.parameters()) total_loss, num_tokens, batch_size = list( map( sum, zip(*utils.all_gather_list([ loss.item(), sample['num_tokens'], len(sample['src_tokens']) ])))) else: total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) # Normalize gradients by number of tokens and perform clipping for param in model.parameters(): param.grad.data.div_(num_tokens) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() # Update statistics for progress bar stats['loss'] += total_loss / num_tokens / math.log(2) stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Adjust learning rate based on validation loss valid_loss = validate(args, model, criterion, valid_dataset, epoch) lr_scheduler.step(valid_loss) # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch, valid_loss) if optimizer.param_groups[0]['lr'] <= args.min_lr: logging.info('Done training!') break
def main(args): os.environ["CUDA_VISIBLE_DEVICES"] = '0' # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries all_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format('all'))) logging.info('Loaded a source dictionary with {} words'.format( len(all_dict))) # Load dataset test_dataset = How2Dataset(src_file=os.path.join(args.data, 'test.{}'.format('tran')), tgt_file=os.path.join(args.data, 'test.{}'.format('desc')), all_dict=all_dict, video_file=args.video_file, video_dir=args.video_dir) test_loader = torch.utils.data.DataLoader( test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, args.max_tokens, 1, 1, #args.batch_size,args.distributed_world_size, args.distributed_rank, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, all_dict).cuda() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {}'.format( args.checkpoint_path)) translator = SequenceGenerator( model, all_dict, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early), normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty, ) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) for i, sample in enumerate(progress_bar): #logging.info(sample) sample = utils.move_to_cuda(sample) with torch.no_grad(): hypos = translator.generate(sample['src_tokens'], sample['src_lengths'], sample['video_inputs']) for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)): src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :], all_dict.pad_idx) has_target = sample['tgt_tokens'] is not None target_tokens = utils.strip_pad( sample['tgt_tokens'].data[i, :], all_dict.pad_idx).int().cpu() if has_target else None src_str = all_dict.string(src_tokens, args.remove_bpe) target_str = all_dict.string(target_tokens, args.remove_bpe) if has_target else '' if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, colored(target_str, 'green'))) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu(), tgt_dict=all_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}'.format(sample_id, colored(hypo_str, 'blue'))) print('P-{}\t{}'.format( sample_id, ' '.join( map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())))) print('A-{}\t{}'.format( sample_id, ' '.join(map(lambda x: str(x.item()), alignment)))) # Score only the top hypothesis if has_target and i == 0: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = all_dict.binarize(target_str, word_tokenize, add_if_not_exist=True) print(target_tokens)
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler(test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} adapted_beam_nbest = [] for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] #print("batch:", batch_size) searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)] #print(searches) with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) #print(encoder_out) # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent? go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) #print(go_slice) #print(go_slice.size()) if args.cuda: go_slice = utils.move_to_cuda(go_slice) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) #print(decoder_out) # __QUESTION 2: Why do we keep one top candidate more than the beam size? # ANS: to anticipate the EOS token? log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1) #log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), #k=args.nbest+1, dim=-1) #print(log_probs) #print(next_candidates) # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j+1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j+1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:,i,:] lstm_out = encoder_out['src_out'][0][:,i,:] final_hidden = encoder_out['src_out'][1][:,i,:] final_cell = encoder_out['src_out'][2][:,i,:] try: mask = encoder_out['src_mask'][i,:] except TypeError: mask = None #Node here is equivalent to one hypothesis (predicted word) node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat((go_slice[i], next_word)), log_p, 1) #print(node) #exit() #print(next_word) # __QUESTION 3: Why do we add the node with a negative score? #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha) #length_norm_results = log_probs/normalizer searches[i].add(-(node.eval()), node) #print(node.eval()*(log_p / (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha))) #exit() #print(searches[i].add(-node.eval(), node)) # Start generating further tokens until max sentence length reached for _ in range(args.max_len-1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 2 log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1) #print(decoder_out) #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha) #length_norm_results = log_probs/normalizer #print(length_norm_results) # Create number of beam_size next nodes for every current node for i in range(log_probs.shape[0]): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] #print(best_candidate) backoff_candidate = next_candidates[i, :, j+1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j+1] #print(backoff_log_p) next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction? # Store the node as final if EOS is generated if next_word[-1 ] == tgt_dict.eos_idx: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) #print(node) search.add_final(-node.eval()*(log_p / (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)), node) # Add the node to current nodes for next iteration else: #This is where I'll add the gamma for adapted beam search? node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) search.add(-node.eval(), node) #search.add(-node.eval() * args.gamma, node) # __QUESTION 5: What happens internally when we prune our beams? #Question 5: How do we know we always maintain the best sequences? for search in searches: search.prune() #print(searches) # Segment into sentences # -- get top n search? #best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches]) #print(search.get_best(args.nbest)) #n_best_sents = torch.stack([sent[1].sequence[1:].cpu() for search in searches for sent in search.get_best(args.nbest)]) n_best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu()[:args.nbest] for search in searches]) #n_best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches[:args.nbest]]) #print(best_sents) #exit() decoded_batch = n_best_sents.numpy() output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp #print(output_sentences) # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] if args.nbest > 1: adapted_beam_nbest.extend(output_sentences) else: for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent #print(adapted_beam_nbest) # Write to file if args.output is not None: with open(args.output, 'w') as out_file: if args.nbest > 1: for sent in adapted_beam_nbest: out_file.write(sent + '\n') else: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) np.random.seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') valid_dataset = load_data(split='valid') # Check CUDA availability device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda:0") # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict).to(device) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').to(device) # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): train_loader = \ torch.utils.data.DataLoader(train_dataset, collate_fn=train_dataset.collater, # num_workers=1, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if len(sample) == 0: continue model.train() for k, v in sample.items(): if isinstance(v, torch.Tensor): sample[k] = v.to(device) ''' ___QUESTION-1-DESCRIBE-F-START___ Describe what the following lines of code do. The following lines of code create a single iteration of training the model. The model is first used to predict a batch, then the cross-entropy loss of the predictions is calculated and backpropagated through the network. It is worth noting that the gradients are normalized in order to ensure we do not encounter any exploding gradients. Finaly the network is updated using the Adam optimizer, before the gradients are reset to zero for the next iteration. ''' output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() '''___QUESTION-1-DESCRIBE-F-END___''' # Update statistics for progress bar total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity = validate(args, model, criterion, valid_dataset, epoch, device) model.train() # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break