def run_viz(args): assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported" print("Loading test trees from {}...".format(args.viz_path)) viz_treebank = trees.load_trees(args.viz_path) print("Loaded {:,} test examples.".format(len(viz_treebank))) print("Loading model from {}...".format(args.model_path_base)) info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Only self-attentive models are supported" parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) from viz import viz_attention stowed_values = {} orig_multihead_forward = parse_nk.MultiHeadAttention.forward def wrapped_multihead_forward(self, inp, batch_idxs, **kwargs): res, attns = orig_multihead_forward(self, inp, batch_idxs, **kwargs) stowed_values[f'attns{stowed_values["stack"]}'] = attns.cpu().data.numpy() stowed_values['stack'] += 1 return res, attns parse_nk.MultiHeadAttention.forward = wrapped_multihead_forward # Select the sentences we will actually be visualizing max_len_viz = 15 if max_len_viz > 0: viz_treebank = [tree for tree in viz_treebank if len(list(tree.leaves())) <= max_len_viz] viz_treebank = viz_treebank[:1] print("Parsing viz sentences...") for start_index in range(0, len(viz_treebank), args.eval_batch_size): subbatch_trees = viz_treebank[start_index:start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] stowed_values = dict(stack=0) predicted, _ = parser.parse_batch(subbatch_sentences) del _ predicted = [p.convert() for p in predicted] stowed_values['predicted'] = predicted for snum, sentence in enumerate(subbatch_sentences): sentence_words = [tokens.START] + [x[1] for x in sentence] + [tokens.STOP] for stacknum in range(stowed_values['stack']): attns_padded = stowed_values[f'attns{stacknum}'] attns = attns_padded[snum::len(subbatch_sentences), :len(sentence_words), :len(sentence_words)] viz_attention(sentence_words, attns)
def run_viz(args): assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" print("Loading test trees from {}...".format(args.viz_path)) viz_treebank, viz_sent_ids = trees.load_trees_with_idx(args.viz_path, \ args.viz_sent_id_path) print("Loaded {:,} test examples.".format(len(viz_treebank))) print("Loading model from {}...".format(args.model_path_base)) info = torch_load(args.model_path_base) assert 'hparams' in info[ 'spec'], "Only self-attentive models are supported" parser = parse_model.SpeechParser.from_spec(info['spec'], \ info['state_dict']) viz_feat_dict = {} if info['spec']['speech_features'] is not None: speech_features = info['spec']['speech_features'] print("Loading speech features for test set...") for feat_type in speech_features: print("\t", feat_type) feat_path = os.path.join(args.feature_path, \ args.viz_prefix + '_' + feat_type + '.pickle') with open(feat_path, 'rb') as f: feat_data = pickle.load(f, encoding='latin1') viz_feat_dict[feat_type] = feat_data from viz import viz_attention stowed_values = {} orig_multihead_forward = parse_model.MultiHeadAttention.forward def wrapped_multihead_forward(self, inp, batch_idxs, **kwargs): res, attns = orig_multihead_forward(self, inp, batch_idxs, **kwargs) stowed_values['attns{}'.format( stowed_values["stack"])] = attns.cpu().data.numpy() stowed_values['stack'] += 1 return res, attns parse_model.MultiHeadAttention.forward = wrapped_multihead_forward # Select the sentences we will actually be visualizing max_len_viz = 40 if max_len_viz > 0: viz_treebank = [ tree for tree in viz_treebank if len(list(tree.leaves())) <= max_len_viz ] #viz_treebank = viz_treebank[:1] print("Parsing viz sentences...") viz_data = [] for start_index in range(0, len(viz_treebank), args.eval_batch_size): subbatch_trees = viz_treebank[start_index:start_index + \ args.eval_batch_size] subbatch_sent_ids = viz_sent_ids[start_index:start_index + \ args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for \ leaf in tree.leaves()] for tree in subbatch_trees] subbatch_features = load_features(subbatch_sent_ids, viz_feat_dict) stowed_values = dict(stack=0) predicted, _ = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features) del _ predicted = [p.convert() for p in predicted] stowed_values['predicted'] = predicted for snum, sentence in enumerate(subbatch_sentences): sentence_words = [tokens.START] + [x[1] for x in sentence ] + [tokens.STOP] for stacknum in range(stowed_values['stack']): attns_padded = stowed_values['attns{}'.format(stacknum)] attns = attns_padded[snum::len(subbatch_sentences), :len( sentence_words), :len(sentence_words)] dat = viz_attention(sentence_words, attns) viz_data.append(dat) outf = open(args.viz_out, 'wb') pickle.dump(viz_data, outf) outf.close()