예제 #1
0
def beam_search(decoder: Decoder,
                size: int,
                bos_index: int,
                eos_index: int,
                pad_index: int,
                encoder_output: Tensor,
                encoder_hidden: Tensor,
                src_mask: Tensor,
                max_output_length: int,
                alpha: float,
                embed: Embeddings,
                n_best: int = 1) -> (np.array, np.array):
    """
    Beam search with size k. Follows OpenNMT-py implementation.
    In each decoding step, find the k most likely partial hypotheses.

    :param decoder:
    :param size: size of the beam
    :param bos_index:
    :param eos_index:
    :param pad_index:
    :param encoder_output:
    :param encoder_hidden:
    :param src_mask:
    :param max_output_length:
    :param alpha: `alpha` factor for length penalty
    :param embed:
    :param n_best: return this many hypotheses, <= beam
    :return:
        - stacked_output: output hypotheses (2d array of indices),
        - stacked_attention_scores: attention scores (3d array)
    """
    # init
    batch_size = src_mask.size(0)
    # pylint: disable=protected-access
    hidden = decoder._init_hidden(encoder_hidden)

    # tile hidden decoder states and encoder output beam_size times
    hidden = tile(hidden, size, dim=1)  # layers x batch*k x dec_hidden_size
    att_vectors = None

    encoder_output = tile(encoder_output.contiguous(), size,
                          dim=0)  # batch*k x src_len x enc_hidden_size

    src_mask = tile(src_mask, size, dim=0)  # batch*k x 1 x src_len

    batch_offset = torch.arange(batch_size,
                                dtype=torch.long,
                                device=encoder_output.device)
    beam_offset = torch.arange(0,
                               batch_size * size,
                               step=size,
                               dtype=torch.long,
                               device=encoder_output.device)
    alive_seq = torch.full([batch_size * size, 1],
                           bos_index,
                           dtype=torch.long,
                           device=encoder_output.device)

    # Give full probability to the first beam on the first step.
    # pylint: disable=not-callable
    topk_log_probs = (torch.tensor(
        [0.0] + [float("-inf")] * (size - 1),
        device=encoder_output.device).repeat(batch_size))

    # Structure that holds finished hypotheses.
    hypotheses = [[] for _ in range(batch_size)]

    results = {}
    results["predictions"] = [[] for _ in range(batch_size)]
    results["scores"] = [[] for _ in range(batch_size)]
    results["gold_score"] = [0] * batch_size

    for step in range(max_output_length):
        decoder_input = alive_seq[:, -1].view(-1, 1)

        # expand current hypotheses
        # decode one single step
        # out: logits for final softmax
        # pylint: disable=unused-variable
        out, hidden, att_scores, att_vectors = decoder(
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden,
            src_mask=src_mask,
            trg_embed=embed(decoder_input),
            hidden=hidden,
            prev_att_vector=att_vectors,
            unrol_steps=1)

        log_probs = F.log_softmax(out,
                                  dim=-1).squeeze(1)  # batch*k x trg_vocab

        # multiply probs by the beam probability (=add logprobs)
        log_probs += topk_log_probs.view(-1).unsqueeze(1)
        curr_scores = log_probs

        # compute length penalty
        if alpha > -1:
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha
            curr_scores /= length_penalty

        # flatten log_probs into a list of possibilities
        curr_scores = curr_scores.reshape(-1, size * decoder.output_size)

        # pick currently best top k hypotheses (flattened order)
        topk_scores, topk_ids = curr_scores.topk(size, dim=-1)

        if alpha > -1:
            # recover original log probs
            topk_log_probs = topk_scores * length_penalty

        # reconstruct beam origin and true word ids from flattened order
        topk_beam_index = topk_ids.div(decoder.output_size)
        topk_ids = topk_ids.fmod(decoder.output_size)

        # map beam_index to batch_index in the flat representation
        batch_index = (topk_beam_index +
                       beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
        select_indices = batch_index.view(-1)

        # append latest prediction
        alive_seq = torch.cat(
            [alive_seq.index_select(0, select_indices),
             topk_ids.view(-1, 1)], -1)  # batch_size*k x hyp_len

        is_finished = topk_ids.eq(eos_index)
        if step + 1 == max_output_length:
            is_finished.fill_(1)
        # end condition is whether the top beam is finished
        end_condition = is_finished[:, 0].eq(1)

        # save finished hypotheses
        if is_finished.any():
            predictions = alive_seq.view(-1, size, alive_seq.size(-1))
            for i in range(is_finished.size(0)):
                b = batch_offset[i]
                if end_condition[i]:
                    is_finished[i].fill_(1)
                finished_hyp = is_finished[i].nonzero().view(-1)
                # store finished hypotheses for this batch
                for j in finished_hyp:
                    hypotheses[b].append(
                        (topk_scores[i,
                                     j], predictions[i, j,
                                                     1:])  # ignore start_token
                    )
                # if the batch reached the end, save the n_best hypotheses
                if end_condition[i]:
                    best_hyp = sorted(hypotheses[b],
                                      key=lambda x: x[0],
                                      reverse=True)
                    for n, (score, pred) in enumerate(best_hyp):
                        if n >= n_best:
                            break
                        results["scores"][b].append(score)
                        results["predictions"][b].append(pred)
            non_finished = end_condition.eq(0).nonzero().view(-1)
            # if all sentences are translated, no need to go further
            # pylint: disable=len-as-condition
            if len(non_finished) == 0:
                break
            # remove finished batches for the next step
            topk_log_probs = topk_log_probs.index_select(0, non_finished)
            batch_index = batch_index.index_select(0, non_finished)
            batch_offset = batch_offset.index_select(0, non_finished)
            alive_seq = predictions.index_select(0, non_finished) \
                .view(-1, alive_seq.size(-1))

        # reorder indices, outputs and masks
        select_indices = batch_index.view(-1)
        encoder_output = encoder_output.index_select(0, select_indices)
        src_mask = src_mask.index_select(0, select_indices)

        if isinstance(hidden, tuple):
            # for LSTMs, states are tuples of tensors
            h, c = hidden
            h = h.index_select(1, select_indices)
            c = c.index_select(1, select_indices)
            hidden = (h, c)
        else:
            # for GRUs, states are single tensors
            hidden = hidden.index_select(1, select_indices)

        att_vectors = att_vectors.index_select(0, select_indices)

    def pad_and_stack_hyps(hyps, pad_value):
        filled = np.ones(
            (len(hyps), max([h.shape[0]
                             for h in hyps])), dtype=int) * pad_value
        for j, h in enumerate(hyps):
            for k, i in enumerate(h):
                filled[j, k] = i
        return filled

    # from results to stacked outputs
    assert n_best == 1
    # only works for n_best=1 for now
    final_outputs = pad_and_stack_hyps(
        [r[0].cpu().numpy() for r in results["predictions"]],
        pad_value=pad_index)

    # TODO also return attention scores and probabilities
    return final_outputs, None
예제 #2
0
def beam_search(model: Model,
                size: int,
                encoder_output: Tensor,
                encoder_hidden: Tensor,
                src_mask: Tensor,
                max_output_length: int,
                alpha: float,
                n_best: int = 1) -> (np.array, np.array):
    """
    Beam search with size k.
    Inspired by OpenNMT-py, adapted for Transformer.

    In each decoding step, find the k most likely partial hypotheses.

    :param model:
    :param size: size of the beam
    :param encoder_output:
    :param encoder_hidden:
    :param src_mask:
    :param max_output_length:
    :param alpha: `alpha` factor for length penalty
    :param n_best: return this many hypotheses, <= beam (currently only 1)
    :return:
        - stacked_output: output hypotheses (2d array of indices),
        - stacked_attention_scores: attention scores (3d array)
    """
    assert size > 0, 'Beam size must be >0.'
    assert n_best <= size, 'Can only return {} best hypotheses.'.format(size)

    # init
    bos_index = model.bos_index
    eos_index = model.eos_index
    pad_index = model.pad_index
    trg_vocab_size = model.decoder.output_size
    device = encoder_output.device
    transformer = isinstance(model.decoder, TransformerDecoder)
    batch_size = src_mask.size(0)
    att_vectors = None  # not used for Transformer
    hidden = None  # not used for Transformer
    trg_mask = None  # not used for RNN

    # Recurrent models only: initialize RNN hidden state
    # pylint: disable=protected-access
    if not transformer:
        # tile encoder states and decoder initial states beam_size times
        hidden = model.decoder._init_hidden(encoder_hidden)
        hidden = tile(hidden, size,
                      dim=1)  # layers x batch*k x dec_hidden_size
        # DataParallel splits batch along the 0th dim.
        # Place back the batch_size to the 1st dim here.
        if isinstance(hidden, tuple):
            h, c = hidden
            hidden = (h.permute(1, 0, 2), c.permute(1, 0, 2))
        else:
            hidden = hidden.permute(1, 0, 2)
            # batch*k x layers x dec_hidden_size

    encoder_output = tile(encoder_output.contiguous(), size,
                          dim=0)  # batch*k x src_len x enc_hidden_size
    src_mask = tile(src_mask, size, dim=0)  # batch*k x 1 x src_len

    # Transformer only: create target mask
    if transformer:
        trg_mask = src_mask.new_ones([1, 1, 1])  # transformer only
        if isinstance(model, torch.nn.DataParallel):
            trg_mask = torch.stack(
                [src_mask.new_ones([1, 1]) for _ in model.device_ids])

    # numbering elements in the batch
    batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)

    # numbering elements in the extended batch, i.e. beam size copies of each
    # batch element
    beam_offset = torch.arange(0,
                               batch_size * size,
                               step=size,
                               dtype=torch.long,
                               device=device)

    # keeps track of the top beam size hypotheses to expand for each element
    # in the batch to be further decoded (that are still "alive")
    alive_seq = torch.full([batch_size * size, 1],
                           bos_index,
                           dtype=torch.long,
                           device=device)

    # Give full probability to the first beam on the first step.
    topk_log_probs = torch.zeros(batch_size, size, device=device)
    topk_log_probs[:, 1:] = float("-inf")

    # Structure that holds finished hypotheses.
    hypotheses = [[] for _ in range(batch_size)]

    results = {
        "predictions": [[] for _ in range(batch_size)],
        "scores": [[] for _ in range(batch_size)],
        "gold_score": [0] * batch_size,
    }

    for step in range(max_output_length):
        # This decides which part of the predicted sentence we feed to the
        # decoder to make the next prediction.
        # For Transformer, we feed the complete predicted sentence so far.
        # For Recurrent models, only feed the previous target word prediction
        if transformer:  # Transformer
            decoder_input = alive_seq  # complete prediction so far
        else:  # Recurrent
            decoder_input = alive_seq[:, -1].view(-1, 1)  # only the last word

        # expand current hypotheses
        # decode one single step
        # logits: logits for final softmax
        # pylint: disable=unused-variable
        with torch.no_grad():
            logits, hidden, att_scores, att_vectors = model(
                return_type="decode",
                encoder_output=encoder_output,
                encoder_hidden=None,  # used to initialize decoder_hidden only
                src_mask=src_mask,
                trg_input=decoder_input,  #trg_embed = embed(decoder_input)
                decoder_hidden=hidden,
                att_vector=att_vectors,
                unroll_steps=1,
                trg_mask=trg_mask  # subsequent mask for Transformer only
            )

        # For the Transformer we made predictions for all time steps up to
        # this point, so we only want to know about the last time step.
        if transformer:
            logits = logits[:, -1]  # keep only the last time step
            hidden = None  # we don't need to keep it for transformer

        # batch*k x trg_vocab
        log_probs = F.log_softmax(logits, dim=-1).squeeze(1)

        # multiply probs by the beam probability (=add logprobs)
        log_probs += topk_log_probs.view(-1).unsqueeze(1)
        curr_scores = log_probs.clone()

        # compute length penalty
        if alpha > -1:
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha
            curr_scores /= length_penalty

        # flatten log_probs into a list of possibilities
        curr_scores = curr_scores.reshape(-1, size * trg_vocab_size)

        # pick currently best top k hypotheses (flattened order)
        topk_scores, topk_ids = curr_scores.topk(size, dim=-1)

        if alpha > -1:
            # recover original log probs
            topk_log_probs = topk_scores * length_penalty
        else:
            topk_log_probs = topk_scores.clone()

        # reconstruct beam origin and true word ids from flattened order
        topk_beam_index = topk_ids.floor_divide(trg_vocab_size)
        topk_ids = topk_ids.fmod(trg_vocab_size)

        # map beam_index to batch_index in the flat representation
        batch_index = (topk_beam_index +
                       beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
        select_indices = batch_index.view(-1)

        # append latest prediction
        alive_seq = torch.cat(
            [alive_seq.index_select(0, select_indices),
             topk_ids.view(-1, 1)], -1)  # batch_size*k x hyp_len

        is_finished = topk_ids.eq(eos_index)
        if step + 1 == max_output_length:
            is_finished.fill_(True)
        # end condition is whether the top beam is finished
        end_condition = is_finished[:, 0].eq(True)

        # save finished hypotheses
        if is_finished.any():
            predictions = alive_seq.view(-1, size, alive_seq.size(-1))
            for i in range(is_finished.size(0)):
                b = batch_offset[i]
                if end_condition[i]:
                    is_finished[i].fill_(1)
                finished_hyp = is_finished[i].nonzero(as_tuple=False).view(-1)
                # store finished hypotheses for this batch
                for j in finished_hyp:
                    # Check if the prediction has more than one EOS.
                    # If it has more than one EOS, it means that the
                    # prediction should have already been added to
                    # the hypotheses, so you don't have to add them again.
                    if (predictions[i, j, 1:]
                            == eos_index).nonzero(as_tuple=False).numel() < 2:
                        # ignore start_token
                        hypotheses[b].append(
                            (topk_scores[i, j], predictions[i, j, 1:]))
                # if the batch reached the end, save the n_best hypotheses
                if end_condition[i]:
                    best_hyp = sorted(hypotheses[b],
                                      key=lambda x: x[0],
                                      reverse=True)
                    for n, (score, pred) in enumerate(best_hyp):
                        if n >= n_best:
                            break
                        results["scores"][b].append(score)
                        results["predictions"][b].append(pred)
            non_finished = end_condition.eq(False).nonzero(
                as_tuple=False).view(-1)
            # if all sentences are translated, no need to go further
            # pylint: disable=len-as-condition
            if len(non_finished) == 0:
                break
            # remove finished batches for the next step
            topk_log_probs = topk_log_probs.index_select(0, non_finished)
            batch_index = batch_index.index_select(0, non_finished)
            batch_offset = batch_offset.index_select(0, non_finished)
            alive_seq = predictions.index_select(0, non_finished) \
                .view(-1, alive_seq.size(-1))

        # reorder indices, outputs and masks
        select_indices = batch_index.view(-1)
        encoder_output = encoder_output.index_select(0, select_indices)
        src_mask = src_mask.index_select(0, select_indices)

        if hidden is not None and not transformer:
            if isinstance(hidden, tuple):
                # for LSTMs, states are tuples of tensors
                h, c = hidden
                h = h.index_select(0, select_indices)
                c = c.index_select(0, select_indices)
                hidden = (h, c)
            else:
                # for GRUs, states are single tensors
                hidden = hidden.index_select(0, select_indices)

        if att_vectors is not None:
            att_vectors = att_vectors.index_select(0, select_indices)

    def pad_and_stack_hyps(hyps, pad_value):
        filled = np.ones(
            (len(hyps), max([h.shape[0]
                             for h in hyps])), dtype=int) * pad_value
        for j, h in enumerate(hyps):
            for k, i in enumerate(h):
                filled[j, k] = i
        return filled

    # from results to stacked outputs
    assert n_best == 1
    # only works for n_best=1 for now
    final_outputs = pad_and_stack_hyps(
        [r[0].cpu().numpy() for r in results["predictions"]],
        pad_value=pad_index)

    return final_outputs, None
예제 #3
0
def beam_search(decoder: Decoder,
                generator: Gen,
                size: int,
                bos_index: int,
                eos_index: int,
                pad_index: int,
                encoder_output: Tensor,
                encoder_hidden: Tensor,
                src_mask: Tensor,
                max_output_length: int,
                alpha: float,
                embed: Embeddings,
                n_best: int = 1,
                knowledgebase: Tuple = None) -> (np.array, np.array, np.array):
    """
    Beam search with size k.
    Inspired by OpenNMT-py, adapted for Transformer.

    In each decoding step, find the k most likely partial hypotheses.

    :param decoder:
    :param generator:
    :param size: size of the beam
    :param bos_index:
    :param eos_index:
    :param pad_index:
    :param encoder_output:
    :param encoder_hidden:
    :param src_mask:
    :param max_output_length:
    :param alpha: `alpha` factor for length penalty
    :param embed:
    :param n_best: return this many hypotheses, <= beam
    :param knowledgebase: knowledgebase tuple containing keys, values and true values for decoding
    :return:
        - stacked_output: output hypotheses (2d array of indices),
        - stacked_attention_scores: attention scores (3d array)
        - stacked_kb_att_scores: kb attention scores (3d array)
    """

    with torch.no_grad():
        # initializations and so on, this should keep weird cuda errors from happening

        # init
        transformer = isinstance(decoder, TransformerDecoder)
        batch_size = src_mask.size(0)
        att_vectors = None  # not used for Transformer

        # Recurrent models only: initialize RNN hidden state
        # pylint: disable=protected-access
        if not transformer:
            hidden = decoder._init_hidden(encoder_hidden)
        else:
            hidden = None

        # tile encoder states and decoder initial states beam_size times
        if hidden is not None:
            hidden = tile(hidden, size,
                          dim=1)  # layers x batch*k x dec_hidden_size

        encoder_output = tile(encoder_output.contiguous(), size,
                              dim=0)  # batch*k x src_len x enc_hidden_size
        src_mask = tile(src_mask, size, dim=0)  # batch*k x 1 x src_len

        # Transformer only: create target mask
        if transformer:
            trg_mask = src_mask.new_ones([1, 1, 1])  # transformer only
        else:
            trg_mask = None

        # numbering elements in the batch
        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=encoder_output.device)

        # numbering elements in the extended batch, i.e. beam size copies of each
        # batch element
        beam_offset = torch.arange(0,
                                   batch_size * size,
                                   step=size,
                                   dtype=torch.long,
                                   device=encoder_output.device)

        # keeps track of the top beam size hypotheses to expand for each element
        # in the batch to be further decoded (that are still "alive")
        alive_seq = torch.full([batch_size * size, 1],
                               bos_index,
                               dtype=torch.long,
                               device=encoder_output.device)

        # Give full probability to the first beam on the first step.
        # pylint: disable=not-callable
        topk_log_probs = torch.zeros(batch_size,
                                     size,
                                     device=encoder_output.device)
        topk_log_probs[:, 1:] = float("-inf")

        # Structure that holds finished hypotheses in order of completion.
        hypotheses = [[] for _ in range(batch_size)]

        results = {}

        results["predictions"] = [[] for _ in range(batch_size)]
        results["scores"] = [[] for _ in range(batch_size)]
        results["att_scores"] = [[] for _ in range(batch_size)]
        results["kb_att_scores"] = [[] for _ in range(batch_size)]

        # kb task: also tile kb tensors along batch dimension as done with other inputs above
        if knowledgebase != None:
            kb_values = tile(knowledgebase[1], size, dim=0)
            kb_mask = tile(knowledgebase[-1], size, dim=0)
            kb_values_embed = tile(knowledgebase[2], size, dim=0)

            kb_size = kb_values.size(1)
            kb_keys = knowledgebase[0]

            if isinstance(kb_keys, tuple):
                kb_keys = tuple(
                    [tile(key_dim, size, dim=0) for key_dim in kb_keys])
            else:
                kb_keys = tile(kb_keys, size, dim=0)

            att_alive = torch.Tensor(  # batch * k x src x time
                [[[] for _ in range(encoder_output.size(1))]
                 for _ in range(batch_size * size)
                 ]).to(dtype=torch.float32, device=encoder_output.device)

            kb_att_alive = torch.Tensor(  # batch*k x KB x time
                [[[] for _ in range(kb_size)] for _ in range(batch_size * size)
                 ]).to(dtype=torch.float32, device=encoder_output.device)

            debug_tnsrs = (kb_values, kb_mask, kb_values_embed,
                           (kb_keys if isinstance(kb_keys, torch.Tensor) else
                            kb_keys[0]), alive_seq)
            assert set([t.size(0) for t in debug_tnsrs
                        ]) == set([batch_size * size
                                   ]), [t.shape for t in debug_tnsrs]

            stacked_attention_scores = [[] for _ in range(batch_size)]
            stacked_kb_att_scores = [[] for _ in range(batch_size)]

            util_dims_cache = None
            kb_feed_hidden_cache = None

        else:
            kb_keys, kb_values, kb_mask = None, None, None
            kb_size = None
            att_alive = None
            kb_att_alive = None
            stacked_attention_scores, stacked_kb_att_scores = None, None

    for step in range(max_output_length):

        # This decides which part of the predicted sentence we feed to the
        # decoder to make the next prediction.
        # For Transformer, we feed the complete predicted sentence so far.
        # For Recurrent models, only feed the previous target word prediction
        if transformer:  # Transformer
            decoder_input = alive_seq  # complete prediction so far
        else:  # Recurrent
            decoder_input = alive_seq[:, -1].view(-1, 1)  # only the last word

        # expand current hypotheses
        # decode one single step
        # pylint: disable=unused-variable
        trg_embed = embed(decoder_input)

        hidden, att_scores, att_vectors, kb_scores, util_dims_cache, kb_feed_hidden_cache = decoder(
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden,
            src_mask=src_mask,
            trg_embed=trg_embed,
            hidden=hidden,
            prev_att_vector=att_vectors,
            unroll_steps=1,
            trg_mask=trg_mask,  # subsequent mask for Transformer only
            kb_keys=kb_keys,  # None by default 
            kb_mask=kb_mask,
            kb_values_embed=kb_values_embed,
            util_dims_cache=util_dims_cache,
            kb_feed_hidden_cache=kb_feed_hidden_cache)

        try:
            # generator applies output layer, biases towards KB values, then applies log_softmax
            log_probs = generator(att_vectors,
                                  kb_values=kb_values,
                                  kb_probs=kb_scores)
        except Exception as e:
            print(kb_scores.shape)
            print(kb_mask_before_index)
            print(kb_mask_after_index)
            raise e

        # hidden = ?? x batch*k x dec hidden        #FIXME why 3 ??????
        # att_scores = batch*k x 1 x src_len        #TODO  Find correct beam in dim 0 at every timestep.
        # att_vectors = batch*k x 1 x dec hidden
        # kb_scores = batch*k x 1 x KB              #TODO  find correct beam in dim 0 at every timestep
        # log_probs = batch*k x 1 x trg_voc

        # For the Transformer we made predictions for all time steps up to
        # this point, so we only want to know about the last time step.
        if transformer:
            log_probs = log_probs[:, -1]  # keep only the last time step
            hidden = None  # we don't need to keep it for transformer

        # batch * k x trg_vocab
        log_probs = log_probs.squeeze(1)

        # multiply probs by the probability of each beam thus far ( = add logprobs)
        try:
            log_probs += topk_log_probs.view(-1).unsqueeze(1)
        except Exception as e:
            dbg_tnsrs = [
                hidden, att_scores, att_vectors, kb_scores, util_dims_cache,
                kb_feed_hidden_cache
            ]
            print([t.shape for t in dbg_tnsrs if isinstance(t, torch.Tensor)])
            print(
                [t.size(0) for t in dbg_tnsrs if isinstance(t, torch.Tensor)])
            print(step)
            print(encoder_output.shape)
            print(select_indices)
            print(batch_index)
            print(non_finished)
            print(non_finished.shape)
            print(batch_size * size)
            raise e
        curr_scores = log_probs

        # compute length penalty
        if alpha > -1:
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha
            curr_scores /= length_penalty

        # flatten log_probs into a list of possibilities
        curr_scores = curr_scores.reshape(
            -1, size * generator.output_size)  # batch x k * voc FIXME

        # pick currently best top k hypotheses (flattened order)
        topk_scores, topk_ids = curr_scores.topk(size,
                                                 dim=-1)  # each: batch x k

        if alpha > -1:
            # recover original log probs
            topk_log_probs = topk_scores * length_penalty  # b x k

        # reconstruct beam origin and true word ids from flattened order

        topk_beam_index = (topk_ids // generator.output_size).to(
            dtype=torch.int64
        )  # NOTE why divide by voc size?? this should always be 0
        topk_ids = topk_ids.fmod(
            generator.output_size
        )  # NOTE why mod voc size? isnt every entry < voc size?

        # map beam_index to batch_index in the flat representation
        batch_index = (topk_beam_index +
                       beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
        select_indices = batch_index.view(-1)  # batch * k

        # append latest prediction
        alive_seq = torch.cat(
            [
                alive_seq.index_select(
                    0, select_indices
                ),  # index first dim (batch * k) with the beams we want to continue this step
                topk_ids.view(-1, 1)
            ],
            -1)  # batch_size*k x hyp_len

        if knowledgebase is not None:
            # print(f"kb_att_alive.shape: {kb_att_alive.shape}")
            # print(f"kb_size: {kb_size}")
            # print(kb_att_alive.index_select(0,select_indices).shape)
            # print(kb_scores.transpose(1,2).index_select(0,select_indices).shape)

            if att_scores is not None:
                # FIXME sometimes this way sometimes the other idk
                try:
                    att_alive = torch.cat(  # batch * k x src len x time
                        [
                            att_alive.index_select(0, select_indices),
                            att_scores.transpose(1, 2).index_select(
                                0, select_indices).contiguous()
                        ], -1)
                except Exception as e:
                    print(f"step: {step}")
                    print(select_indices)
                    print(f"att_alive.shape: {att_alive.shape}")
                    print(f"encoder steps: {encoder_output.size(1)}")
                    print(
                        att_scores.transpose(1, 2).index_select(
                            0, select_indices).shape)
                    raise e

            kb_att_alive = torch.cat(  # batch * k x KB x time
                [
                    kb_att_alive.index_select(0, select_indices),
                    kb_scores.transpose(1, 2).index_select(
                        0, select_indices).contiguous()
                ], -1)

        # which batches are finished?
        is_finished = topk_ids.eq(eos_index)  # batch x k
        if step + 1 == max_output_length:
            # force finish
            is_finished.fill_(True)
        # end condition is whether the top beam of given batch is finished
        end_condition = is_finished[:, 0].eq(True)

        # save finished hypotheses if any of the batches finished
        if is_finished.any():

            predictions = alive_seq.view(
                -1, size, alive_seq.size(-1))  # batch x k x time

            for i in range(is_finished.size(0)):  # iter over batches

                b = batch_offset[i]
                if end_condition[i]:
                    # this batch finished
                    is_finished[i].fill_(True)

                finished_hyp = is_finished[i].nonzero(as_tuple=False).view(
                    -1)  # k

                # store finished hypotheses for this batch
                # (that doesnt mean the batch is completely finished,
                # hence the list 'hypotheses' is maintained outside the unroll loop)
                for j in finished_hyp:  # iter over finished beams

                    # first time EOS appears in this beam, save it as hypothesis
                    # (also save attentions here)
                    if (predictions[i, j, 1:]
                            == eos_index).nonzero(as_tuple=False).numel() < 2:
                        hypotheses[b].append((
                            topk_scores[
                                i, j],  # for sorting beams by prob (below)
                            predictions[i, j, 1:])  # ignore BOS token 
                                             )
                        if knowledgebase is not None:

                            # batch x k x src len x time
                            if 0 not in att_alive.shape:
                                # at least one attention matrix has been inserted
                                attentions = att_alive.view(
                                    -1, size, att_alive.size(-2),
                                    att_alive.size(-1))
                                stacked_attention_scores[b].append(
                                    attentions[i, j].cpu().numpy())
                            else:
                                attentions = None

                            # batch x k x KB x time
                            kb_attentions = kb_att_alive.view(
                                -1, size, kb_att_alive.size(-2),
                                kb_att_alive.size(-1))

                            stacked_kb_att_scores[b].append(
                                kb_attentions[i, j].cpu().numpy())

                # if the batch reached the end, save the n best hypotheses (and their attentions and kb attentions)
                if end_condition[i]:
                    # (hypotheses[b] is list of the completed hypotheses of this batch in order of completion => find out which is best)
                    # (stacked_attention_scores[b] and stacked_kb_att_scores[b] are also in order of completion)

                    # which beam is best?
                    best_hyps_descending = sorted(hypotheses[b],
                                                  key=lambda x: x[0],
                                                  reverse=True)

                    dbg = np.array(
                        [hyp[1].cpu().numpy() for hyp in best_hyps_descending])
                    print(dbg.shape, dbg[0])

                    if knowledgebase is not None:

                        print(hypotheses[b][0], type(hypotheses[b][0]))

                        scores, hyps = zip(*hypotheses[b])
                        sort_key = np.array(scores)
                        hyps = np.array([hyp.cpu().numpy() for hyp in hyps])

                        # indices that would sort hyp[b] in descending order of beam score
                        best_hyps_idx = np.argsort(sort_key)[::-1].copy()
                        best_hyps_d_ = hyps[best_hyps_idx]

                        # sanity check implementation
                        try:
                            assert set([(t1 == t2).all()
                                        for t1, t2 in zip(best_hyps_d_, dbg)
                                        ]) == {True}
                        except Exception as e:
                            print(best_hyps_d_.dtype)
                            print(dbg.dtype)
                            print([[t.dtype for t in tup]
                                   for tup in (best_hyps_d_, dbg)])
                            raise e

                        assert n_best == 1, f"This is a massive clutch: Currently indexing only top 1 beam while saving attentions"

                        # FIXME TODO NOTE XXX

                        if 0 not in att_alive.shape:
                            best_atts_d_ = [
                                stacked_attention_scores[b][best_hyps_idx[0]]
                            ]
                        else:
                            best_atts_d_ = None
                        best_kb_atts_d_ = [
                            stacked_kb_att_scores[b][best_hyps_idx[0]]
                        ]

                    # TODO replace best_hyps_descending with best_hyps_d_ FIXME XXX (after cluster beam test)
                    for n, (score, pred) in enumerate(best_hyps_descending):
                        if n >= n_best:
                            break
                        results["scores"][b].append(score)
                        results["predictions"][b].append(pred)

                        if knowledgebase is not None:
                            if best_atts_d_ is not None:
                                results["att_scores"][b].append(
                                    best_atts_d_[n])
                            results["kb_att_scores"][b].append(
                                best_kb_atts_d_[n])

            non_finished = end_condition.eq(False).nonzero(
                as_tuple=False).view(-1)  # batch
            # if all sentences are translated, no need to go further
            # pylint: disable=len-as-condition
            if len(non_finished) == 0:
                break

            # remove finished batches for the next step
            batch_index = batch_index.index_select(0, non_finished)
            batch_offset = batch_offset.index_select(0, non_finished)

            topk_log_probs = topk_log_probs.index_select(0, non_finished)
            alive_seq = predictions.index_select(0, non_finished) \
                .view(-1, alive_seq.size(-1))

            if knowledgebase is not None:

                # briefly go to
                # batch x k x time x att
                # to easily index_select finished batches in batch dimension 0

                # afterwards reshape to
                # batch * k x time x att

                # where att = src_len for alive attentions, and att = kb_size for kb_attentions alive

                if 0 not in att_alive.shape:
                    att_alive = att_alive.view(-1, size, att_alive.size(-2), att_alive.size(-1)) \
                        .index_select(0, non_finished)
                    att_alive = att_alive.view(-1, att_alive.size(-2),
                                               att_alive.size(-1))

                kb_att_alive = kb_att_alive.view(-1, size, kb_att_alive.size(-2), kb_att_alive.size(-1)) \
                        .index_select(0, non_finished)
                kb_att_alive = kb_att_alive.view(-1, kb_att_alive.size(-2),
                                                 kb_att_alive.size(-1))

        # reorder indices, outputs and masks using this
        select_indices = batch_index.view(-1)

        encoder_output = encoder_output.index_select(0, select_indices)
        src_mask = src_mask.index_select(0, select_indices)  # for transformer

        if hidden is not None and not transformer:
            # reshape hidden to correct shape for next step
            if isinstance(hidden, tuple):
                # for LSTMs, states are tuples of tensors
                h, c = hidden
                h = h.index_select(1, select_indices)
                c = c.index_select(1, select_indices)
                hidden = (h, c)
            else:
                # for GRUs, states are single tensors
                hidden = hidden.index_select(1, select_indices)

        if att_vectors is not None:

            if isinstance(att_vectors, tuple):
                att_vectors = tuple([
                    att_v.index_select(0, select_indices)
                    for att_v in att_vectors
                ])
            else:
                att_vectors = att_vectors.index_select(0, select_indices)

        if knowledgebase is not None:

            kb_values = kb_values.index_select(0, select_indices)

            if isinstance(kb_keys, tuple):
                kb_keys = tuple([
                    key_dim.index_select(0, select_indices)
                    for key_dim in kb_keys
                ])
            else:
                kb_keys = kb_keys.index_select(0, select_indices)

            if util_dims_cache is not None:
                util_dims_cache = [
                    utils.index_select(0, select_indices)
                    for utils in util_dims_cache if utils is not None
                ]
            if kb_feed_hidden_cache is not None:
                try:
                    kb_feed_hidden_cache = [
                        kbf_hidden.index_select(0, select_indices)
                        for kbf_hidden in kb_feed_hidden_cache
                        if kbf_hidden is not None
                    ]
                except IndexError as IE:
                    print(hidden[0].shape)
                    print([t.shape for t in kb_feed_hidden_cache])
                    print(select_indices)
                    print(select_indices.shape)
                    print(size)
                    print(generator.output_size)
                    raise IE
            kb_mask_before_index = kb_mask.shape
            kb_mask = kb_mask.index_select(0, select_indices)
            kb_mask_after_index = kb_mask.shape

    def pad_and_stack_hyps(hyps, pad_value):
        # hyps is arrays of hypotheses
        filled = np.ones(
            (len(hyps), max([h.shape[0]
                             for h in hyps])), dtype=int) * pad_value
        for j, h in enumerate(hyps):
            for k, i in enumerate(h):
                filled[j, k] = i
        return filled

    def pad_and_stack_attention_matrices(atts, pad_value=float("-inf")):
        assert len(list(set([att.shape[1] for att in atts]))) == 1, \
            f"attention matrices have differing attention key bag dimension: {[att.shape[1] for att in atts]}"
        # atts is array of attention matrices, each of dims time x att_dim, where time dims may vary from matrix to matrix
        # NOTE pad_value is used in model.postprocess to recover original part of matrix
        try:
            filled = np.ones(
                (len(atts), max([att.shape[-2]
                                 for att in atts]), atts[0].shape[-1]),
                dtype=atts[0].dtype)
            filled = filled * pad_value
        except Exception as e:
            print(atts[0].shape)
            raise e
        for batch_element_index, attention_matrix in enumerate(atts):
            for t, attentions_at_decoding_step in enumerate(attention_matrix):
                for attention_key, score in enumerate(
                        attentions_at_decoding_step):

                    filled[batch_element_index, t, attention_key] = score
        return filled  # b x time x attention keys

    # from results to stacked outputs
    assert n_best == 1
    # only works for n_best=1 for now

    # final_outputs = batch x time
    final_outputs = pad_and_stack_hyps(
        [r[0].cpu().numpy() for r in results["predictions"]],
        pad_value=pad_index)

    if knowledgebase is not None:
        # TODO FIXME confirm this implementation

        # stacked_attention_scores: batch x max output len x src len
        if len(results["att_scores"][0]):
            stacked_attention_scores = pad_and_stack_attention_matrices(
                [atts[0].T for atts in results["att_scores"]])
        else:
            stacked_attention_scores = None

        # stacked_kb_att_scores: batch x max output len x kb
        stacked_kb_att_scores = pad_and_stack_attention_matrices(
            [kb_atts[0].T for kb_atts in results["kb_att_scores"]])

    return final_outputs, stacked_attention_scores, stacked_kb_att_scores
예제 #4
0
 def tile(self, size, dim):
     self.states = {
         k: tile(v.contiguous(), size, dim=dim)
         for k, v in self.states.items()
     }
예제 #5
0
 def tile(self, size, dim):
     self.states = tile(self.states.contiguous(), size, dim=dim)
예제 #6
0
def beam_search(model,
                size: int,
                encoder_output,
                masks: Dict[str, Tensor],
                max_output_length: int,
                scorer,
                labels: dict = None,
                return_scores: bool = False):
    """
    Beam search with size k.

    In each decoding step, find the k most likely partial hypotheses.

    :param decoder:
    :param size: size of the beam
    :param encoder_output:
    :param masks:
    :param max_output_length:
    :param scorer: function for rescoring hypotheses
    :param embed:
    :return:
        - stacked_output: output hypotheses (2d array of indices),
        - stacked_attention_scores: attention scores (3d array)
    """
    transformer = model.is_transformer
    any_mask = next(iter(masks.values()))
    batch_size = any_mask.size(0)

    att_vectors = None  # not used for Transformer
    device = encoder_output.device

    if model.is_ensemble:
        # run model.ensemble_bridge, I guess
        hidden = model.ensemble_bridge(encoder_output)
    else:
        if not transformer and model.decoder.bridge_layer is not None:
            hidden = model.decoder.bridge_layer(encoder_output.hidden)
        else:
            hidden = None

    # tile encoder states and decoder initial states beam_size times
    if hidden is not None:
        # layers x batch*k x dec_hidden_size
        if isinstance(hidden, list):
            hidden = [
                tile(h, size, dim=1) if h is not None else None for h in hidden
            ]
        else:
            hidden = tile(hidden, size, dim=1)

    # encoder_output: batch*k x src_len x enc_hidden_size
    encoder_output.tile(size, dim=0)
    masks = {k: tile(v, size, dim=0) for k, v in masks.items() if k != "trg"}
    masks["trg"] = any_mask.new_ones([1, 1, 1]) if transformer else None

    # numbering elements in the batch
    batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)

    # beam_size copies of each batch element
    beam_offset = torch.arange(0,
                               batch_size * size,
                               step=size,
                               dtype=torch.long,
                               device=device)

    # keeps track of the top beam size hypotheses to expand for each
    # element in the batch to be further decoded (that are still "alive")
    alive_seq = beam_offset.new_full((batch_size * size, 1), model.bos_index)
    prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1)

    # Give full probability to the first beam on the first step.
    # pylint: disable=not-callable
    current_beam = torch.tensor([0.0] + [float("-inf")] * (size - 1),
                                device=device).repeat(batch_size, 1)

    results = {
        "predictions": [[] for _ in range(batch_size)],
        "scores": [[] for _ in range(batch_size)],
        "gold_score": [0] * batch_size
    }

    for step in range(1, max_output_length + 1):

        # decode a single step
        log_probs, hidden, _, att_vectors = model.decode(
            trg_input=prev_y,
            encoder_output=encoder_output,
            masks=masks,
            decoder_hidden=hidden,
            prev_att_vector=att_vectors,
            unroll_steps=1,
            generate="log",
            labels=labels)
        log_probs = log_probs.squeeze(1)  # log_probs: batch*k x trg_vocab

        # multiply probs by the beam probability (=add logprobs)
        raw_scores = log_probs + current_beam.view(-1).unsqueeze(1)

        # flatten log_probs into a list of possibilities
        vocab_size = log_probs.size(-1)  # vocab size
        raw_scores = raw_scores.reshape(-1, size * vocab_size)

        # apply an additional scorer, such as a length penalty
        scores = scorer(raw_scores, step) if scorer is not None else raw_scores

        # pick currently best top k hypotheses (flattened order)
        topk_scores, topk_ids = scores.topk(size, dim=-1)

        # If using a length penalty, scores are distinct from log probs.
        # The beam keeps track of log probabilities regardless
        current_beam = topk_scores if scorer is None \
            else raw_scores.gather(1, topk_ids)

        # reconstruct beam origin and true word ids from flattened order
        topk_beam_index = topk_ids.div(vocab_size)
        topk_ids = topk_ids.fmod(vocab_size)

        # map beam_index to batch_index in the flat representation
        b_off = beam_offset[:topk_beam_index.size(0)].unsqueeze(1)
        batch_index = topk_beam_index + b_off
        select_ix = batch_index.view(-1)

        # append latest prediction (result: batch_size*k x hyp_len)
        selected_alive_seq = alive_seq.index_select(0, select_ix)
        alive_seq = torch.cat([selected_alive_seq, topk_ids.view(-1, 1)], -1)

        is_finished = topk_ids.eq(model.eos_index)  # batch x beam
        if step == max_output_length:
            is_finished.fill_(1)
        top_finished = is_finished[:, 0].eq(1)  # batch

        # save finished hypotheses
        seq_len = alive_seq.size(-1)
        predictions = alive_seq.view(-1, size, seq_len)
        ix = top_finished.nonzero().view(-1)
        for i in ix:
            finished_scores = topk_scores[i]
            finished_preds = predictions[i, :, 1:]
            b = batch_offset[i]
            # if you desire more hypotheses, you can use topk/sort
            top_score, top_pred_ix = finished_scores.max(dim=0)
            top_pred = finished_preds[top_pred_ix]
            results["scores"][b].append(top_score)
            results["predictions"][b].append(top_pred)

        if top_finished.all():
            break

        # remove finished batches for the next step
        unfinished = top_finished.eq(0).nonzero().view(-1)
        current_beam = current_beam.index_select(0, unfinished)
        batch_index = batch_index.index_select(0, unfinished)
        batch_offset = batch_offset.index_select(0, unfinished)
        alive_seq = predictions.index_select(0, unfinished).view(-1, seq_len)

        # reorder indices, outputs and masks
        select_ix = batch_index.view(-1)
        encoder_output.index_select(select_ix)
        masks = {
            k: v.index_select(0, select_ix) if k != "trg" else v
            for k, v in masks.items()
        }

        if model.is_ensemble:
            if not transformer:
                new_hidden = []
                for h_i in hidden:
                    if isinstance(h_i, tuple):
                        # for LSTMs, states are tuples of tensors
                        h, c = h_i
                        h = h.index_select(1, select_ix)
                        c = c.index_select(1, select_ix)
                        new_h_i = h, c
                    else:
                        # for GRUs, states are single tensors
                        new_h_i = h_i.index_select(1, select_ix)
                    new_hidden.append(new_h_i)
                hidden = new_hidden
        else:
            if hidden is not None and not transformer:
                if isinstance(hidden, tuple):
                    # for LSTMs, states are tuples of tensors
                    h, c = hidden
                    h = h.index_select(1, select_ix)
                    c = c.index_select(1, select_ix)
                    hidden = h, c
                else:
                    # for GRUs, states are single tensors
                    hidden = hidden.index_select(1, select_ix)

        if att_vectors is not None:
            if model.is_ensemble:
                att_vectors = [
                    av.index_select(0, select_ix) if av is not None else None
                    for av in att_vectors
                ]
            else:
                att_vectors = att_vectors.index_select(0, select_ix)

        prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1)

    # is moving to cpu necessary/good?
    final_outputs = pad_and_stack_hyps(
        [r[0].cpu() for r in results["predictions"]], model.pad_index)
    if return_scores:
        final_scores = torch.stack([s[0] for s in results["scores"]])
        return final_outputs, None, final_scores
    else:
        return final_outputs, None, None