Beispiel #1
0
    def validate(self):
        rec_losses = []
        bleus = []

        for batch in self.val_dataloader:
            batch.src, batch.trg = cudable(batch.src), cudable(batch.trg)
            # CE loss
            rec_loss = self.loss_on_batch(batch)
            rec_losses.append(rec_loss.item())

            # BLEU
            encs, enc_mask = self.transformer.encoder(batch.src)
            preds = InferenceState({
                'model': self.transformer.decoder,
                'inputs': encs,
                'enc_mask': enc_mask,
                'vocab': self.vocab_trg,
                'max_len': 50
            }).inference()
            preds = itos_many(preds, self.vocab_trg)
            gold = itos_many(batch.trg, self.vocab_trg)
            bleu = compute_bleu_for_sents(preds, gold)
            bleus.append(bleu)

        self.writer.add_scalar('val/rec_loss', np.mean(rec_losses), self.num_iters_done)
        self.writer.add_scalar('val/bleu', np.mean(bleus), self.num_iters_done)
        self.losses['val_bleu'].append(np.mean(bleus))

        texts = ['Translation: {}\n\n Gold: {}'.format(t,g) for t,g in zip(preds, gold)]
        text = '\n\n ================== \n\n'.join(texts[:10])
        self.writer.add_text('Samples', text, self.num_iters_done)
def predict(lines):
    lines = tokenize(lines)
    # Grouping all lines into batches
    src, trg = generate_dataset_with_middle_chars(lines)
    examples = [Example.fromlist([m,o], fields) for m,o in zip(src, trg)]
    ds = Dataset(examples, fields)
    dataloader = data.BucketIterator(ds, batch_size, repeat=False, shuffle=False)

    word_translations = []

    for batch in dataloader:
        # Generating predictions
        batch.src = cudable(batch.src)
        batch.trg = cudable(batch.trg)
        morphs = morph_chars_idx(batch.trg, field.vocab)
        morphs = cudable(torch.from_numpy(morphs).float())
        first_chars_embs = decoder.embed(batch.trg[:, :n_first_chars])

        z = encoder(batch.src)
        z = merge_z(torch.cat([z, morphs], dim=1))
        z = decoder.gru(first_chars_embs, z.unsqueeze(0))[1].squeeze(0)
        out = simple_inference(decoder, z, field.vocab, max_len=30)

        first_chars = batch.trg[:, :n_first_chars].cpu().numpy().tolist()
        results = [s + p for s,p in zip(first_chars, out)]
        results = itos_many(results, field.vocab, sep='')

        word_translations.extend(results)

    transfered = group_by_lens(word_translations, [len(s.split()) for s in lines])
    transfered = [mix_transfered(o,t) for o,t in zip(lines, transfered)]

    return transfered
Beispiel #3
0
    def validate(self):
        sources = []
        generated = []
        batches = list(self.val_dataloader)

        for batch in random.sample(batches,
                                   self.config.display_k_val_examples):
            src = cudable(batch.text)
            results = self.lm.inference(src.squeeze(),
                                        self.vocab,
                                        eos_token=self.eos,
                                        max_len=250)
            results = itos_many([results], self.vocab, sep='')

            generated.extend(results)
            sources.extend(itos_many(batch.text, self.vocab, sep=''))

        texts = ['`{} => {}`'.format(s, g) for s, g in zip(sources, generated)]
        text = '\n\n'.join(texts)

        self.writer.add_text('Samples', text, self.num_iters_done)
Beispiel #4
0
def transfer_style(transfer_style_on_batch, dataloader, vocab, sep=' '):
    """
    Produces predictions for a given dataloader
    """
    domain_x_to_domain_y = []
    domain_y_to_domain_x = []
    domain_x_to_domain_x = []
    domain_y_to_domain_y = []
    gold_domain_x = []
    gold_domain_y = []

    for batch in dataloader:
        x2y, y2x, x2x, y2y = transfer_style_on_batch(batch)

        domain_x_to_domain_y.extend(x2y)
        domain_y_to_domain_x.extend(y2x)
        domain_x_to_domain_x.extend(x2x)
        domain_y_to_domain_y.extend(y2y)

        gold_domain_x.extend(batch.domain_x.detach().cpu().numpy().tolist())
        gold_domain_y.extend(batch.domain_y.detach().cpu().numpy().tolist())

    # Converting to sentences
    x2y_sents = itos_many(domain_x_to_domain_y, vocab, sep=sep)
    y2x_sents = itos_many(domain_y_to_domain_x, vocab, sep=sep)
    x2x_sents = itos_many(domain_x_to_domain_x, vocab, sep=sep)
    y2y_sents = itos_many(domain_y_to_domain_y, vocab, sep=sep)
    gx_sents = itos_many(gold_domain_x, vocab, sep=sep)
    gy_sents = itos_many(gold_domain_y, vocab, sep=sep)

    return x2y_sents, y2x_sents, x2x_sents, y2y_sents, gx_sents, gy_sents
    def predict(sentences: List[str],
                n_lines: int,
                temperature: float = None,
                max_len: int = None):
        "For each sentence generates `n_lines` lines sequentially to form a dialog"
        dialogs = [s for s in sentences
                   ]  # Let's not mutate original list and copy it
        batch_size = len(dialogs)
        temperature = temperature or DEFAULT_TEMPERATURE
        max_len = max_len or DEFAULT_MAX_LINE_LEN

        for _ in range(n_lines):
            examples = [
                Example.fromlist([EOS_TOKEN.join(d)], [('text', field)])
                for d in dialogs
            ]
            dataset = Dataset(examples, [('text', field)])
            dataloader = data.BucketIterator(dataset,
                                             batch_size,
                                             shuffle=False,
                                             repeat=False)
            batch = next(iter(dataloader))  # We have a single batch
            text = cudable(
                batch.text[:, -MAX_CONTEXT_SIZE:]
            )  # As we made pad_first we are not afraid of losing information

            if model_cls_name == 'CharLMFromEmbs':
                z = lm.init_z(text.size(0), 1)
                z = lm(z, text, return_z=True)[1]
            elif model_cls_name == 'ConditionalLM':
                z = cudable(torch.zeros(2, len(text), 2048))
                z = lm(z, text, style=1, return_z=True)[1]
            elif model_cls_name == 'WeightedLMEnsemble':
                z = cudable(torch.zeros(2, 1, len(text), 4096))
                z = lm(z, text, return_z=True)[1]
            else:
                embs = lm.embed(text)
                z = lm.gru(embs)[1]

            next_lines = InferenceState({
                'model':
                lm,
                'inputs':
                z,
                'vocab':
                field.vocab,
                'max_len':
                max_len,
                'bos_token':
                EOS_TOKEN,  # We start infering a new reply when we see EOS
                'eos_token':
                EOS_TOKEN,
                'temperature':
                temperature,
                'sample_type':
                'sample',
                'inputs_batch_dim':
                1 if model_cls_name != 'WeightedLMEnsemble' else 2,
                'substitute_inputs':
                True,
                'kwargs':
                inference_kwargs
            }).inference()

            next_lines = itos_many(next_lines, field.vocab, sep='')
            next_lines = [slice_unfinished_sentence(l) for l in next_lines]
            dialogs = [d + EOS_TOKEN + l for d, l in zip(dialogs, next_lines)]

        dialogs = [d.split(EOS_TOKEN) for d in dialogs]
        dialogs = [[s for s in d if len(s) != 0] for d in dialogs]
        dialogs = [assign_speakers(d) for d in dialogs]

        return dialogs
Beispiel #6
0
def morph_chars_idx(chars_idx, vocab):
    words = itos_many(chars_idx, vocab, sep='')
    out = [word_to_onehot_features(w) for w in words]

    return np.stack(out)