Exemplo n.º 1
0
    def predict(self, t: str, target_ric: str) -> List[str]:
        # Connect to Postgres
        engine = create_engine(self.config.db_uri)
        SessionMaker = sessionmaker(bind=engine)
        pg_session = SessionMaker()

        rics = self.config.rics \
            if target_ric in self.config.rics \
            else [target_ric] + self.config.rics

        alignments = load_alignments_from_db(pg_session, rics, t,
                                             self.seqtypes)

        # Write the prediction data
        self.config.dir_output.mkdir(parents=True, exist_ok=True)
        dest_alignments = self.config.dir_output / Path(
            'alignment-predict.json')
        with dest_alignments.open(mode='w') as f:
            writer = jsonlines.Writer(f)
            writer.write(alignments.to_dict())

        predict_iter = create_dataset(self.config, self.device, self.vocab,
                                      rics, self.seqtypes)

        self.model.eval()

        batch = next(iter(predict_iter))

        times = batch.time
        tokens = batch.token
        raw_short_field = stringify_ric_seqtype(Code.N225.value,
                                                SeqType.RawShort)
        latest_vals = [x for x in getattr(batch, raw_short_field).data[:, 0]]
        raw_long_field = stringify_ric_seqtype(Code.N225.value,
                                               SeqType.RawLong)
        latest_closing_vals = get_latest_closing_vals(batch, raw_long_field,
                                                      times)

        loss, pred, attn_weight = self.model(batch, batch.batch_size, tokens,
                                             times, self.criterion, Phase.Test)

        i_eos = self.vocab.stoi[SpecialToken.EOS.value]
        pred_sents = [
            remove_bos([self.vocab.itos[i] for i in takeuntil(i_eos, sent)])
            for sent in zip(*pred)
        ]

        return replace_tags_with_vals(pred_sents[0], latest_closing_vals[0],
                                      latest_vals[0])
Exemplo n.º 2
0
 def test_takeuntil(self):
     result = list(takeuntil('</s>', ['Nikkei', 'rises', '</s>', '</s>']))
     self.assertEqual(result, ['Nikkei', 'rises', '</s>'])
Exemplo n.º 3
0
 def test_takeuntil_missing(self):
     s = ['Dream', 'Theater', 'is', 'one', 'of', 'the', 'greatest', 'bands']
     result = list(takeuntil('a', s))
     self.assertEqual(result, s)
Exemplo n.º 4
0
 def test_takeuntil_is_generator(self):
     result = takeuntil('</s>', ['Nikkei', 'rises', '</s>', '</s>'])
     self.assertIsInstance(result, GeneratorType)
Exemplo n.º 5
0
def run(X: Iterator, vocab: Vocab, model: EncoderDecoder,
        optimizer: Dict[SeqType, torch.optim.Optimizer],
        criterion: torch.nn.modules.Module, phase: Phase,
        logger: Logger) -> RunResult:

    if phase in [Phase.Valid, Phase.Test]:
        model.eval()
    else:
        model.train()

    numpy.random.seed(SEED)

    accum_loss = 0.0
    all_article_ids = []
    all_gold_sents = []
    all_pred_sents = []
    all_gold_sents_with_number = []
    all_pred_sents_with_number = []
    attn_weights = []

    for batch in X:

        article_ids = batch.article_id
        times = batch.time
        tokens = batch.token
        raw_short_field = stringify_ric_seqtype(Code.N225.value,
                                                SeqType.RawShort)
        latest_vals = [x for x in getattr(batch, raw_short_field).data[:, 0]]
        raw_long_field = stringify_ric_seqtype(Code.N225.value,
                                               SeqType.RawLong)
        latest_closing_vals = get_latest_closing_vals(batch, raw_long_field,
                                                      times)
        max_n_tokens, _ = tokens.size()

        # Forward
        loss, pred, attn_weight = model(batch, batch.batch_size, tokens, times,
                                        criterion, phase)

        if phase == Phase.Train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if isinstance(model.decoder.attn, Attention):
            attn_weight = numpy.array(list(zip(*attn_weight)))
            attn_weights.extend(attn_weight)

        all_article_ids.extend(article_ids)

        i_eos = vocab.stoi[SpecialToken.EOS.value]
        # Recover words from ids removing BOS and EOS from gold sentences for evaluation
        gold_sents = [
            remove_bos([vocab.itos[i] for i in takeuntil(i_eos, sent)])
            for sent in zip(*tokens.cpu().numpy())
        ]
        all_gold_sents.extend(gold_sents)

        pred_sents = [
            remove_bos([vocab.itos[i] for i in takeuntil(i_eos, sent)])
            for sent in zip(*pred)
        ]
        all_pred_sents.extend(pred_sents)

        if phase == Phase.Test:
            z_iter = zip(article_ids, gold_sents, pred_sents, latest_vals,
                         latest_closing_vals)
            for (article_id, gold_sent, pred_sent, latest_val,
                 latest_closing_val) in z_iter:

                bleu = sentence_bleu(
                    [gold_sent],
                    pred_sent,
                    smoothing_function=SmoothingFunction().method1)

                gold_sent_num = replace_tags_with_vals(gold_sent,
                                                       latest_closing_val,
                                                       latest_val)
                all_gold_sents_with_number.append(gold_sent_num)

                pred_sent_num = replace_tags_with_vals(pred_sent,
                                                       latest_closing_val,
                                                       latest_val)
                all_pred_sents_with_number.append(pred_sent_num)

                description = \
                    '\n'.join(['=== {} ==='.format(phase.value.upper()),
                               'Article ID: {}'.format(article_id),
                               'Gold (tag): {}'.format(', '.join(gold_sent)),
                               'Gold (num): {}'.format(', '.join(gold_sent_num)),
                               'Pred (tag): {}'.format(', '.join(pred_sent)),
                               'Pred (num): {}'.format(', '.join(pred_sent_num)),
                               'BLEU: {:.5f}'.format(bleu),
                               'Loss: {:.5f}'.format(loss.item() / max_n_tokens),
                               'Latest: {:.2f}'.format(latest_val),
                               'Closing: {:.2f}'.format(latest_closing_val)])
                logger.info(description)  # TODO: info → debug in release

        accum_loss += loss.item() / max_n_tokens

    return RunResult(accum_loss, all_article_ids, all_gold_sents,
                     all_gold_sents_with_number, all_pred_sents,
                     all_pred_sents_with_number)