def validate(args, model, model_rev, criterion, valid_dataset, epoch): """ Validates model performance on a held-out development set. """ valid_loader = \ torch.utils.data.DataLoader(valid_dataset, num_workers=1, collate_fn=valid_dataset.collater, batch_sampler=BatchSampler(valid_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=False, seed=42)) model.eval() model_rev.eval() stats = OrderedDict() stats['valid_loss'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 # Iterate over the validation set for i, sample in enumerate(valid_loader): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue with torch.no_grad(): # Compute loss (output, attn_scores), src_out = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) src_inputs = sample['src_tokens'].clone() src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:( src_inputs.size(1) - 1)] src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1] tgt_lengths = sample['src_lengths'].clone( ) #torch.tensor([sample['tgt_tokens'].size(1)]) tgt_lengths += sample['tgt_inputs'].size( 1) - sample['src_tokens'].size(1) (output_rev, attn_scores_rev), src_out_rev = model_rev(sample['tgt_tokens'], tgt_lengths, src_inputs) d, d_rev = get_diff(attn_scores, src_out, attn_scores_rev, src_out_rev) loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) + d + \ criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) + d_rev # Update tracked statistics stats['valid_loss'] += loss.item() stats['num_tokens'] += sample['num_tokens'] stats['batch_size'] += len(sample['src_tokens']) # Calculate validation perplexity stats['valid_loss'] = stats['valid_loss'] / stats['num_tokens'] perplexity = np.exp(stats['valid_loss']) stats['num_tokens'] = stats['num_tokens'] / stats['batch_size'] logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.3g}'.format(value) for key, value in stats.items())) + ' | valid_perplexity {:.3g}'.format(perplexity)) return perplexity
def forward(self, src_tokens, src_lengths): """ Performs a single forward pass through the instantiated encoder sub-network. """ # Embed tokens and apply dropout batch_size, src_time_steps = src_tokens.size() if self.is_cuda: src_tokens = utils.move_to_cuda(src_tokens) src_embeddings = self.embedding(src_tokens) _src_embeddings = F.dropout(src_embeddings, p=self.dropout_in, training=self.training) # Transpose batch: [batch_size, src_time_steps, num_features] -> [src_time_steps, batch_size, num_features] src_embeddings = _src_embeddings.transpose(0, 1) # Pack embedded tokens into a PackedSequence packed_source_embeddings = nn.utils.rnn.pack_padded_sequence( src_embeddings, src_lengths) # Pass source input through the recurrent layer(s) packed_outputs, ( final_hidden_states, final_cell_states) = self.lstm(packed_source_embeddings) # Unpack LSTM outputs and optionally apply dropout (dropout currently disabled) lstm_output, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, padding_value=0.) lstm_output = F.dropout(lstm_output, p=self.dropout_out, training=self.training) assert list(lstm_output.size()) == [ src_time_steps, batch_size, self.output_dim ] # sanity check if self.bidirectional: def combine_directions(outs): return torch.cat( [outs[0:outs.size(0):2], outs[1:outs.size(0):2]], dim=2) final_hidden_states = combine_directions(final_hidden_states) final_cell_states = combine_directions(final_cell_states) # Generate mask zeroing-out padded positions in encoder inputs src_mask = src_tokens.eq(self.dictionary.pad_idx) print('src_embeddings:', _src_embeddings) print('final_hidden_states:', final_hidden_states) return { 'src_embeddings': _src_embeddings.transpose(0, 1), 'src_out': (lstm_output, final_hidden_states, final_cell_states), 'src_mask': src_mask if src_mask.any() else None }
def validate(args, model, criterion, valid_dataset, epoch): valid_loader = torch.utils.data.DataLoader( valid_dataset, num_workers=args.num_workers, collate_fn=valid_dataset.collater, batch_sampler=BatchSampler(valid_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=True, seed=args.seed)) model.eval() stats = {'valid_loss': 0, 'num_tokens': 0, 'batch_size': 0} progress_bar = tqdm(valid_loader, desc='| Epoch {:03d}'.format(epoch), leave=False) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) if len(sample) == 0: continue with torch.no_grad(): output, attn_scores = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'], sample['video_inputs']) loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) stats['valid_loss'] += loss.item() / sample['num_tokens'] / math.log(2) stats['num_tokens'] += sample['num_tokens'] / len(sample['src_tokens']) stats['batch_size'] += len(sample['src_tokens']) progress_bar.set_postfix( { key: '{:.3g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.3g}'.format(value / len(progress_bar)) for key, value in stats.items()))) return stats['valid_loss'] / len(progress_bar)
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') if args.cuda: model = model.cuda() criterion = criterion.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer, scheduler) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): ## BPE Dropout # Set the seed to be equal to the epoch # (this way we guarantee same seeds over multiple training runs, but not for each training epoch) seed = epoch bpe_dropout_if_needed(seed, args.bpe_dropout) # Load the BPE (dropout-ed) training data train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue model.train() output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() # Update statistics for progress bar total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity, valid_loss = validate(args, model, criterion, valid_dataset, epoch) model.train() # Scheduler step if args.adaptive_lr: scheduler.step(valid_loss) # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, scheduler, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{ **vars(args), **vars(state_dict['args']) }) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format( args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) if args.cuda: go_slice = utils.move_to_cuda(go_slice) prev_words = go_slice next_words = None for _ in range(args.max_len): with torch.no_grad(): # Compute the decoder output by repeatedly feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # Suppress <UNK>s _, next_candidates = torch.topk(decoder_out, 2, dim=-1) best_candidates = next_candidates[:, :, 0] backoff_candidates = next_candidates[:, :, 1] next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates) prev_words = torch.cat([go_slice, next_words], dim=1) # Segment into sentences decoded_batch = next_words.cpu().numpy() output_sentences = [ decoded_batch[row, :] for row in range(decoded_batch.shape[0]) ] assert (len(output_sentences) == len(sample['id'].data)) # Remove padding temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append([]) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] # Save translations assert (len(output_sentences) == len(sample['id'].data)) for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader( test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict).cuda() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {}'.format(args.checkpoint_path)) translator = SequenceGenerator( model, tgt_dict, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early), normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty, ) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) with torch.no_grad(): hypos = translator.generate(sample['src_tokens'], sample['src_lengths']) for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)): src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :], tgt_dict.pad_idx) has_target = sample['tgt_tokens'] is not None target_tokens = utils.strip_pad(sample['tgt_tokens'].data[i, :], tgt_dict.pad_idx).int().cpu() if has_target else None src_str = src_dict.string(src_tokens, args.remove_bpe) target_str = tgt_dict.string(target_tokens, args.remove_bpe) if has_target else '' if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, colored(target_str, 'green'))) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu(), tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}'.format(sample_id, colored(hypo_str, 'blue'))) print('P-{}\t{}'.format(sample_id, ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())))) print('A-{}\t{}'.format(sample_id, ' '.join(map(lambda x: str(x.item()), alignment)))) # Score only the top hypothesis if has_target and i == 0: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.binarize(target_str, word_tokenize, add_if_not_exist=True)
def main(args): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported.') torch.manual_seed(args.seed) torch.cuda.set_device(args.device_id) utils.init_logging(args) if args.distributed_world_size > 1: torch.distributed.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({}) with {} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({}) with {} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{}.{}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{}.{}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data(split='train') valid_dataset = load_data(split='valid') # Build model and criterion model = models.build_model(args, src_dict, tgt_dict).cuda() logging.info('Built a model with {} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').cuda() # Build an optimizer and a learning rate schedule optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, min_lr=args.min_lr, factor=args.lr_shrink) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler) last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 for epoch in range(last_epoch + 1, args.max_epoch): train_loader = torch.utils.data.DataLoader( train_dataset, num_workers=args.num_workers, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=True, seed=args.seed)) model.train() stats = { 'loss': 0., 'lr': 0., 'num_tokens': 0., 'batch_size': 0., 'grad_norm': 0., 'clip': 0. } progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=(args.distributed_rank != 0)) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Forward and backward pass output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) optimizer.zero_grad() loss.backward() # Reduce gradients across all GPUs if args.distributed_world_size > 1: utils.reduce_grads(model.parameters()) total_loss, num_tokens, batch_size = list( map( sum, zip(*utils.all_gather_list([ loss.item(), sample['num_tokens'], len(sample['src_tokens']) ])))) else: total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) # Normalize gradients by number of tokens and perform clipping for param in model.parameters(): param.grad.data.div_(num_tokens) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() # Update statistics for progress bar stats['loss'] += total_loss / num_tokens / math.log(2) stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Adjust learning rate based on validation loss valid_loss = validate(args, model, criterion, valid_dataset, epoch) lr_scheduler.step(valid_loss) # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch, valid_loss) if optimizer.param_groups[0]['lr'] <= args.min_lr: logging.info('Done training!') break
def main(args): """ Main function. Visualizes attention weight arrays as nifty heat-maps. """ mpl.rc('font', family='VL Gothic') torch.manual_seed(42) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) print('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) print('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) vis_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, None, 1, 1, 0, shuffle=False, seed=42)) # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.load_state_dict(state_dict['model']) print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) # Store attention weight arrays attn_records = list() # Iterate over the visualization set for i, sample in enumerate(vis_loader): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Perform forward pass output, attn_weights = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) attn_records.append((sample, attn_weights)) # Only visualize the first 10 sentence pairs if i >= 10: break # Generate heat-maps and store them at the designated location if not os.path.exists(args.vis_dir): os.makedirs(args.vis_dir) for record_id, record in enumerate(attn_records): # Unpack sample, attn_map = record src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx) tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx) # Convert indices into word tokens src_str = src_dict.string(src_ids).split(' ') + ['<EOS>'] tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>'] # Generate heat-maps attn_map = attn_map.squeeze(dim=0).transpose(1, 0).cpu().detach().numpy() attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str) sns.heatmap(attn_df, cmap='Blues', linewidths=0.25, vmin=0.0, vmax=1.0, xticklabels=True, yticklabels=True, fmt='.3f') plt.yticks(rotation=0) plot_path = os.path.join(args.vis_dir, 'sentence_{:d}.png'.format(record_id)) plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight') plt.clf() print( 'Done! Visualized attention maps have been saved to the \'{:s}\' directory!' .format(args.vis_dir))
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) # sets the random seed from pytorch random number generators state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler(test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] # returns number of rows from sample['src_tokens'] searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)] # beam search with beamsize, max seq length and unkindex --> do this B times with torch.no_grad(): # disables gradient calculation # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent? # encoder_out = self.encoder(src_tokens, src_lengths) decoder_out = self.decoder(tgt_inputs, encoder_out) go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) # vector of ones of length sample['src_tokens'] rows and 1 col filled with eos_idx casted to type sample[ # 'src_tokens'] if args.cuda: go_slice = utils.move_to_cuda(go_slice) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) # decoder out = decoder(tgt_inputs, encoder_out) # __QUESTION 2: Why do we keep one top candidate more than the beam size? log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # returns largest k elements (here beam_size+1) of the input torch.log(torch.softmax(decoder_out, # dim=2) in dimension -1 + 1 is taken because the input is given in logarithmic notation # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:, i, :] lstm_out = encoder_out['src_out'][0][:, i, :] final_hidden = encoder_out['src_out'][1][:, i, :] final_cell = encoder_out['src_out'][2][:, i, :] try: mask = encoder_out['src_mask'][i, :] except TypeError: mask = None node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat((go_slice[i], next_word)), log_p, 1) # add normalization here according to paper lp = normalize(node.length) score = node.eval()/lp # Add diverse score = diverse(score, j) # __QUESTION 3: Why do we add the node with a negative score? searches[i].add(-score, node) # Start generating further tokens until max sentence length reached for _ in range(args.max_len - 1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 2 log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) # Create number of beam_size next nodes for every current node for i in range(log_probs.shape[0]): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction? # Store the node as final if EOS is generated if next_word[-1] == tgt_dict.eos_idx: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) # Add length normalization lp = normalize(node.length) score = node.eval()/lp # add diverse score = diverse(score, j) search.add_final(-score, node) # Add the node to current nodes for next iteration else: node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) # Add length normalization lp = normalize(node.length) score = node.eval()/lp # add diverse score = diverse(score, j) search.add(-score, node) # __QUESTION 5: What happens internally when we prune our beams? # How do we know we always maintain the best sequences? for search in searches: search.prune() # Segment into 1 best sentences #best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches]) # segment 3 best oneliner best_sents = torch.stack([n[1].sequence[1:] for s in searches for n in s.get_best()]) # segment into n best sentences #for s in searches: # for n in s.get_best(): # best_sents = torch.stack([n[1].sequence[1:].cpu()]) print('n best sents', best_sents) # concatenates a sequence of tensors, gets the one best here, so we should use the n-best (3 best) here decoded_batch = best_sents.numpy() output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] # predicts first eos token if len(first_eos) > 0: # checks if the first eos token is not the beginning (position 0) temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] # here: adapt so that it takes the 3-best (aka n-best), % used for no overflow for ii, sent in enumerate(output_sentences): # all_hyps[int(sample['id'].data[ii])] = sent # variant for 3-best all_hyps[(int(sample['id'].data[int(ii / 3)]), int(ii % 3))] = sent # Write to file (write 3 best per sentence together) if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): # variant for 1-best # out_file.write(all_hyps[sent_id] + '\n') # variant for 3-best out_file.write(all_hyps[(int(sent_id / 3), int(sent_id % 3))] + '\n')
def main(args): """ Main translation function' """ # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args_loaded = argparse.Namespace(**{ **vars(args), **vars(state_dict['args']) }) args_loaded.data = args.data args = args_loaded utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, 9999999, args.batch_size, 1, 0, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.eval() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {:s}'.format( args.checkpoint_path)) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) # Iterate over the test set all_hyps = {} for i, sample in enumerate(progress_bar): # Create a beam search object or every input sentence in batch batch_size = sample['src_tokens'].shape[0] #print("batch:", batch_size) searches = [ BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size) ] #print(searches) with torch.no_grad(): # Compute the encoder output encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) #print(encoder_out) # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent? go_slice = \ torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) #print(go_slice) #print(go_slice.size()) if args.cuda: go_slice = utils.move_to_cuda(go_slice) # Compute the decoder output at the first time step decoder_out, _ = model.decoder(go_slice, encoder_out) #print(decoder_out) # __QUESTION 2: Why do we keep one top candidate more than the beam size? # ANS: to anticipate the EOS token? log_probs, next_candidates = torch.topk(torch.log( torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) #print(log_probs) #print(next_candidates) # Create number of beam_size beam search nodes for every input sentence for i in range(batch_size): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] # Store the encoder_out information for the current input sentence and beam emb = encoder_out['src_embeddings'][:, i, :] lstm_out = encoder_out['src_out'][0][:, i, :] final_hidden = encoder_out['src_out'][1][:, i, :] final_cell = encoder_out['src_out'][2][:, i, :] try: mask = encoder_out['src_mask'][i, :] except TypeError: mask = None #Node here is equivalent to one hypothesis (predicted word) node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell, mask, torch.cat( (go_slice[i], next_word)), log_p, 1) #print(node) #exit() #print(next_word) # __QUESTION 3: Why do we add the node with a negative score? #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha) #length_norm_results = log_probs/normalizer searches[i].add(-(node.eval()), node) #print(node.eval()*(log_p / (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha))) #exit() #print(searches[i].add(-node.eval(), node)) # Start generating further tokens until max sentence length reached for _ in range(args.max_len - 1): # Get the current nodes to expand nodes = [n[1] for s in searches for n in s.get_current_beams()] if nodes == []: break # All beams ended in EOS # Reconstruct prev_words, encoder_out from current beam search nodes prev_words = torch.stack([node.sequence for node in nodes]) encoder_out["src_embeddings"] = torch.stack( [node.emb for node in nodes], dim=1) lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1) final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1) final_cell = torch.stack([node.final_cell for node in nodes], dim=1) encoder_out["src_out"] = (lstm_out, final_hidden, final_cell) try: encoder_out["src_mask"] = torch.stack( [node.mask for node in nodes], dim=0) except TypeError: encoder_out["src_mask"] = None with torch.no_grad(): # Compute the decoder output by feeding it the decoded sentence prefix decoder_out, _ = model.decoder(prev_words, encoder_out) # see __QUESTION 2 log_probs, next_candidates = torch.topk(torch.log( torch.softmax(decoder_out, dim=2)), args.beam_size + 1, dim=-1) #print(decoder_out) #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha) #length_norm_results = log_probs/normalizer #print(length_norm_results) # Create number of beam_size next nodes for every current node # ---- i think I have to add length norm here? for i in range(log_probs.shape[0]): for j in range(args.beam_size): best_candidate = next_candidates[i, :, j] #print(best_candidate) backoff_candidate = next_candidates[i, :, j + 1] best_log_p = log_probs[i, :, j] backoff_log_p = log_probs[i, :, j + 1] next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate) log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p) log_p = log_p[-1] next_word = torch.cat((prev_words[i][1:], next_word[-1:])) # Get parent node and beam search object for corresponding sentence node = nodes[i] search = node.search # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction? # Store the node as final if EOS is generated if next_word[-1] == tgt_dict.eos_idx: node = BeamSearchNode( search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp, node.length) #print(node) search.add_final( -node.eval() * (log_p / (((5 + len(next_candidates))**args.alpha)) / ((5 + 1)**args.alpha)), node) # Add the node to current nodes for next iteration else: #This is where I'll add the gamma for adapted beam search? node = BeamSearchNode( search, node.emb, node.lstm_out, node.final_hidden, node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]), next_word)), node.logp + log_p, node.length + 1) search.add(-node.eval(), node) # __QUESTION 5: What happens internally when we prune our beams? # How do we know we always maintain the best sequences? for search in searches: search.prune() #print(searches) # Segment into sentences best_sents = torch.stack( [search.get_best()[1].sequence[1:].cpu() for search in searches]) decoded_batch = best_sents.numpy() output_sentences = [ decoded_batch[row, :] for row in range(decoded_batch.shape[0]) ] # __QUESTION 6: What is the purpose of this for loop? temp = list() for sent in output_sentences: first_eos = np.where(sent == tgt_dict.eos_idx)[0] if len(first_eos) > 0: temp.append(sent[:first_eos[0]]) else: temp.append(sent) output_sentences = temp #print(output_sentences) # Convert arrays of indices into strings of words output_sentences = [tgt_dict.string(sent) for sent in output_sentences] for ii, sent in enumerate(output_sentences): all_hyps[int(sample['id'].data[ii])] = sent # Write to file if args.output is not None: with open(args.output, 'w') as out_file: for sent_id in range(len(all_hyps.keys())): out_file.write(all_hyps[sent_id] + '\n')
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') if args.cuda: model = model.cuda() criterion = criterion.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue model.train() ''' ___QUESTION-1-DESCRIBE-F-START___ Describe what the following lines of code do. ''' ''' First, the encoder is constructed. Then the loss is computed using cross entropy. Then the error is propagated backwards through the network. After that, the gradient of the loss function is calculated using pytorch. Then the weights are updated based on the current gradient. Finally, the gradient of all model parameters is set to 0. ''' output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() '''___QUESTION-1-DESCRIBE-F-END___''' # Update statistics for progress bar total_loss, num_tokens, batch_size = loss.item( ), sample['num_tokens'], len(sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity = validate(args, model, criterion, valid_dataset, epoch) model.train() # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break
def forward(self, tgt_inputs, encoder_out, incremental_state=None): """ Performs the forward pass through the instantiated model. """ # Optionally, feed decoder input token-by-token if incremental_state is not None: tgt_inputs = tgt_inputs[:, -1:] # __LEXICAL: Following code is to assist with the LEXICAL MODEL implementation # Recover encoder input src_embeddings = encoder_out['src_embeddings'] src_out, src_hidden_states, src_cell_states = encoder_out['src_out'] src_mask = encoder_out['src_mask'] src_time_steps = src_out.size(0) # Embed target tokens and apply dropout batch_size, tgt_time_steps = tgt_inputs.size() tgt_embeddings = self.embedding(tgt_inputs) tgt_embeddings = F.dropout(tgt_embeddings, p=self.dropout_in, training=self.training) # Transpose batch: [batch_size, tgt_time_steps, num_features] -> [tgt_time_steps, batch_size, num_features] tgt_embeddings = tgt_embeddings.transpose(0, 1) # Initialize previous states (or retrieve from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: tgt_hidden_states, tgt_cell_states, input_feed = cached_state else: tgt_hidden_states = [torch.zeros(tgt_inputs.size()[0], self.hidden_size) for i in range(len(self.layers))] tgt_cell_states = [torch.zeros(tgt_inputs.size()[0], self.hidden_size) for i in range(len(self.layers))] input_feed = tgt_embeddings.data.new(batch_size, self.hidden_size).zero_() if self.layers[0].weight_ih.is_cuda: tgt_hidden_states = utils.move_to_cuda(tgt_hidden_states) tgt_cell_states = utils.move_to_cuda(tgt_cell_states) # Initialize attention output node attn_weights = tgt_embeddings.data.new(batch_size, tgt_time_steps, src_time_steps).zero_() rnn_outputs = [] # __LEXICAL: Following code is to assist with the LEXICAL MODEL implementation # Cache lexical context vectors per translation time-step lexical_contexts = [] for j in range(tgt_time_steps): # Concatenate the current token embedding with output from previous time step (i.e. 'input feeding') lstm_input = torch.cat([tgt_embeddings[j, :, :], input_feed], dim=1) for layer_id, rnn_layer in enumerate(self.layers): # Pass target input through the recurrent layer(s) tgt_hidden_states[layer_id], tgt_cell_states[layer_id] = \ rnn_layer(lstm_input, (tgt_hidden_states[layer_id], tgt_cell_states[layer_id])) # Current hidden state becomes input to the subsequent layer; apply dropout lstm_input = F.dropout(tgt_hidden_states[layer_id], p=self.dropout_out, training=self.training) if self.attention is None: input_feed = tgt_hidden_states[-1] else: input_feed, step_attn_weights = self.attention(tgt_hidden_states[-1], src_out, src_mask) attn_weights[:, j, :] = step_attn_weights if self.use_lexical_model: # __LEXICAL: Compute and collect LEXICAL MODEL context vectors here # TODO: --------------------------------------------------------------------- CUT pass # TODO: --------------------------------------------------------------------- /CUT input_feed = F.dropout(input_feed, p=self.dropout_out, training=self.training) rnn_outputs.append(input_feed) # Cache previous states (only used during incremental, auto-regressive generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (tgt_hidden_states, tgt_cell_states, input_feed)) # Collect outputs across time steps decoder_output = torch.cat(rnn_outputs, dim=0).view(tgt_time_steps, batch_size, self.hidden_size) # Transpose batch back: [tgt_time_steps, batch_size, num_features] -> [batch_size, tgt_time_steps, num_features] decoder_output = decoder_output.transpose(0, 1) # Final projection decoder_output = self.final_projection(decoder_output) if self.use_lexical_model: # __LEXICAL: Incorporate the LEXICAL MODEL into the prediction of target tokens here pass # TODO: --------------------------------------------------------------------- /CUT return decoder_output, attn_weights
def main(args): """ Main training function. Trains the translation model over the course of several epochs, including dynamic learning rate adjustment and gradient clipping. """ logging.info('Commencing training!') torch.manual_seed(42) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load datasets def load_data(split): return Seq2SeqDataset( src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) train_dataset = load_data( split='train') if not args.train_on_tiny else load_data( split='tiny_train') valid_dataset = load_data(split='valid') # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) model_rev = models.build_model(args, tgt_dict, src_dict) logging.info('Built a model with {:d} parameters'.format( sum(p.numel() for p in model.parameters()))) criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') criterion2 = nn.MSELoss(reduction='sum') if args.cuda: model = model.cuda() model_rev = model_rev.cuda() criterion = criterion.cuda() # Instantiate optimizer and learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), args.lr) # Load last checkpoint if one exists state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler utils.load_checkpoint_rev(args, model_rev, optimizer) # lr_scheduler last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 # Track validation performance for early stopping bad_epochs = 0 best_validate = float('inf') for epoch in range(last_epoch + 1, args.max_epoch): train_loader = \ torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, 0, shuffle=True, seed=42)) model.train() model_rev.train() stats = OrderedDict() stats['loss'] = 0 stats['lr'] = 0 stats['num_tokens'] = 0 stats['batch_size'] = 0 stats['grad_norm'] = 0 stats['clip'] = 0 # Display progress progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) # Iterate over the training set for i, sample in enumerate(progress_bar): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue model.train() (output, att), src_out = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) # print(sample['src_lengths']) # print(sample['tgt_inputs'].size()) # print(sample['src_tokens'].size()) src_inputs = sample['src_tokens'].clone() src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:( src_inputs.size(1) - 1)] src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1] tgt_lengths = sample['src_lengths'].clone( ) #torch.tensor([sample['tgt_tokens'].size(1)]) tgt_lengths += sample['tgt_inputs'].size( 1) - sample['src_tokens'].size(1) # print(tgt_lengths) # print(sample['num_tokens']) # if args.cuda: # tgt_lengths = tgt_lengths.cuda() (output_rev, att_rev), src_out_rev = model_rev(sample['tgt_tokens'], tgt_lengths, src_inputs) # notice that those are without masks already # print(sample['tgt_tokens'].view(-1)) d, d_rev = get_diff(att, src_out, att_rev, src_out_rev) # print(sample['src_tokens'].size()) # print(sample['tgt_inputs'].size()) # print(att.size()) # print(src_out.size()) # print(acontext.size()) # print(src_out_rev.size()) # # print(sample['tgt_inputs'].dtype) # # print(sample['src_lengths']) # # print(sample['src_tokens']) # # print('output %s' % str(output.size())) # # print(att) # # print(len(sample['src_lengths'])) # print(d) # print(d_rev) # print(criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])) # print(att2) # output=output.cpu().detach().numpy() # output=torch.from_numpy(output).cuda() # output_rev=output_rev.cpu().detach().numpy() # output_rev=torch.from_numpy(output_rev).cuda() loss = \ criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) + d +\ criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) +d_rev loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) # loss_rev = \ # criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) # loss_rev.backward() # grad_norm_rev = torch.nn.utils.clip_grad_norm_(model_rev.parameters(), args.clip_norm) optimizer.step() optimizer.zero_grad() # Update statistics for progress bar total_loss, num_tokens, batch_size = ( loss - d - d_rev).item(), sample['num_tokens'], len( sample['src_tokens']) stats['loss'] += total_loss * len( sample['src_lengths']) / sample['num_tokens'] # stats['loss_rev'] += loss_rev.item() * len(sample['src_lengths']) / sample['src_tokens'].size(0) / sample['src_tokens'].size(1) stats['lr'] += optimizer.param_groups[0]['lr'] stats['num_tokens'] += num_tokens / len(sample['src_tokens']) stats['batch_size'] += batch_size stats['grad_norm'] += grad_norm stats['clip'] += 1 if grad_norm > args.clip_norm else 0 progress_bar.set_postfix( { key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items() }, refresh=True) logging.info('Epoch {:03d}: {}'.format( epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar)) for key, value in stats.items()))) # Calculate validation loss valid_perplexity = validate(args, model, model_rev, criterion, valid_dataset, epoch) model.train() model_rev.train() # Save checkpoints if epoch % args.save_interval == 0: utils.save_checkpoint(args, model, model_rev, optimizer, epoch, valid_perplexity) # lr_scheduler # Check whether to terminate training if valid_perplexity < best_validate: best_validate = valid_perplexity bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= args.patience: logging.info( 'No validation set improvements observed for {:d} epochs. Early stop!' .format(args.patience)) break