Exemplo n.º 1
0
def single_queue_decode(model,
                        encoder_output,
                        masks,
                        max_output_length,
                        max_hyps=1,
                        labels: dict = None,
                        buffer_size: int = 1,
                        gamma=float("-inf")):

    enc_outs = encoder_output.split()
    ys = []
    gammas = []
    for i, enc_out in enumerate(enc_outs):
        seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()}
        if model.is_transformer:
            seq_masks["trg"] = torch.ones(1,
                                          1,
                                          1,
                                          dtype=torch.bool,
                                          device=masks["trg"].device)
        #print(seq_masks["trg"].size())
        if labels is not None:
            seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()}

        # as in full_traversal, we have a generator that lazily produces
        # hypotheses
        hypotheses = _sqd(model,
                          enc_out,
                          seq_masks,
                          max_output_length,
                          labels=seq_labels,
                          buffer_size=buffer_size,
                          gamma=gamma)
        total_p = 0.0
        current_best = float("-inf")
        generated = []
        for j in range(max_hyps):
            try:
                p, y = next(hypotheses)
            except StopIteration:
                break
            # print(p, [model.trg_vocab.itos[y_i] for y_i in y])
            generated.append((p, y))
            total_p += math.exp(p)
            if p > current_best:
                current_best = p
        # print()

        best_p, best_y = max(generated) if generated else (float("-inf"),
                                                           (model.bos_index, ))
        ys.append(best_y)
        gammas.append(best_p)

    outputs = pad_and_stack_hyps(ys, model.pad_index)
    # any_finished = outputs.eq(3).any(dim=-1)

    scores = torch.FloatTensor(gammas, device=outputs.device)
    # assert scores[~any_finished].eq(float("-inf")).all()

    return outputs, scores
Exemplo n.º 2
0
def iterative_deepening(model,
                        encoder_output,
                        masks,
                        max_output_length,
                        labels: dict = None,
                        buffer_size: int = 1,
                        max_hyps=1,
                        verbose=False):
    enc_outs = encoder_output.split()
    ys = []
    gammas = []
    for i, enc_out in enumerate(enc_outs):
        best_y = None
        gamma = float("-inf")
        for max_len in range(max_output_length):
            if verbose:
                print("new max length", max_len)
            seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()}
            if model.is_transformer:
                seq_masks["trg"] = torch.ones(1,
                                              1,
                                              1,
                                              dtype=torch.bool,
                                              device=masks["trg"].device)
            if labels is not None:
                seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()}
            hypotheses = _sqd(model,
                              enc_out,
                              masks=seq_masks,
                              max_output_length=max_len,
                              gamma=gamma,
                              labels=seq_labels,
                              buffer_size=buffer_size)
            for j in range(max_hyps):
                try:
                    p, y = next(hypotheses)
                except StopIteration:
                    break
                if p > gamma:
                    gamma = p
                    best_y = y
                    if verbose:
                        y_seq = [model.trg_vocab.itos[y_i] for y_i in best_y]
                        print(gamma, y_seq)
            if verbose:
                print()
        ys.append(best_y)
        gammas.append(gamma)

    outputs = pad_and_stack_hyps(ys, model.pad_index)
    # any_finished = outputs.eq(3).any(dim=-1)

    scores = torch.FloatTensor(gammas, device=outputs.device)
    # assert scores[~any_finished].eq(float("-inf")).all()

    return outputs, scores
Exemplo n.º 3
0
def depth_first_search(model,
                       encoder_output,
                       masks,
                       max_output_length,
                       labels: dict = None):
    # non-recursive wrapper around the recursive function that does DFS
    enc_outs = encoder_output.split()
    ys = []
    gammas = []
    for i, enc_out in enumerate(enc_outs):
        seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()}
        if model.is_transformer:
            seq_masks["trg"] = torch.ones(1,
                                          1,
                                          1,
                                          dtype=torch.bool,
                                          device=masks["trg"].device)
        #print(seq_masks["trg"].size())
        if labels is not None:
            seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()}
        else:
            seq_labels = None
        y, gamma = _depth_first_search(
            model,
            enc_out,
            seq_masks,
            max_output_length + 1,  # fixing fencepost?
            labels=seq_labels)
        y = y[1:] if y is not None else None  # necessary?
        ys.append(y)
        gammas.append(gamma)

    outputs = pad_and_stack_hyps(ys, model.pad_index)
    # any_finished = outputs.eq(3).any(dim=-1)

    scores = torch.FloatTensor(gammas, device=outputs.device)
    # assert scores[~any_finished].eq(float("-inf")).all()

    return outputs, scores
Exemplo n.º 4
0
def full_traversal(model,
                   encoder_output,
                   masks,
                   max_output_length,
                   max_hyps=1,
                   mode="dfs",
                   labels: dict = None,
                   break_at_argmax=False,
                   break_at_p=1.0):
    assert mode in ["dfs", "bfs"]

    enc_outs = encoder_output.split()
    ys = []
    gammas = []
    n_hyps = []
    for i, enc_out in enumerate(enc_outs):
        seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()}
        if model.is_transformer:
            seq_masks["trg"] = torch.ones(1,
                                          1,
                                          1,
                                          dtype=torch.bool,
                                          device=masks["trg"].device)
        #print(seq_masks["trg"].size())
        if labels is not None:
            seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()}
        hypotheses = _traverse(model,
                               enc_out,
                               seq_masks,
                               max_output_length,
                               mode=mode,
                               labels=seq_labels,
                               prune=break_at_argmax)
        total_p = 0.0
        current_best = float("-inf")
        generated = []
        for j in range(max_hyps):
            try:
                p, y = next(hypotheses)
                # print(p, [model.trg_vocab.itos[y_i] for y_i in y])
            except StopIteration:
                break
            generated.append((p, y))
            total_p += math.exp(p)
            if p > current_best:
                current_best = p
            if total_p > break_at_p or (break_at_argmax
                                        and current_best > 1 - total_p):
                # either you've found the argmax or you've found a hypothesis set
                # that covers a very large part of the mass
                break
        best_p, best_y = max(generated) if generated else (float("-inf"),
                                                           (model.bos_index, ))
        ys.append(best_y)
        gammas.append(best_p)
        n_hyps.append(len(generated))
        print(total_p, len(generated))

    outputs = pad_and_stack_hyps(ys, model.pad_index)
    # any_finished = outputs.eq(3).any(dim=-1)

    scores = torch.FloatTensor(gammas, device=outputs.device)
    # assert scores[~any_finished].eq(float("-inf")).all()

    # idea: return the size of the set of hypotheses

    return outputs, scores
Exemplo n.º 5
0
def beam_search(model,
                size: int,
                encoder_output,
                masks: Dict[str, Tensor],
                max_output_length: int,
                scorer,
                labels: dict = None,
                return_scores: bool = False):
    """
    Beam search with size k.

    In each decoding step, find the k most likely partial hypotheses.

    :param decoder:
    :param size: size of the beam
    :param encoder_output:
    :param masks:
    :param max_output_length:
    :param scorer: function for rescoring hypotheses
    :param embed:
    :return:
        - stacked_output: output hypotheses (2d array of indices),
        - stacked_attention_scores: attention scores (3d array)
    """
    transformer = model.is_transformer
    any_mask = next(iter(masks.values()))
    batch_size = any_mask.size(0)

    att_vectors = None  # not used for Transformer
    device = encoder_output.device

    if model.is_ensemble:
        # run model.ensemble_bridge, I guess
        hidden = model.ensemble_bridge(encoder_output)
    else:
        if not transformer and model.decoder.bridge_layer is not None:
            hidden = model.decoder.bridge_layer(encoder_output.hidden)
        else:
            hidden = None

    # tile encoder states and decoder initial states beam_size times
    if hidden is not None:
        # layers x batch*k x dec_hidden_size
        if isinstance(hidden, list):
            hidden = [
                tile(h, size, dim=1) if h is not None else None for h in hidden
            ]
        else:
            hidden = tile(hidden, size, dim=1)

    # encoder_output: batch*k x src_len x enc_hidden_size
    encoder_output.tile(size, dim=0)
    masks = {k: tile(v, size, dim=0) for k, v in masks.items() if k != "trg"}
    masks["trg"] = any_mask.new_ones([1, 1, 1]) if transformer else None

    # numbering elements in the batch
    batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)

    # beam_size copies of each batch element
    beam_offset = torch.arange(0,
                               batch_size * size,
                               step=size,
                               dtype=torch.long,
                               device=device)

    # keeps track of the top beam size hypotheses to expand for each
    # element in the batch to be further decoded (that are still "alive")
    alive_seq = beam_offset.new_full((batch_size * size, 1), model.bos_index)
    prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1)

    # Give full probability to the first beam on the first step.
    # pylint: disable=not-callable
    current_beam = torch.tensor([0.0] + [float("-inf")] * (size - 1),
                                device=device).repeat(batch_size, 1)

    results = {
        "predictions": [[] for _ in range(batch_size)],
        "scores": [[] for _ in range(batch_size)],
        "gold_score": [0] * batch_size
    }

    for step in range(1, max_output_length + 1):

        # decode a single step
        log_probs, hidden, _, att_vectors = model.decode(
            trg_input=prev_y,
            encoder_output=encoder_output,
            masks=masks,
            decoder_hidden=hidden,
            prev_att_vector=att_vectors,
            unroll_steps=1,
            generate="log",
            labels=labels)
        log_probs = log_probs.squeeze(1)  # log_probs: batch*k x trg_vocab

        # multiply probs by the beam probability (=add logprobs)
        raw_scores = log_probs + current_beam.view(-1).unsqueeze(1)

        # flatten log_probs into a list of possibilities
        vocab_size = log_probs.size(-1)  # vocab size
        raw_scores = raw_scores.reshape(-1, size * vocab_size)

        # apply an additional scorer, such as a length penalty
        scores = scorer(raw_scores, step) if scorer is not None else raw_scores

        # pick currently best top k hypotheses (flattened order)
        topk_scores, topk_ids = scores.topk(size, dim=-1)

        # If using a length penalty, scores are distinct from log probs.
        # The beam keeps track of log probabilities regardless
        current_beam = topk_scores if scorer is None \
            else raw_scores.gather(1, topk_ids)

        # reconstruct beam origin and true word ids from flattened order
        topk_beam_index = topk_ids.div(vocab_size)
        topk_ids = topk_ids.fmod(vocab_size)

        # map beam_index to batch_index in the flat representation
        b_off = beam_offset[:topk_beam_index.size(0)].unsqueeze(1)
        batch_index = topk_beam_index + b_off
        select_ix = batch_index.view(-1)

        # append latest prediction (result: batch_size*k x hyp_len)
        selected_alive_seq = alive_seq.index_select(0, select_ix)
        alive_seq = torch.cat([selected_alive_seq, topk_ids.view(-1, 1)], -1)

        is_finished = topk_ids.eq(model.eos_index)  # batch x beam
        if step == max_output_length:
            is_finished.fill_(1)
        top_finished = is_finished[:, 0].eq(1)  # batch

        # save finished hypotheses
        seq_len = alive_seq.size(-1)
        predictions = alive_seq.view(-1, size, seq_len)
        ix = top_finished.nonzero().view(-1)
        for i in ix:
            finished_scores = topk_scores[i]
            finished_preds = predictions[i, :, 1:]
            b = batch_offset[i]
            # if you desire more hypotheses, you can use topk/sort
            top_score, top_pred_ix = finished_scores.max(dim=0)
            top_pred = finished_preds[top_pred_ix]
            results["scores"][b].append(top_score)
            results["predictions"][b].append(top_pred)

        if top_finished.all():
            break

        # remove finished batches for the next step
        unfinished = top_finished.eq(0).nonzero().view(-1)
        current_beam = current_beam.index_select(0, unfinished)
        batch_index = batch_index.index_select(0, unfinished)
        batch_offset = batch_offset.index_select(0, unfinished)
        alive_seq = predictions.index_select(0, unfinished).view(-1, seq_len)

        # reorder indices, outputs and masks
        select_ix = batch_index.view(-1)
        encoder_output.index_select(select_ix)
        masks = {
            k: v.index_select(0, select_ix) if k != "trg" else v
            for k, v in masks.items()
        }

        if model.is_ensemble:
            if not transformer:
                new_hidden = []
                for h_i in hidden:
                    if isinstance(h_i, tuple):
                        # for LSTMs, states are tuples of tensors
                        h, c = h_i
                        h = h.index_select(1, select_ix)
                        c = c.index_select(1, select_ix)
                        new_h_i = h, c
                    else:
                        # for GRUs, states are single tensors
                        new_h_i = h_i.index_select(1, select_ix)
                    new_hidden.append(new_h_i)
                hidden = new_hidden
        else:
            if hidden is not None and not transformer:
                if isinstance(hidden, tuple):
                    # for LSTMs, states are tuples of tensors
                    h, c = hidden
                    h = h.index_select(1, select_ix)
                    c = c.index_select(1, select_ix)
                    hidden = h, c
                else:
                    # for GRUs, states are single tensors
                    hidden = hidden.index_select(1, select_ix)

        if att_vectors is not None:
            if model.is_ensemble:
                att_vectors = [
                    av.index_select(0, select_ix) if av is not None else None
                    for av in att_vectors
                ]
            else:
                att_vectors = att_vectors.index_select(0, select_ix)

        prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1)

    # is moving to cpu necessary/good?
    final_outputs = pad_and_stack_hyps(
        [r[0].cpu() for r in results["predictions"]], model.pad_index)
    if return_scores:
        final_scores = torch.stack([s[0] for s in results["scores"]])
        return final_outputs, None, final_scores
    else:
        return final_outputs, None, None