def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, beam_width=5, tgt_type_id=3, length_penalty=1.0): model.eval() _, __, info = model(q_ids, q_sids) d_batch, d_seqlen = q_ids.shape state = BeamSearchState(log_probs=L.zeros([d_batch, beam_width], 'float32'), lengths=L.zeros([d_batch, beam_width], 'int64'), finished=L.zeros([d_batch, beam_width], 'int64')) outputs = [] def reorder_(t, parent_id): """reorder cache according to parent beam id""" gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape( parent_id, [-1]) t = L.gather(t, gather_idx) return t def tile_(t, times): _shapes = list(t.shape[1:]) ret = L.reshape( L.expand(L.unsqueeze(t, [1]), [ 1, times, ] + [ 1, ] * len(_shapes)), [ -1, ] + _shapes) return ret cached_k, cached_v = info['caches'] cached_k = [tile_(k, beam_width) for k in cached_k] cached_v = [tile_(v, beam_width) for v in cached_v] past_cache = (cached_k, cached_v) q_ids = tile_(q_ids, beam_width) seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id # SOS ids = L.stack([cls_ids, attn_ids], -1) for step in range(max_decode_len): bias = gen_bias(q_ids, ids, step) pos_ids = D.to_variable( np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch * beam_width, 1])) pos_ids += seqlen _, logits, info = model(ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache) output, state = beam_search_step(state, logits[:, 1], eos_id=eos_id, beam_width=beam_width, is_first_step=(step == 0), length_penalty=length_penalty) outputs.append(output) past_cached_k, past_cached_v = past_cache cached_k, cached_v = info['caches'] cached_k = [ reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) for pk, k in zip(past_cached_k, cached_k) ] # concat cached cached_v = [ reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) for pv, v in zip(past_cached_v, cached_v) ] past_cache = (cached_k, cached_v) pred_ids_flatten = L.reshape(output.predicted_ids, [d_batch * beam_width]) ids = L.stack([pred_ids_flatten, attn_ids], 1) if state.finished.numpy().all(): break final_ids = L.stack([o.predicted_ids for o in outputs], 0) final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0) final_ids = L.gather_tree(final_ids, final_parent_ids)[:, :, 0] # pick best beam final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1]), [1, 0]) return final_ids
def beam_search(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias, bos_id=0, eos_id=1, beam_size=4, max_len=256): def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) tile_dims = [1] * len(tensor.shape) tile_dims[1] = beam_size return layers.expand(tensor, tile_dims) def merge_batch_beams(tensor): return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] + tensor.shape[2:]) def split_batch_beams(tensor): return fluid.layers.reshape(tensor, shape=[-1, beam_size] + list(tensor.shape[1:])) def mask_probs(probs, finished, noend_mask_tensor): # TODO: use where_op finished = layers.cast(finished, dtype=probs.dtype) probs = layers.elementwise_mul(layers.expand( layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), noend_mask_tensor, axis=-1) - layers.elementwise_mul( probs, (finished - 1), axis=0) return probs def gather(x, indices, batch_pos): topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) return layers.gather_nd(x, topk_coordinates) # run encoder enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) # constant number inf = float(1. * 1e7) batch_size = enc_output.shape[0] max_len = (enc_output.shape[1] + 20) if max_len is None else max_len vocab_size_tensor = layers.fill_constant(shape=[1], dtype="int64", value=self.trg_vocab_size) end_token_tensor = to_variable( np.full([batch_size, beam_size], eos_id, dtype="int64")) noend_array = [-inf] * self.trg_vocab_size noend_array[eos_id] = 0 noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32")) batch_pos = layers.expand( layers.unsqueeze( to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]), [1, beam_size]) predict_ids = [] parent_ids = [] ### initialize states of beam search ### log_probs = to_variable( np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size, dtype="float32")) finished = to_variable( np.full([batch_size, beam_size], 0, dtype="bool")) ### initialize inputs and states of transformer decoder ### ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1], dtype="int64", value=bos_id) trg_pos = layers.zeros_like(trg_word) trg_src_attn_bias = merge_batch_beams( expand_to_beam_size(trg_src_attn_bias, beam_size)) enc_output = merge_batch_beams( expand_to_beam_size(enc_output, beam_size)) ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ "k": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), "v": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), } for i in range(self.n_layer)] for i in range(max_len): trg_pos = layers.fill_constant(shape=trg_word.shape, dtype="int64", value=i) caches = map_structure( # can not be reshaped since the 0 size lambda x: x if i == 0 else merge_batch_beams(x), caches) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) caches = map_structure(split_batch_beams, caches) step_log_probs = split_batch_beams( fluid.layers.log(fluid.layers.softmax(logits))) step_log_probs = mask_probs(step_log_probs, finished, noend_mask_tensor) log_probs = layers.elementwise_add(x=step_log_probs, y=log_probs, axis=0) log_probs = layers.reshape(log_probs, [-1, beam_size * self.trg_vocab_size]) scores = log_probs topk_scores, topk_indices = fluid.layers.topk(input=scores, k=beam_size) beam_indices = fluid.layers.elementwise_floordiv( topk_indices, vocab_size_tensor) token_indices = fluid.layers.elementwise_mod( topk_indices, vocab_size_tensor) # update states caches = map_structure( lambda x: gather(x, beam_indices, batch_pos), caches) log_probs = gather(log_probs, topk_indices, batch_pos) finished = gather(finished, beam_indices, batch_pos) finished = layers.logical_or( finished, layers.equal(token_indices, end_token_tensor)) trg_word = layers.reshape(token_indices, [-1, 1]) predict_ids.append(token_indices) parent_ids.append(beam_indices) if layers.reduce_all(finished).numpy(): break predict_ids = layers.stack(predict_ids, axis=0) parent_ids = layers.stack(parent_ids, axis=0) finished_seq = layers.transpose( layers.gather_tree(predict_ids, parent_ids), [1, 2, 0]) finished_scores = topk_scores return finished_seq, finished_scores