def main(args): # Load arguments from checkpoint torch.manual_seed(args.seed) state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.source_lang))) logging.info('Loaded a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict))) tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.target_lang))) logging.info('Loaded a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) test_loader = torch.utils.data.DataLoader( test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, args.max_tokens, args.batch_size, args.distributed_world_size, args.distributed_rank, shuffle=False, seed=args.seed)) # Build model and criterion model = models.build_model(args, src_dict, tgt_dict).cuda() model.load_state_dict(state_dict['model']) logging.info('Loaded a model from checkpoint {}'.format(args.checkpoint_path)) translator = SequenceGenerator( model, tgt_dict, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early), normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty, ) progress_bar = tqdm(test_loader, desc='| Generation', leave=False) for i, sample in enumerate(progress_bar): sample = utils.move_to_cuda(sample) with torch.no_grad(): hypos = translator.generate(sample['src_tokens'], sample['src_lengths']) for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)): src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :], tgt_dict.pad_idx) has_target = sample['tgt_tokens'] is not None target_tokens = utils.strip_pad(sample['tgt_tokens'].data[i, :], tgt_dict.pad_idx).int().cpu() if has_target else None src_str = src_dict.string(src_tokens, args.remove_bpe) target_str = tgt_dict.string(target_tokens, args.remove_bpe) if has_target else '' if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, colored(target_str, 'green'))) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu(), tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}'.format(sample_id, colored(hypo_str, 'blue'))) print('P-{}\t{}'.format(sample_id, ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())))) print('A-{}\t{}'.format(sample_id, ' '.join(map(lambda x: str(x.item()), alignment)))) # Score only the top hypothesis if has_target and i == 0: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.binarize(target_str, word_tokenize, add_if_not_exist=True)
def main(args): """ Main function. Visualizes attention weight arrays as nifty heat-maps. """ mpl.rc('font', family='VL Gothic') torch.manual_seed(42) state_dict = torch.load( args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) utils.init_logging(args) # Load dictionaries src_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) print('Loaded a source dictionary ({:s}) with {:d} words'.format( args.source_lang, len(src_dict))) tgt_dict = Dictionary.load( os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) print('Loaded a target dictionary ({:s}) with {:d} words'.format( args.target_lang, len(tgt_dict))) # Load dataset test_dataset = Seq2SeqDataset( src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), src_dict=src_dict, tgt_dict=tgt_dict) vis_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, batch_sampler=BatchSampler( test_dataset, None, 1, 1, 0, shuffle=False, seed=42)) # Build model and optimization criterion model = models.build_model(args, src_dict, tgt_dict) if args.cuda: model = model.cuda() model.load_state_dict(state_dict['model']) print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) # Store attention weight arrays attn_records = list() # Iterate over the visualization set for i, sample in enumerate(vis_loader): if args.cuda: sample = utils.move_to_cuda(sample) if len(sample) == 0: continue # Perform forward pass output, attn_weights = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) attn_records.append((sample, attn_weights)) # Only visualize the first 10 sentence pairs if i >= 10: break # Generate heat-maps and store them at the designated location if not os.path.exists(args.vis_dir): os.makedirs(args.vis_dir) for record_id, record in enumerate(attn_records): # Unpack sample, attn_map = record src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx) tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx) # Convert indices into word tokens src_str = src_dict.string(src_ids).split(' ') + ['<EOS>'] tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>'] # Generate heat-maps attn_map = attn_map.squeeze(dim=0).transpose(1, 0).cpu().detach().numpy() attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str) sns.heatmap(attn_df, cmap='Blues', linewidths=0.25, vmin=0.0, vmax=1.0, xticklabels=True, yticklabels=True, fmt='.3f') plt.yticks(rotation=0) plot_path = os.path.join(args.vis_dir, 'sentence_{:d}.png'.format(record_id)) plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight') plt.clf() print( 'Done! Visualized attention maps have been saved to the \'{:s}\' directory!' .format(args.vis_dir))