def _ranking(self, inputs, predictions):
        """ Reranking generated responses. """
        src_token = inputs["src_token"]
        src_mask = inputs["src_mask"]
        src_pos = inputs["src_pos"]
        src_type = inputs["src_type"]
        src_turn = inputs["src_turn"]
        src_embed = self.embedder(src_token, src_pos, src_type, src_turn)

        batch_size, num_latent, tgt_seq_len = predictions.shape

        # shape: [batch_size, num_latent, seq_len, 1]
        preds_token = F.unsqueeze(predictions, [3])
        preds_mask = F.not_equal(preds_token, self.padding_idx, "int64")
        preds_pos = layers.range(0, tgt_seq_len, 1, dtype="float32")
        preds_pos = F.unsqueeze(preds_pos, [0, 0, 1])
        preds_pos = layers.expand(preds_pos, [batch_size, num_latent, 1, 1])
        preds_pos = layers.cast(preds_pos, "int64")
        preds_type = layers.zeros_like(preds_token)
        preds_turn = layers.zeros_like(preds_token)

        scores = []
        for i in range(num_latent):
            pred_token = preds_token[:, i]
            pred_mask = preds_mask[:, i]
            pred_pos = preds_pos[:, i]
            pred_type = preds_type[:, i]
            pred_turn = preds_turn[:, i]

            input_mask = layers.concat([src_mask, pred_mask], axis=1)
            input_mask.stop_gradient = True
            pred_embed = self.embedder(pred_token, pred_pos, pred_type,
                                       pred_turn)
            embed = layers.concat([src_embed, pred_embed], axis=1)
            embed = self.embed_layer_norm(embed)

            mask_embed = self.mask_embed
            mask_embed = layers.expand(mask_embed, [batch_size, 1, 1])
            mask_embed = self.embed_layer_norm(mask_embed)

            out = layers.concat([mask_embed, embed], axis=1)
            mask = self._create_mask(input_mask, append_head=True)

            for layer in self.layers:
                out = layer(out, mask, None)

            mask_embed = out[:, 0]
            score = self.discriminator(mask_embed)
            scores.append(score[:, 0])
        scores = layers.stack(scores, axis=1)
        return scores
    def _collect_metrics(self, inputs, outputs):
        """ Calculate loss function by using inputs and outputs. """
        metrics = {}

        tgt_len = layers.reduce_sum(
            layers.reduce_sum(inputs["tgt_mask"], dim=1) - 1)
        tgt_len.stop_gradient = True

        label = inputs["tgt_token"][:, 1:]
        if self.label_smooth > 0:
            one_hot_label = layers.one_hot(label, self.num_token_embeddings)
            smooth_label = layers.label_smooth(one_hot_label,
                                               epsilon=self.label_smooth,
                                               dtype=self._dtype)
            nll = layers.cross_entropy(outputs["dec_pred"],
                                       smooth_label,
                                       soft_label=True,
                                       ignore_index=self.padding_idx)
        else:
            nll = layers.cross_entropy(outputs["dec_probs"],
                                       label,
                                       ignore_index=self.padding_idx)
        nll = layers.reduce_sum(nll, dim=1)
        token_nll = layers.reduce_sum(nll) / tgt_len
        nll = layers.reduce_mean(nll)
        metrics["nll"] = nll
        metrics["token_nll"] = token_nll
        loss = nll

        if self.num_latent > 0 and self.with_bow:
            bow_probs = F.unsqueeze(outputs["bow_probs"], [1])
            bow_probs = layers.expand(bow_probs, [1, label.shape[1], 1])
            if self.label_smooth > 0:
                bow = layers.cross_entropy(bow_probs,
                                           smooth_label,
                                           soft_label=True,
                                           ignore_index=self.padding_idx)
            else:
                bow = layers.cross_entropy(bow_probs,
                                           label,
                                           ignore_index=self.padding_idx)
            bow = layers.reduce_sum(bow, dim=1)
            token_bow = layers.reduce_sum(bow) / tgt_len
            bow = layers.reduce_mean(bow)
            metrics["bow"] = bow
            metrics["token_bow"] = token_bow
            loss = loss + bow

        if self.num_latent > 0 and self.use_discriminator:
            dis = 0.0 - (layers.log(outputs["pos_probs"]) +
                         layers.log(1.0 - outputs["neg_probs"]))
            dis = layers.reduce_mean(dis)
            metrics["dis"] = dis
            loss = loss + dis * self.dis_ratio

        metrics["loss"] = loss
        metrics["token_num"] = tgt_len
        return metrics
def repeat(var, times):
    if isinstance(var, list):
        return [repeat(x, times) for x in var]
    elif isinstance(var, dict):
        return {k: repeat(v, times) for k, v in var.items()}
    elif isinstance(var, Variable):
        var = F.unsqueeze(var, [1])
        expand_times = [1] * len(var.shape)
        expand_times[1] = times
        dtype = var.dtype
        var = layers.cast(var, "float32")
        var = layers.expand(var, expand_times)
        shape = [var.shape[0] * var.shape[1]] + var.shape[2:]
        var = layers.reshape(var, shape)
        var = layers.cast(var, dtype)
        return var
    else:
        return var
Example #4
0
    def _attn(self, query, key, value, mask):
        # shape: [batch_size, num_head, seq_len, seq_len]
        scores = layers.matmul(x=query, y=key, alpha=self.scale)

        if mask is not None:
            mask = F.unsqueeze(mask, [1])
            mask = layers.expand(mask, [1, self.num_heads, 1, 1])
            mask.stop_gradient = True
            scores = (1 - mask) * scores + layers.scale(mask, scale=-1e10)

        attn = layers.softmax(scores, axis=-1)
        attn = F.dropout(attn, self.dropout)

        if mask is not None:
            attn = (1 - mask) * attn

        out = layers.matmul(x=attn, y=value)
        return out
    def _generation_network(self, input_mask, embed, batch_size, src_len,
                            tgt_len, latent_embed):
        """ Basic generation network implement. """
        if self.num_latent > 0:
            latent_embed = F.unsqueeze(latent_embed, [1])
            latent_embed = self.embed_layer_norm(latent_embed)
            dec_embed = layers.concat([latent_embed, embed], axis=1)
        else:
            dec_embed = embed

        # Create generation network mask
        src_mask = input_mask[:, :src_len]
        tgt_mask = input_mask[:, src_len:]
        enc_mask = self._create_mask(
            src_mask,
            auto_regressive=not self.bidirectional_context,
            append_head=self.num_latent > 0)
        dec_mask = self._create_mask(tgt_mask, auto_regressive=True)
        mask = self._join_mask(enc_mask, dec_mask)

        for layer in self.layers:
            dec_embed = layer(dec_embed, mask, None)

        if self.num_latent > 0:
            latent_embed = dec_embed[:, 0]
        else:
            latent_embed = None
        dec_embed = dec_embed[:, -tgt_len:]
        if self.two_layer_predictor:
            dec_embed = self.pre_predictor(dec_embed)
        if self.weight_sharing:
            token_embedding = self.embedder.token_embedding._w
            dec_logits = layers.matmul(x=dec_embed,
                                       y=token_embedding,
                                       transpose_y=True)
        else:
            dec_logits = self.predictor(dec_embed)

        dec_probs = layers.softmax(dec_logits, axis=-1)

        return latent_embed, dec_probs
    def __call__(self, step_fn, state):
        """
        Running beam search.

        @param : step_fn : decoding one step
        @type : function

        @param : state : initial state
        @type : dict
        """
        batch_size = state["batch_size"]
        beam_size = self.beam_size

        # shape: [batch_size, 1]
        pos_index = layers.range(0, batch_size, 1, dtype="int64")
        pos_index = layers.scale(pos_index, beam_size)
        pos_index = F.unsqueeze(pos_index, [1])

        # shape: [batch_size, beam_size, 1]
        predictions = layers.fill_constant(shape=[batch_size, beam_size, 1],
                                           dtype="int64",
                                           value=self.bos_id)

        # initial input
        state["pred_token"] = predictions[:, :1]
        # shape: [batch_size, vocab_size]
        scores, state = step_fn(state)

        unk_penalty = np.zeros(self.vocab_size, dtype="float32")
        unk_penalty[self.unk_id] = -1e10
        unk_penalty = layers.assign(unk_penalty)

        eos_penalty = np.zeros(self.vocab_size, dtype="float32")
        eos_penalty[self.eos_id] = -1e10
        eos_penalty = layers.assign(eos_penalty)

        scores_after_end = np.full(self.vocab_size, -1e10, dtype="float32")
        scores_after_end[self.pad_id] = 0
        scores_after_end = layers.assign(scores_after_end)

        if self.ignore_unk:
            scores = scores + unk_penalty
        scores = scores + eos_penalty

        # shape: [batch_size, beam_size]
        sequence_scores, preds = layers.topk(scores, self.beam_size)

        predictions = layers.concat(
            [predictions, F.unsqueeze(preds, [2])], axis=2)
        state = repeat(state, beam_size)

        parent_idx_list = []
        pred_list = []

        for step in range(2, self.max_gen_len + 1):
            pre_ids = predictions[:, :, -1:]
            state["pred_token"] = layers.reshape(
                pre_ids, shape=[batch_size * beam_size, 1, 1])
            state["pred_mask"] = 1 - F.equal(state["pred_token"], self.pad_id)
            state["pred_pos"] = state["pred_pos"] + 1
            scores, state = step_fn(state)

            # Generate next
            # scores shape: [batch_size, beam_size, vocab_size]
            if self.ignore_unk:
                scores = scores + unk_penalty

            if step <= self.min_gen_len:
                scores = scores + eos_penalty

            scores = layers.reshape(
                scores, shape=[batch_size, beam_size, self.vocab_size])

            # previous token is [PAD] or [EOS]
            pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(
                pre_ids, self.pad_id)

            scores = scores * (1 - pre_eos_mask) + \
                layers.expand(pre_eos_mask, [1, 1, self.vocab_size]) * scores_after_end
            if self.length_average:
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 -
                                                                    1 / step)
                sequence_scores = F.unsqueeze(sequence_scores,
                                              [2]) * scaled_value
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step)
                scores = scores * scaled_value
            elif self.length_penalty >= 0.0:
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
                    (math.pow((4 + step) / (5 + step), self.length_penalty))
                sequence_scores = layers.elementwise_mul(scaled_value,
                                                         sequence_scores,
                                                         axis=0)
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
                    (math.pow(1 / (5 + step), self.length_penalty))
                scores = scores * scaled_value
            scores = layers.elementwise_add(scores, sequence_scores, axis=0)
            scores = layers.reshape(
                scores, shape=[batch_size, beam_size * self.vocab_size])

            topk_scores, topk_indices = layers.topk(scores, beam_size)
            vocab_size = layers.fill_constant(shape=[1],
                                              dtype="int64",
                                              value=self.vocab_size)
            parent_idx = layers.elementwise_floordiv(topk_indices, vocab_size)
            preds = layers.elementwise_mod(topk_indices, vocab_size)

            # Gather state / sequence_scores
            parent_idx = layers.elementwise_add(parent_idx, pos_index, axis=0)
            parent_idx = layers.reshape(parent_idx, [batch_size * beam_size])
            state = gather(state, parent_idx)
            sequence_scores = topk_scores

            predictions = layers.reshape(predictions,
                                         shape=[batch_size * beam_size, step])
            predictions = gather(predictions, parent_idx)
            predictions = layers.reshape(predictions,
                                         shape=[batch_size, beam_size, step])
            predictions = layers.concat(
                [predictions, F.unsqueeze(preds, [2])], axis=2)

        pre_ids = predictions[:, :, -1]
        pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(
            pre_ids, self.pad_id)
        sequence_scores = sequence_scores * pre_eos_mask + layers.scale(
            1 - pre_eos_mask, -1e10)

        _, indices = layers.argsort(sequence_scores, axis=1)
        indices = indices + pos_index
        indices = layers.reshape(indices, [-1])
        sequence_scores = layers.reshape(sequence_scores,
                                         [batch_size * beam_size])
        predictions = layers.reshape(predictions, [batch_size * beam_size, -1])
        sequence_scores = gather(sequence_scores, indices)
        predictions = layers.gather(predictions, indices)
        sequence_scores = layers.reshape(sequence_scores,
                                         [batch_size, beam_size])
        predictions = layers.reshape(predictions, [batch_size, beam_size, -1])

        results = {
            "preds": predictions[:, -1],
            "scores": sequence_scores[:, -1]
        }
        return results
    def __call__(self, step_fn, state):
        """
        Running generation.

        @param : step_fn : decoding one step
        @type : function

        @param : state : initial state
        @type : dict
        """
        batch_size = state["batch_size"]
        vocab_size = self.vocab_size

        pos_index = layers.range(0, batch_size, 1, dtype="int64")
        pos_index = layers.scale(pos_index, vocab_size)

        # shape: [batch_size, beam_size, 1]
        predictions = layers.fill_constant(shape=[batch_size, 1],
                                           dtype="int64",
                                           value=self.bos_id)
        sequence_scores = layers.fill_constant(shape=[batch_size],
                                               dtype="float32",
                                               value=0.0)

        unk_penalty = np.zeros(vocab_size, dtype="float32")
        unk_penalty[self.unk_id] = -1e10
        unk_penalty = layers.assign(unk_penalty)

        eos_penalty = np.zeros(vocab_size, dtype="float32")
        eos_penalty[self.eos_id] = -1e10
        eos_penalty = layers.assign(eos_penalty)

        scores_after_end = np.full(vocab_size, -1e10, dtype="float32")
        scores_after_end[self.pad_id] = 0
        scores_after_end = layers.assign(scores_after_end)

        # initial input
        for step in range(1, self.max_gen_len + 1):
            pre_ids = predictions[:, -1:]
            state["pred_token"] = F.unsqueeze(pre_ids, [2])
            if step > 1:
                state["pred_mask"] = 1 - F.equal(state["pred_token"],
                                                 self.pad_id)
                state["pred_pos"] = state["pred_pos"] + 1
            scores, state = step_fn(state)

            # Generate next
            # scores shape: [batch_size, vocab_size]
            if self.ignore_unk:
                scores = scores + unk_penalty

            if step <= self.min_gen_len:
                scores = scores + eos_penalty

            # previous token is [PAD] or [EOS]
            # shape: [batch_size, 1]
            pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(
                pre_ids, self.pad_id)
            scores = scores * (1 - pre_eos_mask) + \
                layers.expand(pre_eos_mask, [1, vocab_size]) * scores_after_end

            scores = scores / self.temperature
            preds = self._sampling(scores)

            predictions = layers.concat(
                [predictions, F.unsqueeze(preds, [1])], axis=1)

            scores = layers.reshape(scores, [batch_size * vocab_size])
            preds = preds + pos_index
            scores = gather(scores, preds)
            sequence_scores = sequence_scores + scores

        results = {"preds": predictions, "scores": sequence_scores}
        return results
    def _init_state(self, inputs):
        """ Initialize decode state. """
        state = {}

        src_token = inputs["src_token"]
        src_mask = inputs["src_mask"]
        src_pos = inputs["src_pos"]
        src_type = inputs["src_type"]
        src_turn = inputs["src_turn"]

        batch_size = src_token.shape[0]
        seq_len = src_token.shape[1]

        src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
        src_embed = self.embed_layer_norm(src_embed)

        mask = self._create_mask(src_mask, append_head=self.num_latent > 0)

        if self.num_latent > 0:
            src_embed = F.unsqueeze(src_embed, [1])
            src_embed = layers.expand(src_embed, [1, self.num_latent, 1, 1])
            src_embed = layers.reshape(src_embed, [-1, seq_len, self.hidden_dim])

            latent_embed = self.latent_embeddings
            latent_embed = F.unsqueeze(latent_embed, [1])
            latent_embed = layers.expand(latent_embed, [batch_size, 1, 1])
            latent_embed = self.embed_layer_norm(latent_embed)

            enc_out = layers.concat([latent_embed, src_embed], axis=1)

            mask = F.unsqueeze(mask, [1])
            mask = layers.expand(mask, [1, self.num_latent, 1, 1])
            mask = layers.reshape(mask, [-1, seq_len + 1, seq_len + 1])
        else:
            enc_out = src_embed

        cache = {}
        for l, layer in enumerate(self.layers):
            cache[f"layer_{l}"] = {}
            enc_out = layer(enc_out, mask, cache[f"layer_{l}"])

        state["cache"] = cache
        state["mask"] = mask[:, :1]
        if self.num_latent > 0:
            state["batch_size"] = batch_size * self.num_latent
            shape = [batch_size * self.num_latent, 1, 1]
        else:
            state["batch_size"] = batch_size
            shape = [batch_size, 1, 1]
        state["pred_mask"] = layers.ones(shape, self._dtype)
        state["pred_pos"] = layers.zeros(shape, "int64")
        state["pred_type"] = layers.zeros(shape, "int64")
        state["pred_turn"] = layers.zeros(shape, "int64")

        if "tgt_token" in inputs and self.num_latent > 0:
            tgt_token = inputs["tgt_token"][:, :-1]
            tgt_mask = inputs["tgt_mask"][:, :-1]
            tgt_pos = inputs["tgt_pos"][:, :-1]
            tgt_type = inputs["tgt_type"][:, :-1]
            tgt_turn = inputs["tgt_turn"][:, :-1]

            input_mask = layers.concat([src_mask, tgt_mask], axis=1)
            input_mask.stop_gradient = True
            src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
            tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn)
            embed = layers.concat([src_embed, tgt_embed], axis=1)
            embed = self.embed_layer_norm(embed)

            batch_size = src_token.shape[0]
            src_len = src_token.shape[1]
            tgt_len = tgt_token.shape[1]

            post_embed, post_probs, post_logits = self._posteriori_network(
                input_mask, embed, batch_size, src_len, tgt_len)
            state["post_probs"] = post_probs

        return state
    def _forward(self, inputs, is_training):
        """ Real forward process of model in different mode(train/test). """
        outputs = {}

        src_token = inputs["src_token"]
        src_mask = inputs["src_mask"]
        src_pos = inputs["src_pos"]
        src_type = inputs["src_type"]
        src_turn = inputs["src_turn"]

        tgt_token = inputs["tgt_token"][:, :-1]
        tgt_mask = inputs["tgt_mask"][:, :-1]
        tgt_pos = inputs["tgt_pos"][:, :-1]
        tgt_type = inputs["tgt_type"][:, :-1]
        tgt_turn = inputs["tgt_turn"][:, :-1]

        input_mask = layers.concat([src_mask, tgt_mask], axis=1)
        input_mask.stop_gradient = True
        src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
        tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn)
        embed = layers.concat([src_embed, tgt_embed], axis=1)
        embed = self.embed_layer_norm(embed)

        batch_size = src_token.shape[0]
        src_len = src_token.shape[1]
        tgt_len = tgt_token.shape[1]

        if self.num_latent > 0:
            post_embed, post_probs, post_logits = self._posteriori_network(
                input_mask, embed, batch_size, src_len, tgt_len)
            outputs["post_logits"] = post_logits

            if self.use_discriminator:
                pos_probs, neg_probs = self._discriminator_network(
                    input_mask, embed, batch_size, src_len, tgt_len, post_embed)
                outputs["pos_probs"] = pos_probs
                outputs["neg_probs"] = neg_probs

            if is_training:
                z = F.gumbel_softmax(post_logits, self.tau)
            else:
                indices = layers.argmax(post_logits, axis=1)
                z = layers.one_hot(F.unsqueeze(indices, [1]), self.num_latent)
            latent_embeddings = self.latent_embeddings
            latent_embed = layers.matmul(z, latent_embeddings)
            outputs["latent_embed"] = latent_embed
        else:
            latent_embed = None

        latent_embed, dec_probs = self._generation_network(
            input_mask, embed, batch_size, src_len, tgt_len, latent_embed)
        outputs["dec_probs"] = dec_probs

        if self.num_latent > 0 and self.with_bow:
            if self.two_layer_predictor:
                latent_embed = self.pre_bow_predictor(latent_embed)
            bow_logits = self.bow_predictor(latent_embed)
            bow_probs = layers.softmax(bow_logits)
            outputs["bow_probs"] = bow_probs

        return outputs