def sample_from_mog(self, y): """Sample from the output distribution where the output distribution is a mixture of Gaussians. Args: y (Variable): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture. Returns: Variable: shape(B, T), waveform sampled from the output distribution. """ batch_size, time_steps, output_dim = y.shape n_mixture = output_dim // 3 w, mu, log_std = F.split(y, 3, dim=-1) reshaped_w = F.reshape(w, (batch_size * time_steps, n_mixture)) prob_ids = F.sampling_id(F.softmax(reshaped_w)) prob_ids = F.reshape(prob_ids, (batch_size, time_steps)) prob_ids = prob_ids.numpy() index = np.array([[[b, t, prob_ids[b, t]] for t in range(time_steps)] for b in range(batch_size)]).astype("int32") index_var = dg.to_variable(index) mu_ = F.gather_nd(mu, index_var) log_std_ = F.gather_nd(log_std, index_var) dist = D.Normal(mu_, F.exp(log_std_)) samples = dist.sample(shape=[]) samples = F.clip(samples, min=-1., max=1.) return samples
def test_sampling_id(self): program = Program() with program_guard(program): x = layers.data(name="X", shape=[13, 11], dtype='float32', append_batch_size=False) out = layers.sampling_id(x) self.assertIsNotNone(out) print(str(program))
def topk_sampling(self, probs): topk_probs, _ = paddle.topk(probs, self.topk) ge_cond = paddle.cast( paddle.greater_equal(probs, paddle.unsqueeze(topk_probs[:, -1], [1])), "float32") old_probs = probs probs = probs * ge_cond / paddle.sum(topk_probs, axis=-1, keepdim=True) sampling_ids = layers.sampling_id(probs, dtype="int") probs = old_probs return probs, sampling_ids
def sample_from_softmax(self, y): """Sample from the output distribution where the output distribution is a categorical distriobution. Args: y (Variable): shape(B, T, C_output), the logits of the output distribution Returns: Variable: shape(B, T), waveform sampled from the output distribution. """ # dequantize batch_size, time_steps, output_dim, = y.shape y = F.reshape(y, (batch_size * time_steps, output_dim)) prob = F.softmax(y) quantized = F.sampling_id(prob) samples = dequantize(quantized, n_bands=self.output_dim) samples = F.reshape(samples, (batch_size, -1)) return samples
def topp_sampling(self, probs): sorted_probs, sorted_idx = layers.argsort(probs, descending=True) cum_sorted_probs = layers.cumsum(sorted_probs, axis=1, exclusive=True) lt_cond = paddle.cast( paddle.less_than( cum_sorted_probs, layers.fill_constant_batch_size_like(cum_sorted_probs, cum_sorted_probs.shape, cum_sorted_probs.dtype, self.topp)), "float32") old_probs = probs candidate_probs = sorted_probs * lt_cond probs = candidate_probs / paddle.sum( candidate_probs, axis=-1, keep_dim=True) sampling_ids = layers.sampling_id(probs, dtype="int") sampling_ids = paddle.index_sample(sorted_idx, paddle.unsqueeze(sampling_ids, [1])) sampling_ids = paddle.squeeze(sampling_ids, [1]) probs = old_probs return probs, sampling_ids
def inference(self, model, inputs, outputs): """ Run inference. Args: inputs(dict): Its key is input name(str) and its value is a Variable. model(object): A generate model. Need to implement `_generation_network` and `_calc_logits`. Returns: dict(str:Variable): Its key is output name(str) and its value is a Variable. """ # prepare while loop max_len = layers.fill_constant( shape=[1], dtype="int64", value=self.max_dec_len, force_cpu=True) min_len = layers.fill_constant( shape=[1], dtype="int64", value=self.min_dec_len, force_cpu=True) step_idx = layers.fill_constant( shape=[1], dtype="int64", value=0, force_cpu=True) ids = layers.array_write(layers.reshape(inputs["tgt_ids"], (-1, 1)), step_idx) pos_biases = layers.array_write(layers.reshape(inputs["tgt_pos"], (-1, 1)), step_idx) scores = layers.array_write(inputs["init_score"], step_idx) tgt_generation_mask = layers.array_write(inputs["tgt_generation_mask"], step_idx) parent_idx = inputs["parent_idx"] if self.decoding_strategy == "beam_search": beam_size = self.beam_size else: beam_size = 1 eos_penalty = np.zeros(self.vocab_size, dtype="float32") eos_penalty[self.eos_id] = -1e9 eos_penalty = layers.assign(eos_penalty) token_penalty = np.zeros(self.vocab_size, dtype="float32") token_penalty[self.unk_id] = -1e9 if self.mask_id >= 0: token_penalty[self.mask_id] = -1e9 token_penalty = layers.assign(token_penalty) # start while loop cond = layers.less_than(x=step_idx, y=max_len) while_op = layers.While(cond) with while_op.block(): pre_ids = layers.array_read(array=ids, i=step_idx) pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True) pre_scores = layers.array_read(array=scores, i=step_idx) pos_bias = layers.array_read(array=pos_biases, i=step_idx) pos_bias = layers.gather(input=pos_bias, index=parent_idx) tmp_tgt_generation_mask = layers.array_read(tgt_generation_mask, i=step_idx) dtype = tmp_tgt_generation_mask.dtype append_mask = layers.fill_constant_batch_size_like( input=pre_ids, value=1.0, shape=[-1, 1, 1], dtype=dtype) tmp_tgt_generation_mask = layers.concat([tmp_tgt_generation_mask, append_mask], axis=2) pre_mask = tmp_tgt_generation_mask = layers.gather(input=tmp_tgt_generation_mask, index=parent_idx) pre_sent = layers.fill_constant_batch_size_like( input=pre_mask, value=1, shape=[-1, 1, 1], dtype=pre_ids.dtype) if self.continuous_position: pre_pos = layers.elementwise_mul( x=layers.fill_constant_batch_size_like( input=pre_mask, value=1, shape=[-1, 1, 1], dtype=pre_ids.dtype), y=step_idx, axis=0) + pos_bias else: pre_pos = layers.elementwise_mul( x=layers.fill_constant_batch_size_like( input=pre_mask, value=1, shape=[-1, 1, 1], dtype=pre_ids.dtype), y=step_idx, axis=0) if self.use_role: pre_role = layers.fill_constant_batch_size_like( input=pre_mask, value=0, shape=[-1, 1, 1], dtype=pre_ids.dtype) else: pre_role = None dec_out, _ = model._generation_network( token_ids=pre_ids, type_ids=pre_sent, pos_ids=pre_pos, role_ids=pre_role, generation_mask=tmp_tgt_generation_mask, gather_idx=parent_idx) logits = model._calc_logits(dec_out) # ignore unk and mask token if self.ignore_unk: logits = layers.elementwise_add(logits, token_penalty, axis=1) # min dec length min_len_cond = layers.less_than(x=step_idx, y=min_len) def min_len_penalty(): """Plus minimum length penalty.""" return layers.elementwise_add(logits, eos_penalty, axis=1) def no_penalty(): """No penalty.""" return logits logits = layers.case([(min_len_cond, min_len_penalty)], default=no_penalty) # get probs probs = layers.softmax(logits / self.temperature) if self.decoding_strategy == "beam_search": topk_scores, topk_indices = layers.topk( input=probs, k=beam_size) else: if self.decoding_strategy.startswith("sampling"): sampling_ids = layers.sampling_id(probs, dtype="int") elif self.decoding_strategy.startswith("topk_sampling"): topk_probs, _ = layers.topk(input=probs, k=self.topk) ge_cond = layers.cast( layers.greater_equal( probs, layers.unsqueeze(topk_probs[:, -1], [1])), "float32") old_probs = probs probs = probs * ge_cond / layers.reduce_sum(topk_probs, dim=-1, keep_dim=True) sampling_ids = layers.sampling_id(probs, dtype="int") probs = old_probs else: raise ValueError(self.decoding_strategy) sampling_scores = layers.one_hot( layers.unsqueeze(sampling_ids, [1]), probs.shape[1] ) sampling_scores = sampling_scores * probs - (1 - sampling_scores) * 1e3 topk_scores, topk_indices = layers.topk( input=sampling_scores, k=1) pre_len = layers.cast(step_idx, "float32") layers.increment(x=step_idx, value=1.0, in_place=True) cur_len = layers.cast(step_idx, "float32") # update scores if self.length_average: accu_scores = layers.elementwise_add( x=layers.log(topk_scores), y=pre_scores * pre_len, axis=0) / cur_len elif self.length_penalty > 0: pre_lp = layers.pow((5 + pre_len) / 6, self.length_penalty) cur_lp = layers.pow((5 + cur_len) / 6, self.length_penalty) accu_scores = layers.elementwise_add( x=layers.log(topk_scores), y=pre_scores * pre_lp, axis=0) / cur_lp else: accu_scores = layers.elementwise_add( x=layers.log(topk_scores), y=pre_scores, axis=0) topk_indices = layers.lod_reset(topk_indices, pre_ids) accu_scores = layers.lod_reset(accu_scores, pre_ids) selected_ids, selected_scores, gather_idx = layers.beam_search( pre_ids=pre_ids, pre_scores=pre_scores, ids=topk_indices, scores=accu_scores, beam_size=beam_size, end_id=self.eos_id, return_parent_idx=True) layers.array_write(selected_ids, i=step_idx, array=ids) layers.array_write(selected_scores, i=step_idx, array=scores) layers.array_write(pre_mask, i=step_idx, array=tgt_generation_mask) layers.array_write(pos_bias, i=step_idx, array=pos_biases) layers.assign(gather_idx, parent_idx) length_cond = layers.less_than(x=step_idx, y=max_len) finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) layers.logical_and(x=length_cond, y=finish_cond, out=cond) finished_ids, finished_scores = layers.beam_search_decode( ids, scores, beam_size=beam_size, end_id=self.eos_id) predictions = { "finished_ids": finished_ids, "finished_scores": finished_scores, "token_ids": inputs["token_ids"], "data_id": inputs["data_id"] } return predictions
def _build_decoder(self, z_mean=None, z_log_var=None, enc_output=None, mode='train', beam_size=10): dec_input = layers.dropout(self.tar_emb, dropout_prob=self.dec_dropout_in, dropout_implementation="upscale_in_train") # `output_layer` will be used within BeamSearchDecoder output_layer = lambda x: layers.fc(x, size=self.tar_vocab_size, num_flatten_dims=len(x.shape) - 1, name="output_w") # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits) # it will be used within BeamSearchDecoder sample_output_layer = lambda x: layers.unsqueeze( fluid.one_hot(layers.unsqueeze( layers.sampling_id(layers.softmax( layers.squeeze(output_layer(x), [1])), dtype='int'), [1]), depth=self.tar_vocab_size), [1]) if mode == 'train': latent_z = self._sampling(z_mean, z_log_var) else: latent_z = layers.gaussian_random_batch_size_like( self.tar, shape=[-1, self.latent_size]) dec_first_hidden_cell = layers.fc(latent_z, 2 * self.hidden_size * self.num_layers, name='fc_hc') dec_first_hidden, dec_first_cell = layers.split( dec_first_hidden_cell, 2) if self.num_layers > 1: dec_first_hidden = layers.split(dec_first_hidden, self.num_layers) dec_first_cell = layers.split(dec_first_cell, self.num_layers) else: dec_first_hidden = [dec_first_hidden] dec_first_cell = [dec_first_cell] dec_initial_states = [[h, c] for h, c in zip(dec_first_hidden, dec_first_cell) ] dec_cell = DecoderCell(self.num_layers, self.hidden_size, latent_z, self.param_attr_initializer, self.param_attr_scale, self.dec_dropout_out) if mode == 'train': dec_output, _ = rnn(cell=dec_cell, inputs=dec_input, initial_states=dec_initial_states, sequence_length=self.tar_sequence_length) dec_output = output_layer(dec_output) return dec_output elif mode == 'greedy': start_token = 1 end_token = 2 max_length = 100 beam_search_decoder = BeamSearchDecoder( dec_cell, start_token, end_token, beam_size=1, embedding_fn=self.tar_embeder, output_fn=output_layer) outputs, _ = dynamic_decode(beam_search_decoder, inits=dec_initial_states, max_step_num=max_length) return outputs elif mode == 'sampling': start_token = 1 end_token = 2 max_length = 100 beam_search_decoder = BeamSearchDecoder( dec_cell, start_token, end_token, beam_size=1, embedding_fn=self.tar_embeder, output_fn=sample_output_layer) outputs, _ = dynamic_decode(beam_search_decoder, inits=dec_initial_states, max_step_num=max_length) return outputs else: print("mode not supprt", mode)
def forward(self, inputs, use_cache=False, cache=None): """ Args: inputs (dict): include src_ids. pos_ids, input_mask and max_dec_len are optional. """ ######### forward context ######### input_ids = inputs['src_ids'] position_ids = inputs['pos_ids'] if 'pos_ids' in inputs else None attention_mask = inputs[ 'input_mask'] if 'input_mask' in inputs else None causal_mask = paddle.tensor.triu(paddle.ones( (paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4, diagonal=1) if attention_mask is not None: tgt_pos = paddle.sum(attention_mask, axis=-1, keepdim=True).astype('int64') if len(attention_mask.shape) == 2: attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]) encode_mask = attention_mask + causal_mask else: encode_mask = causal_mask # if cached_kvs are assigned to next step in _prepare_qkv of MultiHeadAttention, # need to init the global caches here gen_caches = self._init_generation_caches(input_ids) logits, cached_kvs = self.model(input_ids, position_ids, encode_mask, use_cache=True, cache=gen_caches) next_id = paddle.argmax(logits[:, -1, :], axis=-1).reshape([-1, 1]) #################################### if 'max_dec_len' not in inputs: max_len = layers.fill_constant([1], dtype=int_type, value=self.max_dec_len, force_cpu=True) else: max_len = inputs['max_dec_len'] min_len = layers.fill_constant(shape=[1], dtype=int_type, value=self.min_dec_len, force_cpu=True) step_idx = layers.fill_constant(shape=[1], value=0, dtype='int64', force_cpu=True) placehold_ids = layers.fill_constant_batch_size_like( input=inputs["src_ids"], value=0, shape=[-1, 1], dtype=next_id.dtype) ids = layers.array_write(next_id, step_idx) if 'max_dec_len' in inputs: max_len = paddle.tensor.creation._memcpy(max_len, place=paddle.CPUPlace()) cond_int = paddle.full([1], 0, dtype=int_type, name="cond_int") cond = paddle.less_than(step_idx, max_len) if attention_mask is not None: append_mask = layers.fill_constant_batch_size_like( input=next_id, value=1, shape=[-1, 1, 1, 1], dtype=attention_mask.dtype) while_op = layers.While(cond, is_test=True) with while_op.block(): pre_ids = layers.array_read(array=ids, i=step_idx) if attention_mask: decode_mask = paddle.concat([attention_mask, append_mask], axis=-1) tgt_pos = tgt_pos + step_idx att_mask = (1 - decode_mask) * -1e4 else: att_mask = None tgt_pos = None layers.increment(x=step_idx, value=1.0, in_place=True) layers.array_write(placehold_ids, i=step_idx, array=ids) logits, decode_cached_kvs = self.model(pre_ids, tgt_pos, att_mask, use_cache=True, cache=cached_kvs) logits = paddle.reshape(logits, shape=(-1, self.vocab_size)) probs = F.softmax(logits / self.temperature) if self.decoding_strategy.startswith("sampling"): sampling_ids = layers.sampling_id(probs, dtype="int") elif self.decoding_strategy.startswith("topk_sampling"): probs, sampling_ids = self.topk_sampling(probs) elif self.decoding_strategy.startswith("topp_sampling"): probs, sampling_ids = self.topp_sampling(probs) else: raise ValueError(self.decoding_strategy) selected_ids = paddle.unsqueeze(sampling_ids, -1) layers.array_write(selected_ids, i=step_idx, array=ids) length_cond = paddle.less_than(x=step_idx, y=max_len, name="length_cond") finish_cond = paddle.logical_not(paddle.is_empty(x=selected_ids), name="finish_cond") paddle.logical_and(x=length_cond, y=finish_cond, out=cond, name="logical_and_cond") paddle.assign(layers.cast(cond, dtype='bool'), cond) if attention_mask: paddle.assign(decode_mask, attention_mask) for i in range(len(decode_cached_kvs)): if self._fuse: paddle.assign(decode_cached_kvs[i].kv, cached_kvs[i].kv) else: paddle.assign(decode_cached_kvs[i].k, cached_kvs[i].k) paddle.assign(decode_cached_kvs[i].v, cached_kvs[i].v) ids, _ = layers.tensor_array_to_tensor(ids) return ids