Пример #1
0
    def sample_from_mog(self, y):
        """Sample from the output distribution where the output distribution is a mixture of Gaussians.
        Args:
            y (Variable): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.

        Returns:
            Variable: shape(B, T), waveform sampled from the output distribution.
        """
        batch_size, time_steps, output_dim = y.shape
        n_mixture = output_dim // 3

        w, mu, log_std = F.split(y, 3, dim=-1)

        reshaped_w = F.reshape(w, (batch_size * time_steps, n_mixture))
        prob_ids = F.sampling_id(F.softmax(reshaped_w))
        prob_ids = F.reshape(prob_ids, (batch_size, time_steps))
        prob_ids = prob_ids.numpy()

        index = np.array([[[b, t, prob_ids[b, t]] for t in range(time_steps)]
                          for b in range(batch_size)]).astype("int32")
        index_var = dg.to_variable(index)

        mu_ = F.gather_nd(mu, index_var)
        log_std_ = F.gather_nd(log_std, index_var)

        dist = D.Normal(mu_, F.exp(log_std_))
        samples = dist.sample(shape=[])
        samples = F.clip(samples, min=-1., max=1.)
        return samples
Пример #2
0
    def test_sampling_id(self):
        program = Program()
        with program_guard(program):
            x = layers.data(name="X",
                            shape=[13, 11],
                            dtype='float32',
                            append_batch_size=False)

            out = layers.sampling_id(x)
            self.assertIsNotNone(out)
        print(str(program))
Пример #3
0
 def topk_sampling(self, probs):
     topk_probs, _ = paddle.topk(probs, self.topk)
     ge_cond = paddle.cast(
         paddle.greater_equal(probs,
                              paddle.unsqueeze(topk_probs[:, -1], [1])),
         "float32")
     old_probs = probs
     probs = probs * ge_cond / paddle.sum(topk_probs, axis=-1, keepdim=True)
     sampling_ids = layers.sampling_id(probs, dtype="int")
     probs = old_probs
     return probs, sampling_ids
Пример #4
0
    def sample_from_softmax(self, y):
        """Sample from the output distribution where the output distribution is a categorical distriobution.

        Args:
            y (Variable): shape(B, T, C_output), the logits of the output distribution

        Returns:
            Variable: shape(B, T), waveform sampled from the output distribution.
        """
        # dequantize
        batch_size, time_steps, output_dim, = y.shape
        y = F.reshape(y, (batch_size * time_steps, output_dim))
        prob = F.softmax(y)
        quantized = F.sampling_id(prob)
        samples = dequantize(quantized, n_bands=self.output_dim)
        samples = F.reshape(samples, (batch_size, -1))
        return samples
Пример #5
0
 def topp_sampling(self, probs):
     sorted_probs, sorted_idx = layers.argsort(probs, descending=True)
     cum_sorted_probs = layers.cumsum(sorted_probs, axis=1, exclusive=True)
     lt_cond = paddle.cast(
         paddle.less_than(
             cum_sorted_probs,
             layers.fill_constant_batch_size_like(cum_sorted_probs,
                                                  cum_sorted_probs.shape,
                                                  cum_sorted_probs.dtype,
                                                  self.topp)), "float32")
     old_probs = probs
     candidate_probs = sorted_probs * lt_cond
     probs = candidate_probs / paddle.sum(
         candidate_probs, axis=-1, keep_dim=True)
     sampling_ids = layers.sampling_id(probs, dtype="int")
     sampling_ids = paddle.index_sample(sorted_idx,
                                        paddle.unsqueeze(sampling_ids, [1]))
     sampling_ids = paddle.squeeze(sampling_ids, [1])
     probs = old_probs
     return probs, sampling_ids
Пример #6
0
    def inference(self, model, inputs, outputs):
        """
        Run inference.

        Args:
            inputs(dict): Its key is input name(str) and its value is a Variable.
            model(object): A generate model. Need to implement `_generation_network` and `_calc_logits`.

        Returns:
            dict(str:Variable): Its key is output name(str) and its value is a Variable.
        """
        # prepare while loop
        max_len = layers.fill_constant(
            shape=[1], dtype="int64", value=self.max_dec_len, force_cpu=True)
        min_len = layers.fill_constant(
            shape=[1], dtype="int64", value=self.min_dec_len, force_cpu=True)
        step_idx = layers.fill_constant(
            shape=[1], dtype="int64", value=0, force_cpu=True)

        ids = layers.array_write(layers.reshape(inputs["tgt_ids"], (-1, 1)), step_idx)
        pos_biases = layers.array_write(layers.reshape(inputs["tgt_pos"], (-1, 1)), step_idx)
        scores = layers.array_write(inputs["init_score"], step_idx)
        tgt_generation_mask = layers.array_write(inputs["tgt_generation_mask"], step_idx)
        parent_idx = inputs["parent_idx"]

        if self.decoding_strategy == "beam_search":
            beam_size = self.beam_size
        else:
            beam_size = 1

        eos_penalty = np.zeros(self.vocab_size, dtype="float32")
        eos_penalty[self.eos_id] = -1e9
        eos_penalty = layers.assign(eos_penalty)

        token_penalty = np.zeros(self.vocab_size, dtype="float32")
        token_penalty[self.unk_id] = -1e9
        if self.mask_id >= 0:
            token_penalty[self.mask_id] = -1e9
        token_penalty = layers.assign(token_penalty)

        # start while loop
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
            pre_scores = layers.array_read(array=scores, i=step_idx)
            pos_bias = layers.array_read(array=pos_biases, i=step_idx)
            pos_bias = layers.gather(input=pos_bias, index=parent_idx)

            tmp_tgt_generation_mask = layers.array_read(tgt_generation_mask, i=step_idx)
            dtype = tmp_tgt_generation_mask.dtype

            append_mask = layers.fill_constant_batch_size_like(
                    input=pre_ids,
                    value=1.0,
                    shape=[-1, 1, 1],
                    dtype=dtype)
            tmp_tgt_generation_mask = layers.concat([tmp_tgt_generation_mask, append_mask], axis=2)
            pre_mask = tmp_tgt_generation_mask = layers.gather(input=tmp_tgt_generation_mask, index=parent_idx)

            pre_sent = layers.fill_constant_batch_size_like(
                    input=pre_mask,
                    value=1,
                    shape=[-1, 1, 1],
                    dtype=pre_ids.dtype)

            if self.continuous_position:
                pre_pos = layers.elementwise_mul(
                    x=layers.fill_constant_batch_size_like(
                        input=pre_mask,
                        value=1,
                        shape=[-1, 1, 1],
                        dtype=pre_ids.dtype), y=step_idx, axis=0) + pos_bias
            else:
                pre_pos = layers.elementwise_mul(
                    x=layers.fill_constant_batch_size_like(
                        input=pre_mask,
                        value=1,
                        shape=[-1, 1, 1],
                        dtype=pre_ids.dtype), y=step_idx, axis=0)

            if self.use_role:
                pre_role = layers.fill_constant_batch_size_like(
                        input=pre_mask,
                        value=0,
                        shape=[-1, 1, 1],
                        dtype=pre_ids.dtype)
            else:
                pre_role = None

            dec_out, _ = model._generation_network(
                token_ids=pre_ids,
                type_ids=pre_sent,
                pos_ids=pre_pos,
                role_ids=pre_role,
                generation_mask=tmp_tgt_generation_mask,
                gather_idx=parent_idx)
            logits = model._calc_logits(dec_out)

            # ignore unk and mask token
            if self.ignore_unk:
                logits = layers.elementwise_add(logits, token_penalty, axis=1)

            # min dec length
            min_len_cond = layers.less_than(x=step_idx, y=min_len)
            def min_len_penalty():
                """Plus minimum length penalty."""
                return layers.elementwise_add(logits, eos_penalty, axis=1)
            def no_penalty():
                """No penalty."""
                return logits
            logits = layers.case([(min_len_cond, min_len_penalty)], default=no_penalty)

            # get probs
            probs = layers.softmax(logits / self.temperature)

            if self.decoding_strategy == "beam_search":
                topk_scores, topk_indices = layers.topk(
                    input=probs, k=beam_size)
            else:
                if self.decoding_strategy.startswith("sampling"):
                    sampling_ids = layers.sampling_id(probs, dtype="int")
                elif self.decoding_strategy.startswith("topk_sampling"):
                    topk_probs, _ = layers.topk(input=probs, k=self.topk)
                    ge_cond = layers.cast(
                        layers.greater_equal(
                            probs,
                            layers.unsqueeze(topk_probs[:, -1], [1])),
                        "float32")
                    old_probs = probs
                    probs = probs * ge_cond / layers.reduce_sum(topk_probs, dim=-1, keep_dim=True)
                    sampling_ids = layers.sampling_id(probs, dtype="int")
                    probs = old_probs
                else:
                    raise ValueError(self.decoding_strategy)

                sampling_scores = layers.one_hot(
                    layers.unsqueeze(sampling_ids, [1]), probs.shape[1]
                )
                sampling_scores = sampling_scores * probs - (1 - sampling_scores) * 1e3
                topk_scores, topk_indices = layers.topk(
                    input=sampling_scores, k=1)

            pre_len = layers.cast(step_idx, "float32")
            layers.increment(x=step_idx, value=1.0, in_place=True)
            cur_len = layers.cast(step_idx, "float32")

            # update scores
            if self.length_average:
                accu_scores = layers.elementwise_add(
                    x=layers.log(topk_scores), y=pre_scores * pre_len, axis=0) / cur_len
            elif self.length_penalty > 0:
                pre_lp = layers.pow((5 + pre_len) / 6, self.length_penalty)
                cur_lp = layers.pow((5 + cur_len) / 6, self.length_penalty)
                accu_scores = layers.elementwise_add(
                    x=layers.log(topk_scores), y=pre_scores * pre_lp, axis=0) / cur_lp
            else:
                accu_scores = layers.elementwise_add(
                    x=layers.log(topk_scores), y=pre_scores, axis=0)
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
            accu_scores = layers.lod_reset(accu_scores, pre_ids)
            selected_ids, selected_scores, gather_idx = layers.beam_search(
                pre_ids=pre_ids,
                pre_scores=pre_scores,
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=self.eos_id,
                return_parent_idx=True)

            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
            layers.array_write(pre_mask, i=step_idx, array=tgt_generation_mask)
            layers.array_write(pos_bias, i=step_idx, array=pos_biases)

            layers.assign(gather_idx, parent_idx)

            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)

        finished_ids, finished_scores = layers.beam_search_decode(
            ids, scores, beam_size=beam_size, end_id=self.eos_id)

        predictions = {
            "finished_ids": finished_ids,
            "finished_scores": finished_scores,
            "token_ids": inputs["token_ids"],
            "data_id": inputs["data_id"]
        }
        return predictions
Пример #7
0
    def _build_decoder(self,
                       z_mean=None,
                       z_log_var=None,
                       enc_output=None,
                       mode='train',
                       beam_size=10):
        dec_input = layers.dropout(self.tar_emb,
                                   dropout_prob=self.dec_dropout_in,
                                   dropout_implementation="upscale_in_train")

        # `output_layer` will be used within BeamSearchDecoder
        output_layer = lambda x: layers.fc(x,
                                           size=self.tar_vocab_size,
                                           num_flatten_dims=len(x.shape) - 1,
                                           name="output_w")

        # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits)
        # it will be used within BeamSearchDecoder
        sample_output_layer = lambda x: layers.unsqueeze(
            fluid.one_hot(layers.unsqueeze(
                layers.sampling_id(layers.softmax(
                    layers.squeeze(output_layer(x), [1])),
                                   dtype='int'), [1]),
                          depth=self.tar_vocab_size), [1])

        if mode == 'train':
            latent_z = self._sampling(z_mean, z_log_var)
        else:
            latent_z = layers.gaussian_random_batch_size_like(
                self.tar, shape=[-1, self.latent_size])
        dec_first_hidden_cell = layers.fc(latent_z,
                                          2 * self.hidden_size *
                                          self.num_layers,
                                          name='fc_hc')
        dec_first_hidden, dec_first_cell = layers.split(
            dec_first_hidden_cell, 2)
        if self.num_layers > 1:
            dec_first_hidden = layers.split(dec_first_hidden, self.num_layers)
            dec_first_cell = layers.split(dec_first_cell, self.num_layers)
        else:
            dec_first_hidden = [dec_first_hidden]
            dec_first_cell = [dec_first_cell]
        dec_initial_states = [[h, c]
                              for h, c in zip(dec_first_hidden, dec_first_cell)
                              ]
        dec_cell = DecoderCell(self.num_layers, self.hidden_size, latent_z,
                               self.param_attr_initializer,
                               self.param_attr_scale, self.dec_dropout_out)

        if mode == 'train':
            dec_output, _ = rnn(cell=dec_cell,
                                inputs=dec_input,
                                initial_states=dec_initial_states,
                                sequence_length=self.tar_sequence_length)
            dec_output = output_layer(dec_output)

            return dec_output
        elif mode == 'greedy':
            start_token = 1
            end_token = 2
            max_length = 100
            beam_search_decoder = BeamSearchDecoder(
                dec_cell,
                start_token,
                end_token,
                beam_size=1,
                embedding_fn=self.tar_embeder,
                output_fn=output_layer)
            outputs, _ = dynamic_decode(beam_search_decoder,
                                        inits=dec_initial_states,
                                        max_step_num=max_length)
            return outputs

        elif mode == 'sampling':
            start_token = 1
            end_token = 2
            max_length = 100
            beam_search_decoder = BeamSearchDecoder(
                dec_cell,
                start_token,
                end_token,
                beam_size=1,
                embedding_fn=self.tar_embeder,
                output_fn=sample_output_layer)

            outputs, _ = dynamic_decode(beam_search_decoder,
                                        inits=dec_initial_states,
                                        max_step_num=max_length)
            return outputs
        else:
            print("mode not supprt", mode)
Пример #8
0
    def forward(self, inputs, use_cache=False, cache=None):
        """
        Args:
            inputs (dict): include src_ids.
                pos_ids, input_mask and max_dec_len are optional.
        """
        ######### forward context #########
        input_ids = inputs['src_ids']
        position_ids = inputs['pos_ids'] if 'pos_ids' in inputs else None
        attention_mask = inputs[
            'input_mask'] if 'input_mask' in inputs else None

        causal_mask = paddle.tensor.triu(paddle.ones(
            (paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4,
                                         diagonal=1)
        if attention_mask is not None:
            tgt_pos = paddle.sum(attention_mask, axis=-1,
                                 keepdim=True).astype('int64')
            if len(attention_mask.shape) == 2:
                attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2])
            encode_mask = attention_mask + causal_mask
        else:
            encode_mask = causal_mask

        # if cached_kvs are assigned to next step in _prepare_qkv of MultiHeadAttention,
        # need to init the global caches here
        gen_caches = self._init_generation_caches(input_ids)

        logits, cached_kvs = self.model(input_ids,
                                        position_ids,
                                        encode_mask,
                                        use_cache=True,
                                        cache=gen_caches)

        next_id = paddle.argmax(logits[:, -1, :], axis=-1).reshape([-1, 1])
        ####################################

        if 'max_dec_len' not in inputs:
            max_len = layers.fill_constant([1],
                                           dtype=int_type,
                                           value=self.max_dec_len,
                                           force_cpu=True)
        else:
            max_len = inputs['max_dec_len']
        min_len = layers.fill_constant(shape=[1],
                                       dtype=int_type,
                                       value=self.min_dec_len,
                                       force_cpu=True)
        step_idx = layers.fill_constant(shape=[1],
                                        value=0,
                                        dtype='int64',
                                        force_cpu=True)

        placehold_ids = layers.fill_constant_batch_size_like(
            input=inputs["src_ids"],
            value=0,
            shape=[-1, 1],
            dtype=next_id.dtype)
        ids = layers.array_write(next_id, step_idx)

        if 'max_dec_len' in inputs:
            max_len = paddle.tensor.creation._memcpy(max_len,
                                                     place=paddle.CPUPlace())
        cond_int = paddle.full([1], 0, dtype=int_type, name="cond_int")
        cond = paddle.less_than(step_idx, max_len)

        if attention_mask is not None:
            append_mask = layers.fill_constant_batch_size_like(
                input=next_id,
                value=1,
                shape=[-1, 1, 1, 1],
                dtype=attention_mask.dtype)

        while_op = layers.While(cond, is_test=True)
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            if attention_mask:
                decode_mask = paddle.concat([attention_mask, append_mask],
                                            axis=-1)
                tgt_pos = tgt_pos + step_idx
                att_mask = (1 - decode_mask) * -1e4
            else:
                att_mask = None
                tgt_pos = None

            layers.increment(x=step_idx, value=1.0, in_place=True)
            layers.array_write(placehold_ids, i=step_idx, array=ids)

            logits, decode_cached_kvs = self.model(pre_ids,
                                                   tgt_pos,
                                                   att_mask,
                                                   use_cache=True,
                                                   cache=cached_kvs)

            logits = paddle.reshape(logits, shape=(-1, self.vocab_size))
            probs = F.softmax(logits / self.temperature)

            if self.decoding_strategy.startswith("sampling"):
                sampling_ids = layers.sampling_id(probs, dtype="int")
            elif self.decoding_strategy.startswith("topk_sampling"):
                probs, sampling_ids = self.topk_sampling(probs)
            elif self.decoding_strategy.startswith("topp_sampling"):
                probs, sampling_ids = self.topp_sampling(probs)
            else:
                raise ValueError(self.decoding_strategy)

            selected_ids = paddle.unsqueeze(sampling_ids, -1)
            layers.array_write(selected_ids, i=step_idx, array=ids)

            length_cond = paddle.less_than(x=step_idx,
                                           y=max_len,
                                           name="length_cond")
            finish_cond = paddle.logical_not(paddle.is_empty(x=selected_ids),
                                             name="finish_cond")
            paddle.logical_and(x=length_cond,
                               y=finish_cond,
                               out=cond,
                               name="logical_and_cond")

            paddle.assign(layers.cast(cond, dtype='bool'), cond)
            if attention_mask:
                paddle.assign(decode_mask, attention_mask)
            for i in range(len(decode_cached_kvs)):
                if self._fuse:
                    paddle.assign(decode_cached_kvs[i].kv, cached_kvs[i].kv)
                else:
                    paddle.assign(decode_cached_kvs[i].k, cached_kvs[i].k)
                    paddle.assign(decode_cached_kvs[i].v, cached_kvs[i].v)

        ids, _ = layers.tensor_array_to_tensor(ids)
        return ids