Ejemplo n.º 1
0
def dev_predict(task_path, src_str, is_plot=True):
    """Helper used to visualize and understand why and what the model predicts.

    Args:
        task_path (str): path to the saved task directory containing, amongst
            other, the model.
        src_str (str): source sentence that will be used to predict.
        is_plot (bool, optional): whether to plots the attention pattern.

    Returns:
        out_words (list): decoder predictions.
        other (dictionary): additional information used for predictions.
        test (dictionary): additional information that is only stored in dev mode.
            These can include temporary variables that do not have to be stored in
            `other` but that can still be interesting to inspect.
    """
    check = Checkpoint.load(task_path)
    check.model.set_dev_mode()

    predictor = Predictor(check.model, check.input_vocab, check.output_vocab)
    out_words, other = predictor.predict(src_str.split())

    test = dict()

    for k, v in other["test"].items():
        tensor = v if isinstance(v, torch.Tensor) else torch.cat(v)
        test[k] = tensor.detach().cpu().numpy().squeeze()[:other["length"][0]]
        # except: # for using "step"
        # test[k] = v

    if is_plot:
        visualizer = AttentionVisualizer(task_path)
        visualizer(src_str)

    return out_words, other, test
 def test_predict(self):
     predictor = Predictor(self.seq2seq,
             self.dataset.input_vocab, self.dataset.output_vocab)
     src_seq = ["I", "am", "fat"]
     tgt_seq = predictor.predict(src_seq)
     for tok in tgt_seq:
         self.assertTrue(tok in self.dataset.output_vocab._token2index)
Ejemplo n.º 3
0
class TestPredictor(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        trg = TargetField()
        dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('trg', trg)],
        )
        src.build_vocab(dataset)
        trg.build_vocab(dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(trg.vocab),
                             10,
                             10,
                             trg.sos_id,
                             trg.eos_id,
                             rnn_cell='lstm')
        seq2seq = Seq2seq(encoder, decoder)
        self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)

    def test_predict(self):
        src_seq = "I am fat"
        tgt_seq = self.predictor.predict(src_seq.split(' '))
        for tok in tgt_seq:
            self.assertTrue(tok in self.predictor.tgt_vocab.stoi)
Ejemplo n.º 4
0
def predict_with_checkpoint(checkpoint_path,
                            sequence,
                            hierarchial = False,
                            remote = None,
                            word_vectors = None):
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    
    seq2seq.encoder.word_vectors, seq2seq.decoder.word_vectors = None, None
    if word_vectors != None:
        input_vects = Word2Vectors(input_vocab, word_vectors, word_vectors.dim_size)
        output_vects = Word2Vectors(output_vocab, word_vectors, word_vectors.dim_size)
        seq2seq.encoder.word_vectors, seq2seq.decoder.word_vectors = input_vects, output_vects

        
    seq2seq.decoder = TopKDecoder(seq2seq.decoder, 5)


        
    if not hierarchial:
        predictor = Predictor(seq2seq, input_vocab, output_vocab)
        seq = sequence.strip().split()
    else:
        predictor = HierarchialPredictor(seq2seq, input_vocab, output_vocab)
        seq = ['|'.join(x.split()) for x in sequence]


    return ' '.join(predictor.predict(seq))
Ejemplo n.º 5
0
class eval_tool:
    def __init__(self, ckpt_path='./Res/PretrainModel/2019_12_27_08_48_21/'):
        checkpoint = Checkpoint.load(ckpt_path)
        self.seq2seq = checkpoint.model
        self.input_vocab = checkpoint.input_vocab
        self.output_vocab = checkpoint.output_vocab
        self.predictor = Predictor(self.seq2seq, self.input_vocab,
                                   self.output_vocab)

    def predict(self, input_str):
        return self.predictor.predict(input_str.strip().split())
Ejemplo n.º 6
0
def evaluate_model(model, data, src_field, tgt_field, file_props={}):
    predictor = Predictor(model, src_field.vocab, tgt_field.vocab)
    data["pred_lemma"] = [
        "".join(predictor.predict(list(e.word))[:-1])
        for e in data.itertuples()
    ]
    acc = 0
    for word in data.itertuples():
        acc += int(word.pred_lemma == word.lemma)
    acc /= len(data.lemma)
    EXPERIMENT.metric("Dev accuracy", acc)
    data.to_csv("./dev_{}.csv".format("-".join("{}={}".format(k, v)
                                               for k, v in file_props)))
    EXPERIMENT.log("Incorrect predictions")
    EXPERIMENT.log(
        str(data[data["lemma"] != data["pred_lemma"]][[
            "word", "lemma", "pred_lemma"
        ]]))
Ejemplo n.º 7
0
def test(expt_dir, checkpoint, test_file, output_file):
    if checkpoint is not None:
        checkpoint_path = os.path.join(expt_dir,
                                       Checkpoint.CHECKPOINT_DIR_NAME,
                                       checkpoint)
        logging.info("loading checkpoint from {}".format(checkpoint_path))
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab
    else:
        raise Exception("checkpoint path does not exist")

    predictor = Predictor(seq2seq, input_vocab, output_vocab)

    output = open(output_file, 'ab')

    with open(test_file) as f:
        for line_ in f:
            line = line_.strip().split('<s>')
            if len(line) != 0:
                question = basic_tokenizer(line[-2])
                answer = predictor.predict(question)[:-1]
                output.write(''.join(answer) + '\n')
Ejemplo n.º 8
0
                              best_model_dir=opt.best_model_dir,
                              batch_size=opt.batch_size,
                              checkpoint_every=opt.checkpoint_every,
                              print_every=opt.print_every,
                              max_epochs=opt.max_epochs,
                              max_steps=opt.max_steps,
                              max_checkpoints_num=opt.max_checkpoints_num,
                              best_ppl=opt.best_ppl,
                              device=device,
                              multi_gpu=multi_gpu,
                              logger=logger)

        seq2seq = t.train(seq2seq,
                          data=train,
                          start_step=opt.skip_steps,
                          dev_data=dev,
                          optimizer=optimizer,
                          teacher_forcing_ratio=opt.teacher_forcing_ratio)

    elif opt.phase == "infer":
        # Predict
        predictor = Predictor(seq2seq, src_vocab.word2idx, tgt_vocab.idx2word,
                              device)

        while True:
            seq_str = input("Type in a source sequence:")
            seq = seq_str.strip().split()
            ans = predictor.predict_n(seq, n=opt.beam_width) \
                if opt.beam_width > 1 else predictor.predict(seq)
            print(ans)
Ejemplo n.º 9
0
    # train
    t = SupervisedTrainer(
        loss=loss,
        batch_size=32,
        checkpoint_every=50,
        print_every=10,
        expt_dir=opt.expt_dir,
    )

    seq2seq = t.train(
        seq2seq,
        train,
        num_epochs=6,
        dev_data=dev,
        optimizer=optimizer,
        teacher_forcing_ratio=0.5,
        resume=opt.resume,
    )

evaluator = Evaluator(loss=loss, batch_size=32)
dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
assert dev_loss < 1.5

beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 3))

predictor = Predictor(beam_search, input_vocab, output_vocab)
inp_seq = "1 3 5 7 9"
seq = predictor.predict(inp_seq.split())
assert " ".join(seq[:-1]) == inp_seq[::-1]
Ejemplo n.º 10
0
        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.
        #
        # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        # scheduler = StepLR(optimizer.optimizer, 1)
        # optimizer.set_scheduler(scheduler)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=params['batch_size'],
                          checkpoint_every=50,
                          print_every=25,
                          expt_dir=opt.expt_dir,
                          tensorboard=True)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=params['num_epochs'],
                      dev_data=dev,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)

predictor = Predictor(seq2seq, input_vocab, output_vocab)

while True:
    seq_str = raw_input("Type in a source sequence: ")
    seq = seq_str.strip().split()
    print(' '.join((predictor.predict(seq))))
Ejemplo n.º 11
0
    #sen[max_len,batch_size]
    a = []

    for i in range(len(sen)):
        phrase = ""
        for j in range(len(sen[i])):
            if sen[i][j] != "<eos>":
                #print("printing word :",sen[i][j])
                if sen[i][j] == "." or sen[i][j] == ",":
                    phrase = phrase + sen[i][j]
                else:
                    phrase = phrase + " " + sen[i][j]
            else:
                a.append(phrase)
                break
    return a


for i in range(len(data)):
    seq_str = data.iloc[i]["src"]
    print(seq_str)
    seq = seq_str.strip().split()
    pred.append(predictor.predict(seq))

print(pred)
pred_target = sentence_gen(pred)
print(len(pred_target))
pred_target = pd.DataFrame(pred_target)

pred_target.columns = ["pred"]
pred_target.to_csv("output.csv", sep=",")
Ejemplo n.º 12
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='sum_file.py')

    parser.add_argument('-model', required=True,
                        help='Path to model .pt file')
    parser.add_argument('-src', required=True,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-vocab', required=True,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output', default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30,
                        help='Batch size')
    parser.add_argument('-n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case,
        preprocess_settings.mode)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    # prepare model
    device = torch.device('cuda' if opt.cuda else 'cpu')
    checkpoint = torch.load(opt.model)
    model_opt = checkpoint['settings']
    
    model_opt.bidirectional = True
    encoder = EncoderRNN(model_opt.src_vocab_size, model_opt.max_token_seq_len, model_opt.d_model,
                            bidirectional=model_opt.bidirectional, variable_lengths=True)
    decoder = DecoderRNN(model_opt.tgt_vocab_size, model_opt.max_token_seq_len, model_opt.d_model * 2 if model_opt.bidirectional else model_opt.d_model,
                            n_layers=model_opt.n_layer, dropout_p=model_opt.dropout, use_attention=True, bidirectional=model_opt.bidirectional,
                            eos_id=Constants.BOS, sos_id=Constants.EOS)
    model = Seq2seq(encoder, decoder).to(device)
    model = nn.DataParallel(model) # using Dataparallel because training used

    model.load_state_dict(checkpoint['model'])
    print('[Info] Trained model state loaded.')

    predictor = Predictor(model, preprocess_data['dict']['tgt'])

    with open(opt.output, 'w') as f:
        for src_seq in tqdm(test_src_insts, mininterval=2, desc='  - (Test)', leave=False):
            pred_line = ' '.join(predictor.predict(src_seq))
            f.write(pred_line + '\n')
    print('[Info] Finished.')
Ejemplo n.º 13
0
def sample(
    #         train_source,
    #         train_target,
    #         dev_source,
    #         dev_target,
    experiment_directory='/home/xweiwang/RL/seq2seq/experiment',
    checkpoint='2019_05_18_20_32_54',
    resume=True,
    log_level='info',
):
    """
    # Sample usage

        TRAIN_SRC=data/toy_reverse/train/src.txt
        TRAIN_TGT=data/toy_reverse/train/tgt.txt
        DEV_SRC=data/toy_reverse/dev/src.txt
        DEV_TGT=data/toy_reverse/dev/tgt.txt

    ## Training
    ```shell
    $ ./examples/sample.py $TRAIN_SRC $TRAIN_TGT $DEV_SRC $DEV_TGT -expt
    $EXPT_PATH
    ```
    ## Resuming from the latest checkpoint of the experiment
    ```shell
    $ ./examples/sample.py $TRAIN_SRC $TRAIN_TGT $DEV_SRC $DEV_TGT -expt
    $EXPT_PATH -r
    ```
    ## Resuming from a specific checkpoint
    ```shell
    $ python examples/sample.py $TRAIN_SRC $TRAIN_TGT $DEV_SRC $DEV_TGT -expt
    $EXPT_PATH -c $CHECKPOINT_DIR
    ```
    """
    logging.basicConfig(
        format=LOG_FORMAT,
        level=getattr(logging, log_level.upper()),
    )
    #     logging.info('train_source: %s', train_source)
    #     logging.info('train_target: %s', train_target)
    #     logging.info('dev_source: %s', dev_source)
    #     logging.info('dev_target: %s', dev_target)
    logging.info('experiment_directory: %s', experiment_directory)
    logging.info('checkpoint: %s', checkpoint)

    #     if checkpoint:
    seq2seq, input_vocab, output_vocab = load_checkpoint(
        experiment_directory, checkpoint)
    #     else:
    #         seq2seq, input_vocab, output_vocab = train_model(
    #             train_source,
    #             train_target,
    #             dev_source,
    #             dev_target,
    #             experiment_directory,
    #             resume=resume,
    #         )
    predictor = Predictor(seq2seq, input_vocab, output_vocab)
    while True:
        seq_str = input('Type in a source sequence: ')
        seq = seq_str.strip().split()
        print(predictor.predict(seq))
Ejemplo n.º 14
0
def eval_fa_equiv(model, data, input_vocab, output_vocab):
    loss = NLLLoss()
    batch_size = 1

    model.eval()

    loss.reset()
    match = 0
    total = 0

    device = None if torch.cuda.is_available() else -1
    batch_iterator = torchtext.data.BucketIterator(
        dataset=data,
        batch_size=batch_size,
        sort=False,
        sort_key=lambda x: len(x.src),
        device=device,
        train=False)
    tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab
    pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token]

    predictor = Predictor(model, input_vocab, output_vocab)

    num_samples = 0
    perfect_samples = 0
    dfa_perfect_samples = 0

    match = 0
    total = 0

    with torch.no_grad():
        for batch in batch_iterator:
            num_samples = num_samples + 1

            input_variables, input_lengths = getattr(batch,
                                                     seq2seq.src_field_name)

            target_variables = getattr(batch, seq2seq.tgt_field_name)

            target_string = decode_tensor(target_variables, output_vocab)

            #target_string = target_string + " <eos>"

            input_string = decode_tensor(input_variables, input_vocab)

            generated_string = ' '.join([
                x for x in predictor.predict(input_string.strip().split())[:-1]
                if x != '<pad>'
            ])

            #str(pos_example)[2]

            generated_string = refine_outout(generated_string)

            #str(pos_example)[2]

            pos_example = subprocess.check_output([
                'python2', 'regexDFAEquals.py', '--gold',
                '{}'.format(target_string), '--predicted',
                '{}'.format(generated_string)
            ])

            if target_string == generated_string:
                perfect_samples = perfect_samples + 1
                dfa_perfect_samples = dfa_perfect_samples + 1
            elif str(pos_example)[2] == '1':
                dfa_perfect_samples = dfa_perfect_samples + 1

            target_tokens = target_string.split()
            generated_tokens = generated_string.split()

            shorter_len = min(len(target_tokens), len(generated_tokens))

            for idx in range(len(generated_tokens)):
                total = total + 1

                if idx >= len(target_tokens):
                    total = total + 1
                elif target_tokens[idx] == generated_tokens[idx]:
                    match = match + 1

            if total == 0:
                accuracy = float('nan')
            else:
                accuracy = match / total

            string_accuracy = perfect_samples / num_samples
            dfa_accuracy = dfa_perfect_samples / num_samples

        f = open('./time_logs/log_score_time.txt', 'a')
        f.write('{}\n'.format(dfa_accuracy))
        f.close()
Ejemplo n.º 15
0
parser.add_argument('--log-level',
                    dest='log_level',
                    default='info',
                    help='Logging level.')

opt = parser.parse_args()

LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT,
                    level=getattr(logging, opt.log_level.upper()))
logging.info(opt)

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

predictor = Predictor(seq2seq, input_vocab, output_vocab)

while True:
    seq_str = raw_input("Type in a source sequence:")
    seq = seq_str.strip().split()
    print(predictor.predict(seq))
Ejemplo n.º 16
0
# topk_predictor = Predictor(seq2top, input_vocab, output_vocab, vectors)

if config['pull embeddings']:
    out_vecs = {}

if config['feat embeddings']:
    feats = {}
    of = open(config['feat output'], 'wb')
    # TODO add option to save output
    src = SourceField()
    feat = SourceField()
    tgt = TargetField()
    # pdb.set_trace()
    for key in tqdm(input_vocab.freqs.keys()):
        try:
            guess, enc_out = predictor.predict([key])
        except:
            print("guess, enc_out = predictor.predict([key]) didn't work")
            pdb.set_trace()
        # TODO first try averaging
        # (Pdb)
        # test[3].mean(-1).shape
        # torch.Size([1, 13])
        # (Pdb)
        # test[3].mean(-2).shape
        # torch.Size([1, 600])
        feats[key] = {}
        feats[key]['src'] = key
        feats[key]['tgt'] = key
        feats[key]['guess'] = ''.join(guess)
        feats[key]['embed'] = enc_out
Ejemplo n.º 17
0
            #
            # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
            # scheduler = StepLR(optimizer.optimizer, 1)
            # optimizer.set_scheduler(scheduler)

        # train
        t = SupervisedTrainer(loss=loss,
                              batch_size=batch_size,
                              checkpoint_every=50,
                              print_every=10,
                              expt_dir=opt.expt_dir)

        seq2seq = t.train(seq2seq,
                          train,
                          num_epochs=num_epochs,
                          dev_data=dev,
                          optimizer=optimizer,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

    predictor = Predictor(seq2seq, input_vocab)

    while True:
        seq_str = raw_input("Type in a source sequence:")
        seq_1 = [first_field.SYM_SOS
                 ] + seq_str.strip().split() + [first_field.SYM_EOS]
        seq_str = raw_input("Type in a source sequence:")
        seq_2 = [first_field.SYM_SOS
                 ] + seq_str.strip().split() + [first_field.SYM_EOS]
        print(predictor.predict([seq_1, seq_2]))
Ejemplo n.º 18
0
            seqs_x.append(seq_x)
            POSs.append(POS)
            rhythms.append(rhythm)
            lengths.append(length)

    return seqs_x, lengths, POSs, rhythms


predictor = Predictor(seq2seq, input_vocab, output_vocab)
seqs_x, lengths, POSs, rhythms = read_dev(opt.dev_path)

preds = []
for i, seq in enumerate(seqs_x):
    print(i)
    seq = seq.strip().split()
    pred = predictor.predict(seq)
    preds.append(pred)

with open(opt.output_path, 'w', encoding='utf8') as f:
    for pred in preds:
        if len(pred) == 3:
            row = ['我']
        else:
            row = pred[1:-2]
        for i in range(len(row)):
            if row[i] == '<unk>':
                row[i] = '我'

        f.write("%s\n" % (' '.join(row)))
Ejemplo n.º 19
0
with torch.no_grad():
    for batch in batch_iterator:
        num_samples = num_samples + 1
        
        input_variables, input_lengths  = getattr(batch, seq2seq.src_field_name)
        
        target_variables = getattr(batch, seq2seq.tgt_field_name)

        
            
        target_string = decode_tensor(target_variables, output_vocab)

        
        input_string = decode_tensor(input_variables, input_vocab)
        
        generated_string = ' '.join([x for x in predictor.predict(input_string.strip().split())[:-1] if x != '<pad>'])
        
        print("Input string: ", input_string)
        print("Targ   : ", target_string)
        print("Pred   : ", refine_outout(generated_string))

        
        generated_string = refine_outout(generated_string)

        
        pos_example = subprocess.check_output(['python2', 'regexDFAEquals.py', '--gold', '{}'.format(target_string), '--predicted', '{}'.format(generated_string)])

        if target_string == generated_string:
            perfect_samples = perfect_samples + 1
            dfa_perfect_samples = dfa_perfect_samples + 1
            print('String Equivalent')
Ejemplo n.º 20
0
        # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        # scheduler = StepLR(optimizer.optimizer, 1)
        # optimizer.set_scheduler(scheduler)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=32,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=102,
                      dev_data=dev,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)

predictor = Predictor(seq2seq, input_vocab, output_vocab)

while True:
    seq_str = raw_input("Type in a source sequence:")
    seq = seq_str.strip().split()
    prediction = predictor.predict(seq)
    for ind, x in enumerate(prediction):
        if x == 'B':
            print(seq[ind], " ")

    print(predictor.predict(seq))
        seq2seq = Seq2seq(encoder, decoder)
        if torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=10000,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=10,
                      dev_data=validation,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)

predictor = Predictor(seq2seq, input_vocab, output_vocab)

while True:
    sentence = raw_input("Type in a source sequence:")
    words = sentence_to_words(sentence)

    print(words)
    print(predictor.predict(words))
Ejemplo n.º 22
0
    t = SupervisedTrainer(loss=loss, batch_size=32,
                          checkpoint_every=50,
                          print_every=10, expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq, train,
                      num_epochs=40, dev_data=dev,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)

predictor = Predictor(seq2seq, input_vocab, output_vocab)

#while True:
    #seq_str = raw_input("Type in a source sequence:")
    #seq = seq_str.strip().split()
    #print(predictor.predict(seq))
with open(opt.test_path) as f:
    content = f.readlines()
content = [x.strip() for x in content]

output = ""
for row in content:
     seq_in = row.split("\t")[0]
     seq_out = row.split("\t")[1]
     seq_pred = predictor.predict(seq_in.strip().split())
     seq_pred = " ".join(seq_pred[:-1])
     output += seq_out + "\t" + seq_pred + "\n"
output_file = "seq_pred.txt"
text_train = open(output_file, "w")
text_train.write(output)
text_train.close()  
Ejemplo n.º 23
0
class AttentionVisualizer(object):
    """Object for visualizing the attention pattern of a given prediction.

    Args:
        task_path (str): name of the  checkpoint file.
        figsizeh (tuple, optional): (width, height) of the final matplotlib figure.
        decimals (int, optional): number of decimals to whoe when pritning any number.
        is_show_attn_split (bool, optional): whether to show the the content and
            positional attention if there is one in addition to the full attention.
        is_show_evaluation (bool, optional): whether to show the evaluation metric
            if the target is given.
        output_length_key, attention_key, content_attn_key (str, optional): keys of
            the respective values in the the dictionary returned by the prediction.
        positional_table_labels (dictionary, optional): mapping from the keys
            in the return dictionary (the values) to the name the name of it should
            be shown as in the figure (the keys). The order is the one that will
            be used to plot the table (in python > 3.6).
        is_show_name (bool, optional): whether to show the name of the mdoel as
            the title of the figure.
        max_src, max_out, max_tgt (int, optional): maximum number of token to show
            for the source, the output and the target. Used in order not to clotter
            too much the plots.
        kwargs:
            Additional arguments to `MetricComputer`.
    """
    def __init__(
            self,
            task_path,
            figsize=(15, 13),
            decimals=2,
            is_show_attn_split=True,
            is_show_evaluation=True,
            output_length_key='length',
            attention_key="attention_score",
            position_attn_key='position_attention',
            content_attn_key='content_attention',
            positional_table_labels={
                "λ%": "position_percentage",
                "C.γ": "content_confidence",
                #"lgt": "approx_max_logit",
                "C.λ": "pos_confidence",
                "μ": "mu",
                "σ": "sigma",
                "w_α": "mean_attn_old_weight",
                "w_j/n": "rel_counter_decoder_weight",
                "w_1/n": "single_step_weight",
                "w_μ": "mu_old_weight",
                "w_γ": "mean_content_old_weight",
                "w_1": "bias_weight"
            },
            # "% carry": "carry_rates",
            is_show_name=True,
            max_src=17,
            max_out=13,
            max_tgt=13,
            **kwargs):

        check = Checkpoint.load(task_path)
        self.model = check.model
        # store some interesting variables
        self.model.set_dev_mode()

        self.predictor = Predictor(self.model, check.input_vocab,
                                   check.output_vocab)
        self.model_name = task_path.split("/")[-2]
        self.figsize = figsize
        self.decimals = decimals
        self.is_show_attn_split = is_show_attn_split
        self.is_show_evaluation = is_show_evaluation
        self.positional_table_labels = positional_table_labels
        self.is_show_name = is_show_name

        self.max_src = max_src
        self.max_out = max_out
        self.max_tgt = max_tgt

        self.output_length_key = output_length_key
        self.attention_key = attention_key
        self.position_attn_key = position_attn_key
        self.content_attn_key = content_attn_key

        if self.is_show_evaluation:
            self.is_symbol_rewriting = "symbol rewriting" in task_path.lower()
            self.metric_computer = MetricComputer(
                check, is_symbol_rewriting=self.is_symbol_rewriting, **kwargs)

        if self.model.decoder.is_attention is None:
            raise AttentionException("Model is not using attention.")

    def __call__(self, src_str, tgt_str=None):
        """Plots the attention for the current example.

        Args:
            src_str (str): source of the example.
            tgt_str (str, optional): (width, height) target of the example,
                must be given in order to show the final metric.
        Returns:
            fig (plt.Figure): plotted attention figure.
        """
        out_words, other = self.predictor.predict(src_str.split())

        full_src_str = src_str
        full_out_str = " ".join(out_words)
        full_tgt_str = tgt_str

        additional, additional_text = self._format_additional(other)
        additional, src_words, out_words, tgt_str = self._subset(
            additional, src_str.split(), out_words, tgt_str)

        if self.is_show_name:
            title = ""
        else:
            title = None

        if tgt_str is not None:
            if self.is_show_name:
                title += "\n tgt_str: {} - ".format(tgt_str)
            else:
                title = "tgt_str: {} - ".format(tgt_str)

            if self.metric_computer.is_predict_eos:
                is_output_good_length = (len(full_out_str.split()) != len(
                    full_tgt_str.split()))
                if self.is_symbol_rewriting and is_output_good_length:
                    warnings.warn(
                        "Cannot currently show the metric for symbol rewriting if output is not the right length."
                    )

                else:
                    metrics = self.metric_computer(full_src_str, full_out_str,
                                                   full_tgt_str)

                    for name, val in metrics.items():
                        title += "{}: {:.2g}  ".format(name, val)
            else:
                warnings.warn(
                    "Cannot currently show the metric in the attention plots when `is_predict_eos=False`"
                )

        if self.attention_key not in additional:
            raise ValueError(
                "`{}` not returned by predictor. Make sure the model uses attention."
                .format(self.attention_key))

        attention = additional[self.attention_key]

        if self.position_attn_key in additional:
            filtered_pos_table_labels = {
                k: v
                for k, v in self.positional_table_labels.items()
                if v in additional
            }
            table_values = np.stack([
                np.around(additional[name], decimals=self.decimals)
                for name in filtered_pos_table_labels.values()
            ]).T

        if self.is_show_attn_split and (self.position_attn_key in additional
                                        and self.content_attn_key
                                        in additional):
            content_attention = additional.get(self.content_attn_key)
            positional_attention = additional.get(self.position_attn_key)

            fig, axs = plt.subplots(2, 2, figsize=self.figsize)
            _plot_attention(src_words,
                            out_words,
                            attention,
                            axs[0, 0],
                            is_colorbar=False,
                            title="Final Attention")
            _plot_table(table_values, list(filtered_pos_table_labels.keys()),
                        axs[0, 1])
            _plot_attention(src_words,
                            out_words,
                            content_attention,
                            axs[1, 0],
                            title="Content Attention")
            _plot_attention(src_words,
                            out_words,
                            positional_attention,
                            axs[1, 1],
                            title="Positional Attention")

        elif self.position_attn_key in additional:
            fig, axs = plt.subplots(1, 2, figsize=self.figsize)
            _plot_attention(src_words,
                            out_words,
                            attention,
                            axs[0],
                            title="Final Attention")
            _plot_table(table_values, list(filtered_pos_table_labels.keys()),
                        axs[1])
        else:
            fig, ax = plt.subplots(1, 1, figsize=self.figsize)
            _plot_attention(src_words,
                            out_words,
                            attention,
                            ax,
                            title="Final Attention")

        fig.text(0.5,
                 0.02,
                 ' | '.join(additional_text),
                 ha='center',
                 va='center',
                 size=13)

        if title is not None:
            plt.suptitle(title, size=13, weight="bold")
        fig.tight_layout()
        fig.subplots_adjust(bottom=0.07, top=0.83)

        return fig

    def _format_additional(self, additional):
        """Format the additinal dictionary returned by the predictor."""
        def _format_carry_rates(carry_rates):
            if carry_rates is None:
                return "Carry % : None"
            mean_carry_rates = np.around(carry_rates.mean().item(),
                                         decimals=self.decimals)
            median_carry_rates = np.around(carry_rates.median().item(),
                                           decimals=self.decimals)
            return "Carry % : mean: {}; median: {}".format(
                mean_carry_rates, median_carry_rates)

        def _format_bb_gates(gates):
            if gates is None:
                return "BB Weight Mean Gates : None"
            mean_gates = np.around(gates.mean(0), decimals=self.decimals)
            return "BB Weight Mean Gates : {}".format(mean_gates)

        def _format_mu_weights(mu_weights):
            if mu_weights is not None:
                building_blocks_labels = self.model.decoder.position_attention.bb_labels
                for i, label in enumerate(building_blocks_labels):
                    output[label + "_weight"] = mu_weights[:, i]

        output = dict()

        additional.pop(
            "visualize",
            None)  # this is only for training visualization not predict
        additional.pop("losses", None)

        additional_text = []
        additional = flatten_dict(additional)

        output = dict()
        output[self.output_length_key] = additional.pop(
            self.output_length_key)[0]

        for k, v in additional.items():
            tensor = v if isinstance(v, torch.Tensor) else torch.cat(v)
            output[k] = tensor.detach().cpu().numpy().squeeze(
            )[:output[self.output_length_key]]

        carry_txt = _format_carry_rates(additional.pop("carry_rates", None))
        bb_gates_txt = _format_bb_gates(output.pop("bb_gates", None))
        additional_text.append(carry_txt)
        additional_text.append(bb_gates_txt)

        _format_mu_weights(output.pop("mu_weights", None))

        return output, additional_text

    def _subset(self, additional, src_words, out_words, tgt_str=None):
        """Subsets the objects in the additional dictionary in order not to
        clotter the visualization.
        """
        n_src = len(src_words)
        n_out = len(out_words)

        if n_out > self.max_out:
            subset_out = self.max_out // 2
            out_words = out_words[:subset_out] + out_words[-subset_out:]
            for k, v in additional.items():
                if isinstance(v, np.ndarray):
                    additional[k] = np.concatenate(
                        (v[:subset_out], v[-subset_out:]), axis=0)

        if n_src > self.max_src:
            subset_src = self.max_src // 2
            src_words = src_words[:subset_src] + src_words[-subset_src:]
            for k, v in additional.items():
                if isinstance(v, np.ndarray) and v.ndim == 2:
                    additional[k] = np.concatenate(
                        (v[:, :subset_src], v[:, -subset_src:]), axis=1)

        if tgt_str is not None:
            tgt_words = tgt_str.split()
            n_tgt = len(tgt_words)
            if n_tgt > self.max_tgt:
                subset_target = self.max_tgt // 2
                tgt_str = " ".join(tgt_words[:subset_target] + ["..."] +
                                   tgt_words[-subset_target:])

        return additional, src_words, out_words, tgt_str
                    help='Logging level.')

args = parser.parse_args()

LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, args.log_level.upper()))
logging.info(args)

logging.info("loading checkpoint from {}".format(args.trained_model_dir))
checkpoint_path = args.trained_model_dir
checkpoint = Checkpoint.load(checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

predictor = Predictor(seq2seq, input_vocab, output_vocab)

with open(args.text_path, mode='r', encoding='utf-8') as file:
    file.readline()
    text = file.read().replace('\n', '')
    sentences = nltk.sent_tokenize(text)
    results = []
    for sentence in sentences:
        words = sentence_to_words(sentence)

        result = predictor.predict(words)
        result.remove('<eos>')
        if result:
            results.extend(result)
    print("\n".join(results))