def step(self, time, inputs, states, **kwargs): # Steps for decoding. # Compared to RNN, Transformer has 3D data at every decoding step inputs = paddle.reshape(inputs, [-1, 1]) # token pos = paddle.ones_like(inputs) * time # pos cell_states = map_structure(self._merge_batch_beams_with_var_dim, states.cell_states) cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states, **kwargs) # Squeeze to adapt to BeamSearchDecoder which use 2D logits cell_outputs = map_structure( lambda x: paddle.squeeze(x, [1]) if len(x.shape) == 3 else x, cell_outputs) cell_outputs = map_structure(self._split_batch_beams, cell_outputs) next_cell_states = map_structure(self._split_batch_beams_with_var_dim, next_cell_states) beam_search_output, beam_search_state = self._beam_search_step( time=time, logits=cell_outputs, next_cell_states=next_cell_states, beam_state=states) next_inputs, finished = (beam_search_output.predicted_ids, beam_search_state.finished) return (beam_search_output, beam_search_state, next_inputs, finished)
def step(self, time, inputs, states, **kwargs): # Steps for decoding. # Compared to RNN, Transformer has 3D data at every decoding step inputs = paddle.reshape(inputs, [-1, 1]) # token pos = paddle.ones_like(inputs) * time # pos cell_states = map_structure(self._merge_batch_beams_with_var_dim, states.cell_states) cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states, **kwargs) # Squeeze to adapt to BeamSearchDecoder which use 2D logits cell_outputs = map_structure( lambda x: paddle.squeeze(x, [1]) if len(x.shape) == 3 else x, cell_outputs) cell_outputs = map_structure(self._split_batch_beams, cell_outputs) next_cell_states = map_structure(self._split_batch_beams_with_var_dim, next_cell_states) beam_search_output, beam_search_state = self._beam_search_step( time=time, logits=cell_outputs, next_cell_states=next_cell_states, beam_state=states) if kwargs.get("trg_word", None) is not None: if in_dygraph_mode(): if paddle.shape(kwargs.get("trg_word"))[1] > time: beam_search_output, beam_search_state = self.force_decoding( beam_search_output, beam_search_state, kwargs.get("trg_word"), kwargs.get("trg_length"), time) else: def condition(trg_word, time): return paddle.shape(trg_word)[1] > time def default_fn(beam_search_output, beam_search_state): return beam_search_output, beam_search_state from functools import partial beam_search_output, beam_search_state = paddle.static.nn.case( [(condition(kwargs.get("trg_word"), time), partial(self.force_decoding, beam_search_output=beam_search_output, beam_search_state=beam_search_state, trg_word=kwargs.get("trg_word"), trg_length=kwargs.get("trg_length"), time=time))], default=partial(default_fn, beam_search_output=beam_search_output, beam_search_state=beam_search_state)) next_inputs, finished = (beam_search_output.predicted_ids, beam_search_state.finished) return (beam_search_output, beam_search_state, next_inputs, finished)
def tile_beam_merge_with_batch(t, beam_size): r""" Tiles the batch dimension of a tensor. Specifically, this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape `[batch_size * beam_size, s0, s1, ...]` composed of minibatch entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated `beam_size` times. Args: t (list|tuple): A list of tensor with shape `[batch_size, ...]`. beam_size (int): The beam width used in beam search. Returns: Tensor: A tensor with shape `[batch_size * beam_size, ...]`, whose data type is same as `t`. Example: .. code-block:: import paddle from paddlenlp.transformers import TransformerBeamSearchDecoder t = paddle.rand(shape=[10, 10]) TransformerBeamSearchDecoder.tile_beam_merge_with_batch(t, beam_size=4) """ return map_structure( lambda x: nn.decode.BeamSearchDecoder.tile_beam_merge_with_batch( x, beam_size), t)
def __init__(self): super(Model, self).__init__("model") self._dygraph_mode = is_eager # extract data desc from method arguments self._data_descs = {} is_sequence_ori = utils.is_sequence # nested structure of shapes utils.is_sequence = self._InputDesc._is_shape_sequence for func in [self.forward, self.loss]: flag = True func_argspec = inspect.getargspec(func) for i, arg in enumerate(func_argspec.args[::-1]): if arg.endswith("_shape"): assert flag, "_shape arguments must be at the rear." assert i <= len( func_argspec.defaults ), "The shape argument must have default value." self._data_descs[arg[:-len("_shape")]] = map_structure( lambda shape: self._InputDesc(shape), func_argspec.defaults[-i - 1]) else: # switch flag flag = False utils.is_sequence = is_sequence_ori print(self._data_descs)
def body_func(step_idx, pre_ids, pre_scores, gather_idx, caches, trg_src_attn_bias): # gather cell states corresponding to selected parent pre_caches = map_structure( lambda x: layers.gather(x, index=gather_idx), caches) pre_src_attn_bias = layers.gather(trg_src_attn_bias, index=gather_idx) pre_pos = layers.elementwise_mul( x=layers.fill_constant_batch_size_like( input=pre_src_attn_bias, # cann't use lod tensor here value=1, shape=[-1, 1], dtype=pre_ids.dtype), y=step_idx, axis=0) logits = wrap_decoder((pre_ids, pre_pos, None, pre_src_attn_bias), trg_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing, enc_output=enc_output, caches=pre_caches, bos_idx=bos_idx) # intra-beam topK topk_scores, topk_indices = layers.topk( input=layers.softmax(logits), k=beam_size) accu_scores = layers.elementwise_add(x=layers.log(topk_scores), y=pre_scores, axis=0) # beam_search op uses lod to differentiate branches. accu_scores = layers.lod_reset(accu_scores, pre_ids) # topK reduction across beams, also contain special handle of # end beams and end sentences(batch reduction) selected_ids, selected_scores, gather_idx = layers.beam_search( pre_ids=pre_ids, pre_scores=pre_scores, ids=topk_indices, scores=accu_scores, beam_size=beam_size, end_id=eos_idx, return_parent_idx=True) step_idx = layers.increment(x=step_idx, value=1.0, in_place=False) layers.array_write(selected_ids, i=step_idx, array=ids) layers.array_write(selected_scores, i=step_idx, array=scores) return (step_idx, selected_ids, selected_scores, gather_idx, pre_caches, pre_src_attn_bias)
def test_case(self): inputs = {"key1": 1, "key2": 2} func = lambda x: x + 1 outputs = utils.map_structure(func, inputs) utils.assert_same_structure(inputs, outputs) try: inputs["key3"] = 3 utils.assert_same_structure(inputs, outputs) except ValueError as identifier: pass
def varbase_to_numpy(self, res): if isinstance(res, (list, tuple)): res = map_structure(lambda x: x.numpy(), res) else: res = [res.numpy()] return res
def beam_search(self, input_ids, beam_scorer, logits_processors, max_length, pad_token_id, eos_token_id, **model_kwargs): batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape origin_len = cur_len assert ( num_beams * batch_size == batch_beam_size ), "Batch dimension of `input_ids` should be {}, but received {}.".format( num_beams * batch_size, batch_beam_size) beam_scores = paddle.zeros((batch_size, num_beams), dtype=paddle.get_default_dtype()) beam_scores[:, 1:] = -1e9 beam_scores = paddle.reshape(beam_scores, [-1]) while cur_len < max_length: # prepare model inputs & get model output model_inputs = self.prepare_inputs_for_generation( input_ids, **model_kwargs) outputs = self(**model_inputs) logits = outputs[0] if isinstance(outputs, tuple) else outputs # [batch_size, vocab_size] logits = logits[:, -1, :] # pre-process distribution logits = self.adjust_logits_during_generation(logits) logits = logits_processors(input_ids, logits) # beam search # [batch_size * num_beams, vocab_size] next_scores = F.softmax(logits) next_scores = paddle.log(next_scores) next_scores = next_scores + beam_scores.unsqueeze(-1) # reshape for beam search vocab_size = next_scores.shape[-1] next_scores = next_scores.reshape( [batch_size, num_beams * vocab_size]) next_scores, next_tokens = paddle.topk(next_scores, 2 * num_beams, axis=1) next_indices = next_tokens // vocab_size next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] cur_len += 1 input_ids = paddle.concat([ paddle.index_select(input_ids, beam_idx), beam_next_tokens.unsqueeze(-1) ], axis=-1) if beam_scorer.is_done: break model_kwargs = self.update_model_kwargs_for_generation( outputs, model_kwargs) if model_kwargs["cache"] is not None: # reorder the cache model_kwargs["cache"] = map_structure( lambda x: paddle.index_select(x, beam_idx), model_kwargs["cache"]) pred_ids, scores = beam_scorer.finalize(input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id) return pred_ids[:, origin_len:], scores
def beam_search(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias, bos_id=0, eos_id=1, beam_size=4, max_len=256): def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) tile_dims = [1] * len(tensor.shape) tile_dims[1] = beam_size return layers.expand(tensor, tile_dims) def merge_batch_beams(tensor): return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] + tensor.shape[2:]) def split_batch_beams(tensor): return fluid.layers.reshape(tensor, shape=[-1, beam_size] + list(tensor.shape[1:])) def mask_probs(probs, finished, noend_mask_tensor): # TODO: use where_op finished = layers.cast(finished, dtype=probs.dtype) probs = layers.elementwise_mul(layers.expand( layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), noend_mask_tensor, axis=-1) - layers.elementwise_mul( probs, (finished - 1), axis=0) return probs def gather(x, indices, batch_pos): topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) return layers.gather_nd(x, topk_coordinates) # run encoder enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) # constant number inf = float(1. * 1e7) batch_size = enc_output.shape[0] max_len = (enc_output.shape[1] + 20) if max_len is None else max_len vocab_size_tensor = layers.fill_constant(shape=[1], dtype="int64", value=self.trg_vocab_size) end_token_tensor = to_variable( np.full([batch_size, beam_size], eos_id, dtype="int64")) noend_array = [-inf] * self.trg_vocab_size noend_array[eos_id] = 0 noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32")) batch_pos = layers.expand( layers.unsqueeze( to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]), [1, beam_size]) predict_ids = [] parent_ids = [] ### initialize states of beam search ### log_probs = to_variable( np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size, dtype="float32")) finished = to_variable( np.full([batch_size, beam_size], 0, dtype="bool")) ### initialize inputs and states of transformer decoder ### ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1], dtype="int64", value=bos_id) trg_pos = layers.zeros_like(trg_word) trg_src_attn_bias = merge_batch_beams( expand_to_beam_size(trg_src_attn_bias, beam_size)) enc_output = merge_batch_beams( expand_to_beam_size(enc_output, beam_size)) ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ "k": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), "v": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), } for i in range(self.n_layer)] for i in range(max_len): trg_pos = layers.fill_constant(shape=trg_word.shape, dtype="int64", value=i) caches = map_structure( # can not be reshaped since the 0 size lambda x: x if i == 0 else merge_batch_beams(x), caches) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) caches = map_structure(split_batch_beams, caches) step_log_probs = split_batch_beams( fluid.layers.log(fluid.layers.softmax(logits))) step_log_probs = mask_probs(step_log_probs, finished, noend_mask_tensor) log_probs = layers.elementwise_add(x=step_log_probs, y=log_probs, axis=0) log_probs = layers.reshape(log_probs, [-1, beam_size * self.trg_vocab_size]) scores = log_probs topk_scores, topk_indices = fluid.layers.topk(input=scores, k=beam_size) beam_indices = fluid.layers.elementwise_floordiv( topk_indices, vocab_size_tensor) token_indices = fluid.layers.elementwise_mod( topk_indices, vocab_size_tensor) # update states caches = map_structure( lambda x: gather(x, beam_indices, batch_pos), caches) log_probs = gather(log_probs, topk_indices, batch_pos) finished = gather(finished, beam_indices, batch_pos) finished = layers.logical_or( finished, layers.equal(token_indices, end_token_tensor)) trg_word = layers.reshape(token_indices, [-1, 1]) predict_ids.append(token_indices) parent_ids.append(beam_indices) if layers.reduce_all(finished).numpy(): break predict_ids = layers.stack(predict_ids, axis=0) parent_ids = layers.stack(parent_ids, axis=0) finished_seq = layers.transpose( layers.gather_tree(predict_ids, parent_ids), [1, 2, 0]) finished_scores = topk_scores return finished_seq, finished_scores
def tile_beam_merge_with_batch(t, beam_size): return map_structure( lambda x: nn.decode.BeamSearchDecoder.tile_beam_merge_with_batch( x, beam_size), t)
def forward(self, inputs, initial_states=None, sequence_length=None, **kwargs): if fluid.in_dygraph_mode(): class OutputArray(object): def __init__(self, x): self.array = [x] def append(self, x): self.array.append(x) def _maybe_copy(state, new_state, step_mask): # TODO: use where_op new_state = fluid.layers.elementwise_mul( new_state, step_mask, axis=0) - fluid.layers.elementwise_mul(state, (step_mask - 1), axis=0) return new_state flat_inputs = flatten(inputs) batch_size, time_steps = ( flat_inputs[0].shape[self.batch_index], flat_inputs[0].shape[self.time_step_index]) if initial_states is None: initial_states = self.cell.get_initial_states( batch_ref=inputs, batch_dim_idx=self.batch_index) if not self.time_major: inputs = map_structure( lambda x: fluid.layers.transpose(x, [1, 0] + list( range(2, len(x.shape)))), inputs) if sequence_length is not None: mask = fluid.layers.sequence_mask( sequence_length, maxlen=time_steps, dtype=flatten(initial_states)[0].dtype) mask = fluid.layers.transpose(mask, [1, 0]) if self.is_reverse: inputs = map_structure( lambda x: fluid.layers.reverse(x, axis=[0]), inputs) mask = fluid.layers.reverse( mask, axis=[0]) if sequence_length is not None else None states = initial_states outputs = [] for i in range(time_steps): step_inputs = map_structure(lambda x: x[i], inputs) step_outputs, new_states = self.cell(step_inputs, states, **kwargs) if sequence_length is not None: new_states = map_structure( partial(_maybe_copy, step_mask=mask[i]), states, new_states) states = new_states if i == 0: outputs = map_structure(lambda x: OutputArray(x), step_outputs) else: map_structure(lambda x, x_array: x_array.append(x), step_outputs, outputs) final_outputs = map_structure( lambda x: fluid.layers.stack(x.array, axis=self.time_step_index), outputs) if self.is_reverse: final_outputs = map_structure( lambda x: fluid.layers.reverse(x, axis=self.time_step_index), final_outputs) final_states = new_states else: final_outputs, final_states = fluid.layers.rnn( self.cell, inputs, initial_states=initial_states, sequence_length=sequence_length, time_major=self.time_major, is_reverse=self.is_reverse, **kwargs) return final_outputs, final_states
def get_initial_states(self, batch_ref, shape=None, dtype=None, init_value=0, batch_dim_idx=0): """ Generate initialized states according to provided shape, data type and value. Parameters: batch_ref: A (possibly nested structure of) tensor variable[s]. The first dimension of the tensor will be used as batch size to initialize states. shape: A (possiblely nested structure of) shape[s], where a shape is represented as a list/tuple of integer). -1(for batch size) will beautomatically inserted if shape is not started with it. If None, property `state_shape` will be used. The default value is None. dtype: A (possiblely nested structure of) data type[s]. The structure must be same as that of `shape`, except when all tensors' in states has the same data type, a single data type can be used. If None and property `cell.state_shape` is not available, float32 will be used as the data type. The default value is None. init_value: A float value used to initialize states. Returns: Variable: tensor variable[s] packed in the same structure provided \ by shape, representing the initialized states. """ # TODO: use inputs and batch_size batch_ref = flatten(batch_ref)[0] def _is_shape_sequence(seq): if sys.version_info < (3, ): integer_types = ( int, long, ) else: integer_types = (int, ) """For shape, list/tuple of integer is the finest-grained objection""" if (isinstance(seq, list) or isinstance(seq, tuple)): if reduce( lambda flag, x: isinstance(x, integer_types) and flag, seq, True): return False # TODO: Add check for the illegal if isinstance(seq, dict): return True return (isinstance(seq, collections.Sequence) and not isinstance(seq, six.string_types)) class Shape(object): def __init__(self, shape): self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) # nested structure of shapes states_shapes = self.state_shape if shape is None else shape is_sequence_ori = utils.is_sequence utils.is_sequence = _is_shape_sequence states_shapes = map_structure(lambda shape: Shape(shape), states_shapes) utils.is_sequence = is_sequence_ori # nested structure of dtypes try: states_dtypes = self.state_dtype if dtype is None else dtype except NotImplementedError: # use fp32 as default states_dtypes = "float32" if len(flatten(states_dtypes)) == 1: dtype = flatten(states_dtypes)[0] states_dtypes = map_structure(lambda shape: dtype, states_shapes) init_states = map_structure( lambda shape, dtype: fluid.layers.fill_constant_batch_size_like( input=batch_ref, shape=shape.shape, dtype=dtype, value=init_value, input_dim_idx=batch_dim_idx), states_shapes, states_dtypes) return init_states
def _convert_input(self, input): return map_structure( lambda x: to_variable(x) if isinstance(x, np.ndarray) else x, input)