Example #1
0
def beam_search_infilling(model,
                          q_ids,
                          q_sids,
                          sos_id,
                          eos_id,
                          attn_id,
                          max_encode_len=640,
                          max_decode_len=100,
                          beam_width=5,
                          tgt_type_id=3,
                          length_penalty=1.0):
    model.eval()
    _, __, info = model(q_ids, q_sids)
    d_batch, d_seqlen = q_ids.shape

    state = BeamSearchState(log_probs=L.zeros([d_batch, beam_width],
                                              'float32'),
                            lengths=L.zeros([d_batch, beam_width], 'int64'),
                            finished=L.zeros([d_batch, beam_width], 'int64'))
    outputs = []

    def reorder_(t, parent_id):
        """reorder cache according to parent beam id"""
        gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape(
            parent_id, [-1])
        t = L.gather(t, gather_idx)
        return t

    def tile_(t, times):
        _shapes = list(t.shape[1:])
        ret = L.reshape(
            L.expand(L.unsqueeze(t, [1]), [
                1,
                times,
            ] + [
                1,
            ] * len(_shapes)), [
                -1,
            ] + _shapes)
        return ret

    cached_k, cached_v = info['caches']
    cached_k = [tile_(k, beam_width) for k in cached_k]
    cached_v = [tile_(v, beam_width) for v in cached_v]
    past_cache = (cached_k, cached_v)

    q_ids = tile_(q_ids, beam_width)
    seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)

    cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id
    attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id  # SOS
    ids = L.stack([cls_ids, attn_ids], -1)
    for step in range(max_decode_len):
        bias = gen_bias(q_ids, ids, step)
        pos_ids = D.to_variable(
            np.tile(np.array([[step, step + 1]], dtype=np.int64),
                    [d_batch * beam_width, 1]))
        pos_ids += seqlen
        _, logits, info = model(ids,
                                L.ones_like(ids) * tgt_type_id,
                                pos_ids=pos_ids,
                                attn_bias=bias,
                                past_cache=past_cache)

        output, state = beam_search_step(state,
                                         logits[:, 1],
                                         eos_id=eos_id,
                                         beam_width=beam_width,
                                         is_first_step=(step == 0),
                                         length_penalty=length_penalty)
        outputs.append(output)

        past_cached_k, past_cached_v = past_cache
        cached_k, cached_v = info['caches']
        cached_k = [
            reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids)
            for pk, k in zip(past_cached_k, cached_k)
        ]  # concat cached
        cached_v = [
            reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids)
            for pv, v in zip(past_cached_v, cached_v)
        ]
        past_cache = (cached_k, cached_v)

        pred_ids_flatten = L.reshape(output.predicted_ids,
                                     [d_batch * beam_width])
        ids = L.stack([pred_ids_flatten, attn_ids], 1)

        if state.finished.numpy().all():
            break

    final_ids = L.stack([o.predicted_ids for o in outputs], 0)
    final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0)
    final_ids = L.gather_tree(final_ids, final_parent_ids)[:, :,
                                                           0]  # pick best beam
    final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1]), [1, 0])
    return final_ids
Example #2
0
    def beam_search(self,
                    src_word,
                    src_pos,
                    src_slf_attn_bias,
                    trg_word,
                    trg_src_attn_bias,
                    bos_id=0,
                    eos_id=1,
                    beam_size=4,
                    max_len=256):
        def expand_to_beam_size(tensor, beam_size):
            tensor = layers.reshape(tensor,
                                    [tensor.shape[0], 1] + tensor.shape[1:])
            tile_dims = [1] * len(tensor.shape)
            tile_dims[1] = beam_size
            return layers.expand(tensor, tile_dims)

        def merge_batch_beams(tensor):
            return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] +
                                  tensor.shape[2:])

        def split_batch_beams(tensor):
            return fluid.layers.reshape(tensor,
                                        shape=[-1, beam_size] +
                                        list(tensor.shape[1:]))

        def mask_probs(probs, finished, noend_mask_tensor):
            # TODO: use where_op
            finished = layers.cast(finished, dtype=probs.dtype)
            probs = layers.elementwise_mul(layers.expand(
                layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]),
                                           noend_mask_tensor,
                                           axis=-1) - layers.elementwise_mul(
                                               probs, (finished - 1), axis=0)
            return probs

        def gather(x, indices, batch_pos):
            topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2)
            return layers.gather_nd(x, topk_coordinates)

        # run encoder
        enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)

        # constant number
        inf = float(1. * 1e7)
        batch_size = enc_output.shape[0]
        max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
        vocab_size_tensor = layers.fill_constant(shape=[1],
                                                 dtype="int64",
                                                 value=self.trg_vocab_size)
        end_token_tensor = to_variable(
            np.full([batch_size, beam_size], eos_id, dtype="int64"))
        noend_array = [-inf] * self.trg_vocab_size
        noend_array[eos_id] = 0
        noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32"))
        batch_pos = layers.expand(
            layers.unsqueeze(
                to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]),
            [1, beam_size])

        predict_ids = []
        parent_ids = []
        ### initialize states of beam search ###
        log_probs = to_variable(
            np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size,
                     dtype="float32"))
        finished = to_variable(
            np.full([batch_size, beam_size], 0, dtype="bool"))
        ### initialize inputs and states of transformer decoder ###
        ## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
        trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1],
                                        dtype="int64",
                                        value=bos_id)
        trg_pos = layers.zeros_like(trg_word)
        trg_src_attn_bias = merge_batch_beams(
            expand_to_beam_size(trg_src_attn_bias, beam_size))
        enc_output = merge_batch_beams(
            expand_to_beam_size(enc_output, beam_size))
        ## init states (caches) for transformer, need to be updated according to selected beam
        caches = [{
            "k":
            layers.fill_constant(
                shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
                dtype=enc_output.dtype,
                value=0),
            "v":
            layers.fill_constant(
                shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
                dtype=enc_output.dtype,
                value=0),
        } for i in range(self.n_layer)]

        for i in range(max_len):
            trg_pos = layers.fill_constant(shape=trg_word.shape,
                                           dtype="int64",
                                           value=i)
            caches = map_structure(  # can not be reshaped since the 0 size
                lambda x: x if i == 0 else merge_batch_beams(x), caches)
            logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
                                  enc_output, caches)
            caches = map_structure(split_batch_beams, caches)
            step_log_probs = split_batch_beams(
                fluid.layers.log(fluid.layers.softmax(logits)))
            step_log_probs = mask_probs(step_log_probs, finished,
                                        noend_mask_tensor)
            log_probs = layers.elementwise_add(x=step_log_probs,
                                               y=log_probs,
                                               axis=0)
            log_probs = layers.reshape(log_probs,
                                       [-1, beam_size * self.trg_vocab_size])
            scores = log_probs
            topk_scores, topk_indices = fluid.layers.topk(input=scores,
                                                          k=beam_size)
            beam_indices = fluid.layers.elementwise_floordiv(
                topk_indices, vocab_size_tensor)
            token_indices = fluid.layers.elementwise_mod(
                topk_indices, vocab_size_tensor)

            # update states
            caches = map_structure(
                lambda x: gather(x, beam_indices, batch_pos), caches)
            log_probs = gather(log_probs, topk_indices, batch_pos)
            finished = gather(finished, beam_indices, batch_pos)
            finished = layers.logical_or(
                finished, layers.equal(token_indices, end_token_tensor))
            trg_word = layers.reshape(token_indices, [-1, 1])

            predict_ids.append(token_indices)
            parent_ids.append(beam_indices)

            if layers.reduce_all(finished).numpy():
                break

        predict_ids = layers.stack(predict_ids, axis=0)
        parent_ids = layers.stack(parent_ids, axis=0)
        finished_seq = layers.transpose(
            layers.gather_tree(predict_ids, parent_ids), [1, 2, 0])
        finished_scores = topk_scores

        return finished_seq, finished_scores