def dev_predict(task_path, src_str, is_plot=True): """Helper used to visualize and understand why and what the model predicts. Args: task_path (str): path to the saved task directory containing, amongst other, the model. src_str (str): source sentence that will be used to predict. is_plot (bool, optional): whether to plots the attention pattern. Returns: out_words (list): decoder predictions. other (dictionary): additional information used for predictions. test (dictionary): additional information that is only stored in dev mode. These can include temporary variables that do not have to be stored in `other` but that can still be interesting to inspect. """ check = Checkpoint.load(task_path) check.model.set_dev_mode() predictor = Predictor(check.model, check.input_vocab, check.output_vocab) out_words, other = predictor.predict(src_str.split()) test = dict() for k, v in other["test"].items(): tensor = v if isinstance(v, torch.Tensor) else torch.cat(v) test[k] = tensor.detach().cpu().numpy().squeeze()[:other["length"][0]] # except: # for using "step" # test[k] = v if is_plot: visualizer = AttentionVisualizer(task_path) visualizer(src_str) return out_words, other, test
def test_predict(self): predictor = Predictor(self.seq2seq, self.dataset.input_vocab, self.dataset.output_vocab) src_seq = ["I", "am", "fat"] tgt_seq = predictor.predict(src_seq) for tok in tgt_seq: self.assertTrue(tok in self.dataset.output_vocab._token2index)
class TestPredictor(unittest.TestCase): @classmethod def setUpClass(self): test_path = os.path.dirname(os.path.realpath(__file__)) src = SourceField() trg = TargetField() dataset = torchtext.data.TabularDataset( path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv', fields=[('src', src), ('trg', trg)], ) src.build_vocab(dataset) trg.build_vocab(dataset) encoder = EncoderRNN(len(src.vocab), 10, 10, rnn_cell='lstm') decoder = DecoderRNN(len(trg.vocab), 10, 10, trg.sos_id, trg.eos_id, rnn_cell='lstm') seq2seq = Seq2seq(encoder, decoder) self.predictor = Predictor(seq2seq, src.vocab, trg.vocab) def test_predict(self): src_seq = "I am fat" tgt_seq = self.predictor.predict(src_seq.split(' ')) for tok in tgt_seq: self.assertTrue(tok in self.predictor.tgt_vocab.stoi)
def predict_with_checkpoint(checkpoint_path, sequence, hierarchial = False, remote = None, word_vectors = None): checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab seq2seq.encoder.word_vectors, seq2seq.decoder.word_vectors = None, None if word_vectors != None: input_vects = Word2Vectors(input_vocab, word_vectors, word_vectors.dim_size) output_vects = Word2Vectors(output_vocab, word_vectors, word_vectors.dim_size) seq2seq.encoder.word_vectors, seq2seq.decoder.word_vectors = input_vects, output_vects seq2seq.decoder = TopKDecoder(seq2seq.decoder, 5) if not hierarchial: predictor = Predictor(seq2seq, input_vocab, output_vocab) seq = sequence.strip().split() else: predictor = HierarchialPredictor(seq2seq, input_vocab, output_vocab) seq = ['|'.join(x.split()) for x in sequence] return ' '.join(predictor.predict(seq))
class eval_tool: def __init__(self, ckpt_path='./Res/PretrainModel/2019_12_27_08_48_21/'): checkpoint = Checkpoint.load(ckpt_path) self.seq2seq = checkpoint.model self.input_vocab = checkpoint.input_vocab self.output_vocab = checkpoint.output_vocab self.predictor = Predictor(self.seq2seq, self.input_vocab, self.output_vocab) def predict(self, input_str): return self.predictor.predict(input_str.strip().split())
def evaluate_model(model, data, src_field, tgt_field, file_props={}): predictor = Predictor(model, src_field.vocab, tgt_field.vocab) data["pred_lemma"] = [ "".join(predictor.predict(list(e.word))[:-1]) for e in data.itertuples() ] acc = 0 for word in data.itertuples(): acc += int(word.pred_lemma == word.lemma) acc /= len(data.lemma) EXPERIMENT.metric("Dev accuracy", acc) data.to_csv("./dev_{}.csv".format("-".join("{}={}".format(k, v) for k, v in file_props))) EXPERIMENT.log("Incorrect predictions") EXPERIMENT.log( str(data[data["lemma"] != data["pred_lemma"]][[ "word", "lemma", "pred_lemma" ]]))
def test(expt_dir, checkpoint, test_file, output_file): if checkpoint is not None: checkpoint_path = os.path.join(expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, checkpoint) logging.info("loading checkpoint from {}".format(checkpoint_path)) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: raise Exception("checkpoint path does not exist") predictor = Predictor(seq2seq, input_vocab, output_vocab) output = open(output_file, 'ab') with open(test_file) as f: for line_ in f: line = line_.strip().split('<s>') if len(line) != 0: question = basic_tokenizer(line[-2]) answer = predictor.predict(question)[:-1] output.write(''.join(answer) + '\n')
best_model_dir=opt.best_model_dir, batch_size=opt.batch_size, checkpoint_every=opt.checkpoint_every, print_every=opt.print_every, max_epochs=opt.max_epochs, max_steps=opt.max_steps, max_checkpoints_num=opt.max_checkpoints_num, best_ppl=opt.best_ppl, device=device, multi_gpu=multi_gpu, logger=logger) seq2seq = t.train(seq2seq, data=train, start_step=opt.skip_steps, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=opt.teacher_forcing_ratio) elif opt.phase == "infer": # Predict predictor = Predictor(seq2seq, src_vocab.word2idx, tgt_vocab.idx2word, device) while True: seq_str = input("Type in a source sequence:") seq = seq_str.strip().split() ans = predictor.predict_n(seq, n=opt.beam_width) \ if opt.beam_width > 1 else predictor.predict(seq) print(ans)
# train t = SupervisedTrainer( loss=loss, batch_size=32, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir, ) seq2seq = t.train( seq2seq, train, num_epochs=6, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume, ) evaluator = Evaluator(loss=loss, batch_size=32) dev_loss, accuracy = evaluator.evaluate(seq2seq, dev) assert dev_loss < 1.5 beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 3)) predictor = Predictor(beam_search, input_vocab, output_vocab) inp_seq = "1 3 5 7 9" seq = predictor.predict(inp_seq.split()) assert " ".join(seq[:-1]) == inp_seq[::-1]
# Optimizer and learning rate scheduler can be customized by # explicitly constructing the objects and pass to the trainer. # # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) # scheduler = StepLR(optimizer.optimizer, 1) # optimizer.set_scheduler(scheduler) # train t = SupervisedTrainer(loss=loss, batch_size=params['batch_size'], checkpoint_every=50, print_every=25, expt_dir=opt.expt_dir, tensorboard=True) seq2seq = t.train(seq2seq, train, num_epochs=params['num_epochs'], dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) predictor = Predictor(seq2seq, input_vocab, output_vocab) while True: seq_str = raw_input("Type in a source sequence: ") seq = seq_str.strip().split() print(' '.join((predictor.predict(seq))))
#sen[max_len,batch_size] a = [] for i in range(len(sen)): phrase = "" for j in range(len(sen[i])): if sen[i][j] != "<eos>": #print("printing word :",sen[i][j]) if sen[i][j] == "." or sen[i][j] == ",": phrase = phrase + sen[i][j] else: phrase = phrase + " " + sen[i][j] else: a.append(phrase) break return a for i in range(len(data)): seq_str = data.iloc[i]["src"] print(seq_str) seq = seq_str.strip().split() pred.append(predictor.predict(seq)) print(pred) pred_target = sentence_gen(pred) print(len(pred_target)) pred_target = pd.DataFrame(pred_target) pred_target.columns = ["pred"] pred_target.to_csv("output.csv", sep=",")
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='sum_file.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument('-src', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-vocab', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-output', default='pred.txt', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case, preprocess_settings.mode) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) # prepare model device = torch.device('cuda' if opt.cuda else 'cpu') checkpoint = torch.load(opt.model) model_opt = checkpoint['settings'] model_opt.bidirectional = True encoder = EncoderRNN(model_opt.src_vocab_size, model_opt.max_token_seq_len, model_opt.d_model, bidirectional=model_opt.bidirectional, variable_lengths=True) decoder = DecoderRNN(model_opt.tgt_vocab_size, model_opt.max_token_seq_len, model_opt.d_model * 2 if model_opt.bidirectional else model_opt.d_model, n_layers=model_opt.n_layer, dropout_p=model_opt.dropout, use_attention=True, bidirectional=model_opt.bidirectional, eos_id=Constants.BOS, sos_id=Constants.EOS) model = Seq2seq(encoder, decoder).to(device) model = nn.DataParallel(model) # using Dataparallel because training used model.load_state_dict(checkpoint['model']) print('[Info] Trained model state loaded.') predictor = Predictor(model, preprocess_data['dict']['tgt']) with open(opt.output, 'w') as f: for src_seq in tqdm(test_src_insts, mininterval=2, desc=' - (Test)', leave=False): pred_line = ' '.join(predictor.predict(src_seq)) f.write(pred_line + '\n') print('[Info] Finished.')
def sample( # train_source, # train_target, # dev_source, # dev_target, experiment_directory='/home/xweiwang/RL/seq2seq/experiment', checkpoint='2019_05_18_20_32_54', resume=True, log_level='info', ): """ # Sample usage TRAIN_SRC=data/toy_reverse/train/src.txt TRAIN_TGT=data/toy_reverse/train/tgt.txt DEV_SRC=data/toy_reverse/dev/src.txt DEV_TGT=data/toy_reverse/dev/tgt.txt ## Training ```shell $ ./examples/sample.py $TRAIN_SRC $TRAIN_TGT $DEV_SRC $DEV_TGT -expt $EXPT_PATH ``` ## Resuming from the latest checkpoint of the experiment ```shell $ ./examples/sample.py $TRAIN_SRC $TRAIN_TGT $DEV_SRC $DEV_TGT -expt $EXPT_PATH -r ``` ## Resuming from a specific checkpoint ```shell $ python examples/sample.py $TRAIN_SRC $TRAIN_TGT $DEV_SRC $DEV_TGT -expt $EXPT_PATH -c $CHECKPOINT_DIR ``` """ logging.basicConfig( format=LOG_FORMAT, level=getattr(logging, log_level.upper()), ) # logging.info('train_source: %s', train_source) # logging.info('train_target: %s', train_target) # logging.info('dev_source: %s', dev_source) # logging.info('dev_target: %s', dev_target) logging.info('experiment_directory: %s', experiment_directory) logging.info('checkpoint: %s', checkpoint) # if checkpoint: seq2seq, input_vocab, output_vocab = load_checkpoint( experiment_directory, checkpoint) # else: # seq2seq, input_vocab, output_vocab = train_model( # train_source, # train_target, # dev_source, # dev_target, # experiment_directory, # resume=resume, # ) predictor = Predictor(seq2seq, input_vocab, output_vocab) while True: seq_str = input('Type in a source sequence: ') seq = seq_str.strip().split() print(predictor.predict(seq))
def eval_fa_equiv(model, data, input_vocab, output_vocab): loss = NLLLoss() batch_size = 1 model.eval() loss.reset() match = 0 total = 0 device = None if torch.cuda.is_available() else -1 batch_iterator = torchtext.data.BucketIterator( dataset=data, batch_size=batch_size, sort=False, sort_key=lambda x: len(x.src), device=device, train=False) tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token] predictor = Predictor(model, input_vocab, output_vocab) num_samples = 0 perfect_samples = 0 dfa_perfect_samples = 0 match = 0 total = 0 with torch.no_grad(): for batch in batch_iterator: num_samples = num_samples + 1 input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) target_variables = getattr(batch, seq2seq.tgt_field_name) target_string = decode_tensor(target_variables, output_vocab) #target_string = target_string + " <eos>" input_string = decode_tensor(input_variables, input_vocab) generated_string = ' '.join([ x for x in predictor.predict(input_string.strip().split())[:-1] if x != '<pad>' ]) #str(pos_example)[2] generated_string = refine_outout(generated_string) #str(pos_example)[2] pos_example = subprocess.check_output([ 'python2', 'regexDFAEquals.py', '--gold', '{}'.format(target_string), '--predicted', '{}'.format(generated_string) ]) if target_string == generated_string: perfect_samples = perfect_samples + 1 dfa_perfect_samples = dfa_perfect_samples + 1 elif str(pos_example)[2] == '1': dfa_perfect_samples = dfa_perfect_samples + 1 target_tokens = target_string.split() generated_tokens = generated_string.split() shorter_len = min(len(target_tokens), len(generated_tokens)) for idx in range(len(generated_tokens)): total = total + 1 if idx >= len(target_tokens): total = total + 1 elif target_tokens[idx] == generated_tokens[idx]: match = match + 1 if total == 0: accuracy = float('nan') else: accuracy = match / total string_accuracy = perfect_samples / num_samples dfa_accuracy = dfa_perfect_samples / num_samples f = open('./time_logs/log_score_time.txt', 'a') f.write('{}\n'.format(dfa_accuracy)) f.close()
parser.add_argument('--log-level', dest='log_level', default='info', help='Logging level.') opt = parser.parse_args() LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper())) logging.info(opt) if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab predictor = Predictor(seq2seq, input_vocab, output_vocab) while True: seq_str = raw_input("Type in a source sequence:") seq = seq_str.strip().split() print(predictor.predict(seq))
# topk_predictor = Predictor(seq2top, input_vocab, output_vocab, vectors) if config['pull embeddings']: out_vecs = {} if config['feat embeddings']: feats = {} of = open(config['feat output'], 'wb') # TODO add option to save output src = SourceField() feat = SourceField() tgt = TargetField() # pdb.set_trace() for key in tqdm(input_vocab.freqs.keys()): try: guess, enc_out = predictor.predict([key]) except: print("guess, enc_out = predictor.predict([key]) didn't work") pdb.set_trace() # TODO first try averaging # (Pdb) # test[3].mean(-1).shape # torch.Size([1, 13]) # (Pdb) # test[3].mean(-2).shape # torch.Size([1, 600]) feats[key] = {} feats[key]['src'] = key feats[key]['tgt'] = key feats[key]['guess'] = ''.join(guess) feats[key]['embed'] = enc_out
# # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) # scheduler = StepLR(optimizer.optimizer, 1) # optimizer.set_scheduler(scheduler) # train t = SupervisedTrainer(loss=loss, batch_size=batch_size, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=num_epochs, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) predictor = Predictor(seq2seq, input_vocab) while True: seq_str = raw_input("Type in a source sequence:") seq_1 = [first_field.SYM_SOS ] + seq_str.strip().split() + [first_field.SYM_EOS] seq_str = raw_input("Type in a source sequence:") seq_2 = [first_field.SYM_SOS ] + seq_str.strip().split() + [first_field.SYM_EOS] print(predictor.predict([seq_1, seq_2]))
seqs_x.append(seq_x) POSs.append(POS) rhythms.append(rhythm) lengths.append(length) return seqs_x, lengths, POSs, rhythms predictor = Predictor(seq2seq, input_vocab, output_vocab) seqs_x, lengths, POSs, rhythms = read_dev(opt.dev_path) preds = [] for i, seq in enumerate(seqs_x): print(i) seq = seq.strip().split() pred = predictor.predict(seq) preds.append(pred) with open(opt.output_path, 'w', encoding='utf8') as f: for pred in preds: if len(pred) == 3: row = ['我'] else: row = pred[1:-2] for i in range(len(row)): if row[i] == '<unk>': row[i] = '我' f.write("%s\n" % (' '.join(row)))
with torch.no_grad(): for batch in batch_iterator: num_samples = num_samples + 1 input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) target_variables = getattr(batch, seq2seq.tgt_field_name) target_string = decode_tensor(target_variables, output_vocab) input_string = decode_tensor(input_variables, input_vocab) generated_string = ' '.join([x for x in predictor.predict(input_string.strip().split())[:-1] if x != '<pad>']) print("Input string: ", input_string) print("Targ : ", target_string) print("Pred : ", refine_outout(generated_string)) generated_string = refine_outout(generated_string) pos_example = subprocess.check_output(['python2', 'regexDFAEquals.py', '--gold', '{}'.format(target_string), '--predicted', '{}'.format(generated_string)]) if target_string == generated_string: perfect_samples = perfect_samples + 1 dfa_perfect_samples = dfa_perfect_samples + 1 print('String Equivalent')
# optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) # scheduler = StepLR(optimizer.optimizer, 1) # optimizer.set_scheduler(scheduler) # train t = SupervisedTrainer(loss=loss, batch_size=32, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=102, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) predictor = Predictor(seq2seq, input_vocab, output_vocab) while True: seq_str = raw_input("Type in a source sequence:") seq = seq_str.strip().split() prediction = predictor.predict(seq) for ind, x in enumerate(prediction): if x == 'B': print(seq[ind], " ") print(predictor.predict(seq))
seq2seq = Seq2seq(encoder, decoder) if torch.cuda.is_available(): seq2seq.cuda() for param in seq2seq.parameters(): param.data.uniform_(-0.08, 0.08) # train t = SupervisedTrainer(loss=loss, batch_size=10000, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=10, dev_data=validation, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) predictor = Predictor(seq2seq, input_vocab, output_vocab) while True: sentence = raw_input("Type in a source sequence:") words = sentence_to_words(sentence) print(words) print(predictor.predict(words))
t = SupervisedTrainer(loss=loss, batch_size=32, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=40, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) predictor = Predictor(seq2seq, input_vocab, output_vocab) #while True: #seq_str = raw_input("Type in a source sequence:") #seq = seq_str.strip().split() #print(predictor.predict(seq)) with open(opt.test_path) as f: content = f.readlines() content = [x.strip() for x in content] output = "" for row in content: seq_in = row.split("\t")[0] seq_out = row.split("\t")[1] seq_pred = predictor.predict(seq_in.strip().split()) seq_pred = " ".join(seq_pred[:-1]) output += seq_out + "\t" + seq_pred + "\n" output_file = "seq_pred.txt" text_train = open(output_file, "w") text_train.write(output) text_train.close()
class AttentionVisualizer(object): """Object for visualizing the attention pattern of a given prediction. Args: task_path (str): name of the checkpoint file. figsizeh (tuple, optional): (width, height) of the final matplotlib figure. decimals (int, optional): number of decimals to whoe when pritning any number. is_show_attn_split (bool, optional): whether to show the the content and positional attention if there is one in addition to the full attention. is_show_evaluation (bool, optional): whether to show the evaluation metric if the target is given. output_length_key, attention_key, content_attn_key (str, optional): keys of the respective values in the the dictionary returned by the prediction. positional_table_labels (dictionary, optional): mapping from the keys in the return dictionary (the values) to the name the name of it should be shown as in the figure (the keys). The order is the one that will be used to plot the table (in python > 3.6). is_show_name (bool, optional): whether to show the name of the mdoel as the title of the figure. max_src, max_out, max_tgt (int, optional): maximum number of token to show for the source, the output and the target. Used in order not to clotter too much the plots. kwargs: Additional arguments to `MetricComputer`. """ def __init__( self, task_path, figsize=(15, 13), decimals=2, is_show_attn_split=True, is_show_evaluation=True, output_length_key='length', attention_key="attention_score", position_attn_key='position_attention', content_attn_key='content_attention', positional_table_labels={ "λ%": "position_percentage", "C.γ": "content_confidence", #"lgt": "approx_max_logit", "C.λ": "pos_confidence", "μ": "mu", "σ": "sigma", "w_α": "mean_attn_old_weight", "w_j/n": "rel_counter_decoder_weight", "w_1/n": "single_step_weight", "w_μ": "mu_old_weight", "w_γ": "mean_content_old_weight", "w_1": "bias_weight" }, # "% carry": "carry_rates", is_show_name=True, max_src=17, max_out=13, max_tgt=13, **kwargs): check = Checkpoint.load(task_path) self.model = check.model # store some interesting variables self.model.set_dev_mode() self.predictor = Predictor(self.model, check.input_vocab, check.output_vocab) self.model_name = task_path.split("/")[-2] self.figsize = figsize self.decimals = decimals self.is_show_attn_split = is_show_attn_split self.is_show_evaluation = is_show_evaluation self.positional_table_labels = positional_table_labels self.is_show_name = is_show_name self.max_src = max_src self.max_out = max_out self.max_tgt = max_tgt self.output_length_key = output_length_key self.attention_key = attention_key self.position_attn_key = position_attn_key self.content_attn_key = content_attn_key if self.is_show_evaluation: self.is_symbol_rewriting = "symbol rewriting" in task_path.lower() self.metric_computer = MetricComputer( check, is_symbol_rewriting=self.is_symbol_rewriting, **kwargs) if self.model.decoder.is_attention is None: raise AttentionException("Model is not using attention.") def __call__(self, src_str, tgt_str=None): """Plots the attention for the current example. Args: src_str (str): source of the example. tgt_str (str, optional): (width, height) target of the example, must be given in order to show the final metric. Returns: fig (plt.Figure): plotted attention figure. """ out_words, other = self.predictor.predict(src_str.split()) full_src_str = src_str full_out_str = " ".join(out_words) full_tgt_str = tgt_str additional, additional_text = self._format_additional(other) additional, src_words, out_words, tgt_str = self._subset( additional, src_str.split(), out_words, tgt_str) if self.is_show_name: title = "" else: title = None if tgt_str is not None: if self.is_show_name: title += "\n tgt_str: {} - ".format(tgt_str) else: title = "tgt_str: {} - ".format(tgt_str) if self.metric_computer.is_predict_eos: is_output_good_length = (len(full_out_str.split()) != len( full_tgt_str.split())) if self.is_symbol_rewriting and is_output_good_length: warnings.warn( "Cannot currently show the metric for symbol rewriting if output is not the right length." ) else: metrics = self.metric_computer(full_src_str, full_out_str, full_tgt_str) for name, val in metrics.items(): title += "{}: {:.2g} ".format(name, val) else: warnings.warn( "Cannot currently show the metric in the attention plots when `is_predict_eos=False`" ) if self.attention_key not in additional: raise ValueError( "`{}` not returned by predictor. Make sure the model uses attention." .format(self.attention_key)) attention = additional[self.attention_key] if self.position_attn_key in additional: filtered_pos_table_labels = { k: v for k, v in self.positional_table_labels.items() if v in additional } table_values = np.stack([ np.around(additional[name], decimals=self.decimals) for name in filtered_pos_table_labels.values() ]).T if self.is_show_attn_split and (self.position_attn_key in additional and self.content_attn_key in additional): content_attention = additional.get(self.content_attn_key) positional_attention = additional.get(self.position_attn_key) fig, axs = plt.subplots(2, 2, figsize=self.figsize) _plot_attention(src_words, out_words, attention, axs[0, 0], is_colorbar=False, title="Final Attention") _plot_table(table_values, list(filtered_pos_table_labels.keys()), axs[0, 1]) _plot_attention(src_words, out_words, content_attention, axs[1, 0], title="Content Attention") _plot_attention(src_words, out_words, positional_attention, axs[1, 1], title="Positional Attention") elif self.position_attn_key in additional: fig, axs = plt.subplots(1, 2, figsize=self.figsize) _plot_attention(src_words, out_words, attention, axs[0], title="Final Attention") _plot_table(table_values, list(filtered_pos_table_labels.keys()), axs[1]) else: fig, ax = plt.subplots(1, 1, figsize=self.figsize) _plot_attention(src_words, out_words, attention, ax, title="Final Attention") fig.text(0.5, 0.02, ' | '.join(additional_text), ha='center', va='center', size=13) if title is not None: plt.suptitle(title, size=13, weight="bold") fig.tight_layout() fig.subplots_adjust(bottom=0.07, top=0.83) return fig def _format_additional(self, additional): """Format the additinal dictionary returned by the predictor.""" def _format_carry_rates(carry_rates): if carry_rates is None: return "Carry % : None" mean_carry_rates = np.around(carry_rates.mean().item(), decimals=self.decimals) median_carry_rates = np.around(carry_rates.median().item(), decimals=self.decimals) return "Carry % : mean: {}; median: {}".format( mean_carry_rates, median_carry_rates) def _format_bb_gates(gates): if gates is None: return "BB Weight Mean Gates : None" mean_gates = np.around(gates.mean(0), decimals=self.decimals) return "BB Weight Mean Gates : {}".format(mean_gates) def _format_mu_weights(mu_weights): if mu_weights is not None: building_blocks_labels = self.model.decoder.position_attention.bb_labels for i, label in enumerate(building_blocks_labels): output[label + "_weight"] = mu_weights[:, i] output = dict() additional.pop( "visualize", None) # this is only for training visualization not predict additional.pop("losses", None) additional_text = [] additional = flatten_dict(additional) output = dict() output[self.output_length_key] = additional.pop( self.output_length_key)[0] for k, v in additional.items(): tensor = v if isinstance(v, torch.Tensor) else torch.cat(v) output[k] = tensor.detach().cpu().numpy().squeeze( )[:output[self.output_length_key]] carry_txt = _format_carry_rates(additional.pop("carry_rates", None)) bb_gates_txt = _format_bb_gates(output.pop("bb_gates", None)) additional_text.append(carry_txt) additional_text.append(bb_gates_txt) _format_mu_weights(output.pop("mu_weights", None)) return output, additional_text def _subset(self, additional, src_words, out_words, tgt_str=None): """Subsets the objects in the additional dictionary in order not to clotter the visualization. """ n_src = len(src_words) n_out = len(out_words) if n_out > self.max_out: subset_out = self.max_out // 2 out_words = out_words[:subset_out] + out_words[-subset_out:] for k, v in additional.items(): if isinstance(v, np.ndarray): additional[k] = np.concatenate( (v[:subset_out], v[-subset_out:]), axis=0) if n_src > self.max_src: subset_src = self.max_src // 2 src_words = src_words[:subset_src] + src_words[-subset_src:] for k, v in additional.items(): if isinstance(v, np.ndarray) and v.ndim == 2: additional[k] = np.concatenate( (v[:, :subset_src], v[:, -subset_src:]), axis=1) if tgt_str is not None: tgt_words = tgt_str.split() n_tgt = len(tgt_words) if n_tgt > self.max_tgt: subset_target = self.max_tgt // 2 tgt_str = " ".join(tgt_words[:subset_target] + ["..."] + tgt_words[-subset_target:]) return additional, src_words, out_words, tgt_str
help='Logging level.') args = parser.parse_args() LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, args.log_level.upper())) logging.info(args) logging.info("loading checkpoint from {}".format(args.trained_model_dir)) checkpoint_path = args.trained_model_dir checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab predictor = Predictor(seq2seq, input_vocab, output_vocab) with open(args.text_path, mode='r', encoding='utf-8') as file: file.readline() text = file.read().replace('\n', '') sentences = nltk.sent_tokenize(text) results = [] for sentence in sentences: words = sentence_to_words(sentence) result = predictor.predict(words) result.remove('<eos>') if result: results.extend(result) print("\n".join(results))