Exemple #1
0
    def decode(self, dec_state):
        """
        decode
        """
        long_tensor_type = torch.cuda.LongTensor if self.use_gpu else torch.LongTensor

        b = dec_state.get_batch_size()

        # [[0], [k*1], [k*2], ..., [k*(b-1)]]
        self.pos_index = (long_tensor_type(range(b)) * self.k).view(-1, 1)

        # Inflate the initial hidden states to be of size: (b*k, H)
        dec_state = dec_state.inflate(self.k)

        # Initialize the scores; for the first step,
        # ignore the inflated copies to avoid duplicate entries in the top k
        sequence_scores = long_tensor_type(b * self.k).float()
        sequence_scores.fill_(-float('inf'))
        sequence_scores.index_fill_(
            0, long_tensor_type([i * self.k for i in range(b)]), 0.0)

        # Initialize the input vector
        input_var = long_tensor_type([self.BOS] * b * self.k)

        # Store decisions for backtracking
        stored_scores = list()
        stored_predecessors = list()
        stored_emitted_symbols = list()

        for t in range(1, self.max_length + 1):
            # Run the RNN one step forward

            output, dec_state, attn = self.model.decode(input_var, dec_state)

            log_softmax_output = output.squeeze(1)

            # To get the full sequence scores for the new candidates, add the
            # local scores for t_i to the predecessor scores for t_(i-1)
            sequence_scores = sequence_scores.unsqueeze(1).repeat(1, self.V)
            if self.length_average and t > 1:
                sequence_scores = sequence_scores * \
                    (1 - 1/t) + log_softmax_output / t
            else:
                sequence_scores += log_softmax_output

            scores, candidates = sequence_scores.view(b, -1).topk(self.k,
                                                                  dim=1)

            # Reshape input = (b*k, 1) and sequence_scores = (b*k)
            input_var = (candidates % self.V)
            sequence_scores = scores.view(b * self.k)

            input_var = input_var.view(b * self.k)

            # Update fields for next timestep
            if torch.__version__ == '1.2.0':
                predecessors = (candidates / self.V +
                                self.pos_index.expand_as(candidates)).view(
                                    b * self.k)
            else:
                predecessors = (torch.true_divide(candidates, self.V) +
                                self.pos_index.expand_as(candidates)).view(
                                    b * self.k).long()

            dec_state = dec_state.index_select(predecessors)

            # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded
            stored_scores.append(sequence_scores.clone())
            eos_indices = input_var.data.eq(self.EOS)
            if eos_indices.nonzero(as_tuple=False).dim() > 0:
                sequence_scores.data.masked_fill_(eos_indices, -float('inf'))

            if self.ignore_unk:
                # Erase scores for UNK symbol so that they aren't expanded
                unk_indices = input_var.data.eq(self.UNK)
                if unk_indices.nonzero(as_tuple=False).dim() > 0:
                    sequence_scores.data.masked_fill_(unk_indices,
                                                      -float('inf'))

            # Cache results for backtracking
            stored_predecessors.append(predecessors)
            stored_emitted_symbols.append(input_var)

        predicts, scores, lengths = self._backtrack(stored_predecessors,
                                                    stored_emitted_symbols,
                                                    stored_scores, b)

        predicts = predicts[:, :1]
        scores = scores[:, :1]
        lengths = long_tensor_type(lengths)[:, :1]
        mask = sequence_mask(lengths, max_len=self.max_length).eq(0)
        predicts[mask] = self.PAD

        return predicts, lengths, scores
Exemple #2
0
    def encode(self, enc_inputs, hidden=None):
        """
        encode
        """
        outputs = Pack()
        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        inputs, lengths = enc_inputs
        batch_size = enc_outputs.size(0)
        max_len = enc_outputs.size(1)
        attn_mask = sequence_mask(lengths, max_len).eq(0)

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        # insert dialog memory
        if self.dialog_state_memory is None:
            assert self.dialog_history_memory is None
            assert self.history_index is None
            assert self.memory_masks is None
            self.dialog_state_memory = enc_outputs
            self.dialog_history_memory = enc_outputs
            self.history_index = inputs
            self.memory_masks = attn_mask
        else:
            batch_state_memory = self.dialog_state_memory[:batch_size, :, :]
            self.dialog_state_memory = torch.cat(
                [batch_state_memory, enc_outputs], dim=1)
            batch_history_memory = self.dialog_history_memory[:
                                                              batch_size, :, :]
            self.dialog_history_memory = torch.cat(
                [batch_history_memory, enc_outputs], dim=1)
            batch_history_index = self.history_index[:batch_size, :]
            self.history_index = torch.cat([batch_history_index, inputs],
                                           dim=-1)
            batch_memory_masks = self.memory_masks[:batch_size, :]
            self.memory_masks = torch.cat([batch_memory_masks, attn_mask],
                                          dim=-1)

        batch_kb_inputs = self.kbs[:batch_size, :, :]
        batch_kb_state_memory = self.kb_state_memory[:batch_size, :, :]
        batch_kb_slot_memory = self.kb_slot_memory[:batch_size, :, :]
        batch_kb_slot_index = self.kb_slot_index[:batch_size, :]
        kb_mask = self.kb_mask[:batch_size, :]
        selector_mask = self.selector_mask[:batch_size, :]
        batch_situation = self.situation[:, :batch_size, :]
        batch_user_profile = self.user_profile[:batch_size, :, :]
        batch_user_profile_mask = self.user_profile_mask[:batch_size, :]

        enc_hidden = self.situation_bridge(
            torch.cat([enc_hidden, batch_situation], dim=-1))

        up_memory, up_readout = self.decoder.initialize_user_profile(
            batch_user_profile, enc_hidden, batch_user_profile_mask)
        enc_hidden = self.user_profile_bridge(
            torch.cat([enc_hidden, up_readout.unsqueeze(0)], dim=-1))

        kb_memory, selector, kb_readout = self.decoder.initialize_kb_v3(
            batch_kb_inputs, enc_hidden, kb_mask)
        enc_hidden = self.kb_readout_bridge(
            torch.cat([enc_hidden, kb_readout.unsqueeze(0)], dim=-1))

        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            state_memory=self.dialog_state_memory,
            history_memory=self.dialog_history_memory,
            kb_memory=kb_memory,
            kb_state_memory=batch_kb_state_memory,
            kb_slot_memory=batch_kb_slot_memory,
            history_index=self.history_index,
            kb_slot_index=batch_kb_slot_index,
            attn_mask=self.memory_masks,
            attn_kb_mask=kb_mask,
            selector=selector,
            selector_mask=selector_mask,
            up_readout=up_readout)

        return outputs, dec_init_state
Exemple #3
0
    def iterate(self,
                turn_inputs,
                kb_inputs,
                situation_inputs=None,
                user_profile_inputs=None,
                optimizer=None,
                grad_clip=None,
                use_rl=False,
                is_training=True):
        """
        iterate
        """
        self.reset_memory()

        self.load_kb_memory(kb_inputs)
        self.load_situation_memory(situation_inputs)
        self.load_user_profile_memory(kb_inputs, user_profile_inputs)

        metrics_list = []
        total_loss = 0

        for i, inputs in enumerate(turn_inputs):
            if self.use_gpu:
                inputs = inputs.cuda()
            src, src_lengths = inputs.src
            tgt, tgt_lengths = inputs.tgt
            task_label = inputs.task
            gold_entity = inputs.gold_entity
            ptr_index, ptr_lengths = None, None
            kb_index, kb_index_lengths = inputs.kb_index
            enc_inputs = src[:, 1:-1], src_lengths - 2  # filter <bos> <eos>
            dec_inputs = tgt[:, :-1], tgt_lengths - 1  # filter <eos>
            target = tgt[:, 1:]  # filter <bos>
            target_mask = sequence_mask(tgt_lengths - 1)

            if use_rl:
                sample_outputs = self.sample(enc_inputs,
                                             dec_inputs,
                                             random_sample=True)
                with torch.no_grad():
                    greedy_outputs = self.sample(enc_inputs,
                                                 dec_inputs,
                                                 random_sample=False)
                    outputs = self.forward(enc_inputs, dec_inputs)
                metrics = self.collect_rl_metrics(sample_outputs,
                                                  greedy_outputs, target,
                                                  gold_entity, ptr_index,
                                                  kb_index, target_mask,
                                                  task_label)
            else:
                outputs = self.forward(enc_inputs, dec_inputs)
                metrics = self.collect_metrics(outputs, target, ptr_index,
                                               kb_index)

            metrics_list.append(metrics)
            total_loss += metrics.loss

            self.update_memory(dialog_state_memory=outputs.dialog_state_memory,
                               kb_state_memory=outputs.kb_state_memory)

        if torch.isnan(total_loss):
            raise ValueError("NAN loss encountered!")

        if is_training:
            assert optimizer is not None
            optimizer.zero_grad()
            total_loss.backward()
            if grad_clip is not None and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(parameters=self.parameters(),
                                               max_norm=grad_clip)
            optimizer.step()

        return metrics_list
Exemple #4
0
    def iterate(self,
                turn_inputs,
                kb_inputs,
                optimizer=None,
                grad_clip=None,
                is_training=True,
                method="GAN",
                mask=False):
        """
        iterate
        note: this function iterate in the whole model (muti-agent) instead of single sub_model
        """

        if isinstance(optimizer, tuple):
            optimizerG, optimizerDB, optimizerDE = optimizer

        # clear all memory before the begin of a new batch computation
        for name, model in self.named_children():
            if name.startswith("model_"):
                model.reset_memory()
                model.load_kb_memory(kb_inputs)

        # store the whole model (muti_agent)'s metric
        metrics_list_S, metrics_list_TB, metrics_list_TE = [], [], []
        metrics_list_G, metrics_list_DB, metrics_list_DE = [], [], []
        mask_list_S, length_list = [], []
        # store the whole model (muti_agent)'s loss
        total_loss_DB, total_loss_DE, total_loss_G = 0, 0, 0
        # use to compute final loss (sum of each agent's loss) per turn for the cumulated total_loss in a batch
        loss = Pack()
        # use to store kb_mask for three single model
        kd_masks = Pack()

        # compare evaluation metric (bleu/f1score) among models
        if method in ('1-3', 'GAN'):
            # TODO complete
            bleu_ENS_gt_S, bleu_ENS_gt_TB, f1score_ENS_gt_TE = True, True, True
        else:
            # compute bleu_S_gt_TB per batch (compute metric for the following training batch)
            # (key: batch/following/training)
            res_bleu = self.compare_metric(generator_1=self.generator_S,
                                           generator_2=self.generator_TB,
                                           turn_inputs=turn_inputs,
                                           kb_inputs=kb_inputs,
                                           type='bleu',
                                           data_name=self.data_name)
            if isinstance(res_bleu, tuple):
                bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu
            else:
                assert isinstance(res_bleu, bool)
                bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu, ''
            if self.model_TE is not None:
                res_f1score = self.compare_metric(
                    generator_1=self.generator_S,
                    generator_2=self.generator_TE,
                    turn_inputs=turn_inputs,
                    kb_inputs=kb_inputs,
                    type='f1score',
                    data_name=self.data_name)
                if isinstance(res_f1score, tuple):
                    f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score
                else:
                    assert isinstance(res_f1score, bool)
                    f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score, ''
        """ update discriminator """

        # clear all memory again because of cumulation of the memory in the computation of the above generator
        for name, model in self.named_children():
            if name.startswith("model_"):
                model.reset_memory()
                model.load_kb_memory(kb_inputs)

        # begin iterate (a dialogue batch)
        for i, inputs in enumerate(turn_inputs):

            for name, model in self.named_children():
                if name.startswith("model_"):
                    if model.use_gpu:
                        inputs = inputs.cuda()
                    src, src_lengths = inputs.src
                    tgt, tgt_lengths = inputs.tgt
                    task_label = inputs.task
                    gold_entity = inputs.gold_entity
                    ptr_index, ptr_lengths = inputs.ptr_index
                    kb_index, kb_index_lengths = inputs.kb_index
                    enc_inputs = src[:, 1:
                                     -1], src_lengths - 2  # filter <bos> <eos>
                    dec_inputs = tgt[:, :-1], tgt_lengths - 1  # filter <eos>
                    target = tgt[:, 1:]  # filter <bos>
                    target_mask = sequence_mask(tgt_lengths - 1)
                    kd_mask = sequence_kd_mask(tgt_lengths - 1, target, name,
                                               self.ent_idx, self.nen_idx)

                    outputs = model.forward(enc_inputs, dec_inputs)
                    metrics = model.collect_metrics(outputs, target, ptr_index,
                                                    kb_index)

                    if name == "model_S":
                        metrics_list_S.append(metrics)
                    elif name == "model_TB":
                        metrics_list_TB.append(metrics)
                    else:
                        metrics_list_TE.append(metrics)

                    kd_masks[name] = kd_mask if mask else target_mask
                    loss[name] = metrics

                    model.update_memory(
                        dialog_state_memory=outputs.dialog_state_memory,
                        kb_state_memory=outputs.kb_state_memory)

            # store necessary data for three single model
            if self.model_TE is not None:
                kd_mask_e = kd_masks.model_TE
            kd_mask_s = kd_masks.model_S
            kd_mask_b = kd_masks.model_TB
            mask_list_S.append(kd_mask_s)
            length_list.append(tgt_lengths - 1)

            assert False not in (kd_mask_b == kd_mask_e)

            errD_B = self.discriminator_update(netD=self.discriminator_B,
                                               real_data=loss.model_TB.prob,
                                               fake_data=loss.model_S.prob,
                                               lengths=tgt_lengths - 1,
                                               mask=kd_mask_b)
            errD_E = self.discriminator_update(netD=self.discriminator_E,
                                               real_data=loss.model_TE.prob,
                                               fake_data=loss.model_S.prob,
                                               lengths=tgt_lengths - 1,
                                               mask=kd_mask_e)
            # collect discriminator‘s total loss
            metrics_DB = Pack(num_samples=metrics.num_samples)
            metrics_DE = Pack(num_samples=metrics.num_samples)
            metrics_DB.add(loss=errD_B, logits=0.0, prob=0.0)
            metrics_DE.add(loss=errD_E, logits=0.0, prob=0.0)
            metrics_list_DB.append(metrics_DB)
            metrics_list_DE.append(metrics_DE)

            # update in a batch
            total_loss_DB = total_loss_DB + errD_B
            total_loss_DE = total_loss_DE + errD_E
            loss.clear()
            kd_masks.clear()

        # check loss
        if torch.isnan(total_loss_DB) or torch.isnan(total_loss_DE):
            raise ValueError("NAN loss encountered!")

        # compute and update gradient
        if is_training:
            assert not None in (optimizerDB, optimizerDE)
            optimizerDB.zero_grad()
            optimizerDE.zero_grad()
            total_loss_DB.backward()
            total_loss_DE.backward()
            if grad_clip is not None and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.discriminator_B.parameters(),
                    max_norm=grad_clip)
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.discriminator_E.parameters(),
                    max_norm=grad_clip)
            optimizerDB.step()
            optimizerDE.step()
        """ update generator """

        # begin iterate (a dialogue batch)
        n_turn = len(metrics_list_S)
        assert n_turn == len(turn_inputs) == len(mask_list_S)
        for i in range(n_turn):
            errG, errG_B, errG_E, nll = self.generator_update(
                netG=self.model_S,
                netDB=self.discriminator_B,
                netDE=self.discriminator_E,
                fake_data=metrics_list_S[i].prob,
                length=length_list[i],
                mask=mask_list_S[i],
                nll=metrics_list_S[i].loss,
                lambda_g=self.lambda_g)

            # collect generator‘s total loss
            metrics_G = Pack(num_samples=metrics_list_S[i].num_samples)
            metrics_G.add(loss=errG,
                          loss_gb=errG_B,
                          loss_ge=errG_E,
                          loss_nll=nll,
                          logits=0.0,
                          prob=0.0)
            metrics_list_G.append(metrics_G)

            # update in a batch
            total_loss_G += errG

        # check loss
        if torch.isnan(total_loss_G):
            raise ValueError("NAN loss encountered!")

        # compute and update gradient
        if is_training:
            assert optimizerG is not None
            optimizerG.zero_grad()
            total_loss_G.backward()
            if grad_clip is not None and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.model_S.parameters(), max_norm=grad_clip)
            optimizerG.step()

        return metrics_list_S, metrics_list_G, metrics_list_DB, metrics_list_DE
Exemple #5
0
    def encode(self, enc_inputs, hidden=None):
        """
        encode
        """
        outputs = Pack()
        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        inputs, lengths = enc_inputs
        batch_size = enc_outputs.size(0)
        max_len = enc_outputs.size(1)
        attn_mask = sequence_mask(lengths, max_len).eq(0)

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        # insert dialog memory
        if self.dialog_state_memory is None:
            assert self.dialog_history_memory is None
            assert self.history_index is None
            assert self.memory_masks is None
            self.dialog_state_memory = enc_outputs
            self.dialog_history_memory = enc_outputs
            self.history_index = inputs
            self.memory_masks = attn_mask
        else:
            batch_state_memory = self.dialog_state_memory[:batch_size, :, :]
            self.dialog_state_memory = torch.cat(
                [batch_state_memory, enc_outputs], dim=1)
            batch_history_memory = self.dialog_history_memory[:
                                                              batch_size, :, :]
            self.dialog_history_memory = torch.cat(
                [batch_history_memory, enc_outputs], dim=1)
            batch_history_index = self.history_index[:batch_size, :]
            self.history_index = torch.cat([batch_history_index, inputs],
                                           dim=-1)
            batch_memory_masks = self.memory_masks[:batch_size, :]
            self.memory_masks = torch.cat([batch_memory_masks, attn_mask],
                                          dim=-1)

        batch_kb_inputs = self.kbs[:batch_size, :, :]
        batch_kb_state_memory = self.kb_state_memory[:batch_size, :, :]
        batch_kb_slot_memory = self.kb_slot_memory[:batch_size, :, :]
        batch_kb_slot_index = self.kb_slot_index[:batch_size, :]
        kb_mask = self.kb_mask[:batch_size, :]
        selector_mask = self.selector_mask[:batch_size, :]

        selector = self.decoder.initialize_kb_v2(
            enc_hidden=enc_hidden,
            kb_state_memory=batch_kb_state_memory,
            attn_kb_mask=kb_mask)
        # kb_memory, selector = self.decoder.initialize_kb_v3(kb_inputs=batch_kb_inputs, enc_hidden=enc_hidden)
        kb_memory = None
        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            state_memory=self.dialog_state_memory,
            history_memory=self.dialog_history_memory,
            kb_memory=kb_memory,
            kb_state_memory=batch_kb_state_memory,
            kb_slot_memory=batch_kb_slot_memory,
            history_index=self.history_index,
            kb_slot_index=batch_kb_slot_index,
            attn_mask=self.memory_masks,
            attn_kb_mask=kb_mask,
            selector=selector,
            selector_mask=selector_mask)

        return outputs, dec_init_state
Exemple #6
0
    def forward(self, inputs, state):
        """
        forward
        """
        inputs, lengths = inputs
        batch_size, max_len = inputs.size()

        out_inputs = inputs.new_zeros(size=(batch_size, max_len,
                                            self.out_input_size),
                                      dtype=torch.float)

        fact_len = state.fact.size(1)
        hist_len = state.hist.size(1)
        out_facts = inputs.new_zeros(size=(batch_size, max_len, fact_len),
                                     dtype=torch.float)
        out_hists = inputs.new_zeros(size=(batch_size, max_len, hist_len),
                                     dtype=torch.float)

        # prob_hist = inputs.new_zeros(
        #     size=(batch_size, max_len, self.output_size),
        #     dtype=torch.float)

        # prob_fact = inputs.new_zeros(
        #     size=(batch_size, max_len, self.output_size),
        #     dtype=torch.float)

        # sort by lengths
        sorted_lengths, indices = lengths.sort(descending=True)
        inputs = inputs.index_select(0, indices)
        state = state.index_select(indices)

        # number of valid input (i.e. not padding index) in each time step
        num_valid_list = sequence_mask(sorted_lengths).int().sum(dim=0)

        for i, num_valid in enumerate(num_valid_list):
            dec_input = inputs[:num_valid, i]
            valid_state = state.slice_select(num_valid)
            out_input, valid_state, output = self.decode(dec_input,
                                                         valid_state,
                                                         is_training=True)
            state.hidden[:, :num_valid] = valid_state.hidden
            out_inputs[:num_valid, i] = out_input.squeeze(1)
            out_facts[:num_valid, i] = output.attn_f.squeeze(1)
            out_hists[:num_valid, i] = output.attn_h.squeeze(1)

        # Resort
        _, inv_indices = indices.sort()
        state = state.index_select(inv_indices)
        out_inputs = out_inputs.index_select(0, inv_indices)
        out_facts = out_facts.index_select(0, inv_indices)
        out_hists = out_hists.index_select(0, inv_indices)

        p_modes = self.ff(out_inputs)

        # (batch_size, max_len, vocab_size)
        prob_vocab = self.output_layer(out_inputs)
        # prob_hist = convert_dist(
        #     out_hists, state.hist, prob_hist)
        # prob_fact = convert_dist(
        #     out_facts, state.fact, prob_fact)

        # a = torch.cat((prob_vocab, prob_hist, prob_fact), -
        #               1).view(batch_size * max_len, self.output_size, -1)
        # b = p_modes.view(batch_size * max_len, -1).unsqueeze(2)
        # prob = torch.bmm(a, b).squeeze().view(batch_size, max_len, -1)

        weighted_prob = prob_vocab * p_modes[:, :, 0].unsqueeze(2)
        weighted_f = out_facts * p_modes[:, :, 1].unsqueeze(2)
        weighted_h = out_hists * p_modes[:, :, 2].unsqueeze(2)
        weighted_prob = convert_dist(weighted_h, state.hist, weighted_prob)
        weighted_prob = convert_dist(weighted_f, state.fact, weighted_prob)

        log_probs = torch.log(weighted_prob + 1e-10)
        return log_probs, state, output
Exemple #7
0
    def forward(self, dec_inputs, state):
        """
        forward
        """
        inputs, lengths = dec_inputs
        batch_size, max_len = inputs.size()

        out_inputs = inputs.new_zeros(size=(batch_size, max_len,
                                            self.out_input_size),
                                      dtype=torch.float)

        kb_inputs = inputs.new_zeros(size=(batch_size, max_len,
                                           self.out_input_size),
                                     dtype=torch.float)

        out_attn_size = state.history_memory.size(1)
        out_attn_probs = inputs.new_zeros(size=(batch_size, max_len,
                                                out_attn_size),
                                          dtype=torch.float)

        out_kb_size = state.kb_slot_memory.size(1)
        out_kb_probs = inputs.new_zeros(size=(batch_size, max_len,
                                              out_kb_size),
                                        dtype=torch.float)

        # sort by lengths
        sorted_lengths, indices = lengths.sort(descending=True)
        inputs = inputs.index_select(0, indices)
        state = state.index_select(indices)

        # number of valid inputs (i.e. not padding index) in each time step
        num_valid_list = sequence_mask(sorted_lengths).int().sum(dim=0)

        for i, num_valid in enumerate(num_valid_list):
            dec_input = inputs[:num_valid, i]
            valid_state = state.slice_select(num_valid)

            # decode for one step
            out_input, kb_input, attn, kb_attn, valid_state = self.decode(
                dec_input, valid_state, is_training=True)

            state.hidden[:, :num_valid] = valid_state.hidden
            state.state_memory[:num_valid, :, :] = valid_state.state_memory
            state.kb_state_memory[:
                                  num_valid, :, :] = valid_state.kb_state_memory

            out_inputs[:num_valid, i] = out_input.squeeze(1)
            kb_inputs[:num_valid, i] = kb_input.squeeze(1)
            out_attn_probs[:num_valid, i] = attn.squeeze(1)
            out_kb_probs[:num_valid, i] = kb_attn.squeeze(1)

        # Resort
        _, inv_indices = indices.sort()
        state = state.index_select(inv_indices)
        out_inputs = out_inputs.index_select(0, inv_indices)
        kb_inputs = kb_inputs.index_select(0, inv_indices)
        attn_probs = out_attn_probs.index_select(0, inv_indices)
        kb_probs = out_kb_probs.index_select(0, inv_indices)

        probs = self.output_layer(out_inputs)
        p_gen = self.gate_layer(out_inputs)
        p_con = self.copy_gate_layer(kb_inputs)

        return probs, attn_probs, kb_probs, p_gen, p_con, state
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        '''
	    #inputs: 嵌套形式为{分离src和target和cue->(分离数据和长度->tensor数据值    
	    #{'src':( 数据值-->shape(batch_size , sen_num , max_len), 句子长度值--> shape(batch_size,sen_num) ),
          'tgt':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ),
          'cue' :( 数据值-->shape(batch_size, max_len), 句子长度值--> shape(batch_size) ),
          'label':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ),
          'index': ( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) )
          }
	    '''
        outputs = Pack()
        ''' 第二阶段'''
        if self.task_id == 1:

            enc_inputs = inputs.src[0][:, 1:-1], inputs.src[1] - 2
            lengths = inputs.src[1] - 2  # (batch_size)
            enc_outputs, enc_hidden, enc_embedding = self.encoder(
                enc_inputs, hidden)
            # enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size)
            # enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size)

            if self.with_bridge:
                enc_hidden = self.bridge(enc_hidden)

            # tem_bth,tem_len,tem_hi_size =enc_outputs.size()# batch_size, max_len-2, 2*rnn_hidden_size)
            key_index, len_key_index = inputs.index[0], inputs.index[
                1]  # key_index(batch_size , idx_max_len)
            max_len = key_index.size(1)
            key_mask = sequence_mask(len_key_index, max_len).eq(
                0)  # key_mask(batch_size , idx_max_len)
            key_hidden = torch.gather(
                enc_embedding, 1,
                key_index.unsqueeze(-1).repeat(1, 1, enc_embedding.size(
                    -1)))  # (batch_size ,idx_max_len, 2*rnn_hidden_size)
            key_global = key_hidden.masked_fill(
                key_mask.unsqueeze(-1),
                0.0).sum(1) / len_key_index.unsqueeze(1).float()
            key_global = self.key_linear(
                key_global)  # (batch_size, 2*rnn_hidden_size)
            # persona_aware = torch.cat([key_global, enc_hidden[-1]], dim=-1)  # (batch_size ,2*rnn_hidden_size)
            persona_aware = key_global + enc_hidden[
                -1]  #(batch_size , 2*rnn_hidden_size)

            # persona
            batch_size, sent_num, sent = inputs.cue[0].size()
            cue_len = inputs.cue[1]  # (batch_size,sen_num)
            cue_len[cue_len > 0] -= 2  # (batch_size, sen_num)
            cue_inputs = inputs.cue[0].view(-1, sent)[:,
                                                      1:-1], cue_len.view(-1)
            # cue_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num))
            cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder(
                cue_inputs, hidden)
            # cue_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size)
            # cue_enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size)
            cue_outputs = cue_enc_hidden[-1].view(batch_size, sent_num, -1)
            cue_enc_outputs = cue_enc_outputs.view(
                batch_size, sent_num, cue_enc_outputs.size(1), -1
            )  # cue_enc_outputs:(batch_size, sent_num , max_len-2, 2*rnn_hidden_size)
            cue_len = cue_len.view(batch_size, sent_num)

            # cue_outputs:(batch_size, sent_num, 2 * rnn_hidden_size)
            # Attention
            weighted_cue1, cue_attn1 = self.persona_attention(
                query=persona_aware.unsqueeze(1),
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))
            # weighted_cue:(batch_size , 1 , 2 * rnn_hidden_size)
            persona_memory1 = weighted_cue1 + persona_aware.unsqueeze(1)
            weighted_cue2, cue_attn2 = self.persona_attention(
                query=persona_memory1,
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))
            persona_memory2 = weighted_cue2 + persona_aware.unsqueeze(1)
            weighted_cue3, cue_attn3 = self.persona_attention(
                query=persona_memory2,
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))

            cue_attn = cue_attn3.squeeze(1)
            # cue_attn:(batch_size, sent_num)
            outputs.add(cue_attn=cue_attn)
            indexs = cue_attn.max(dim=1)[1]  # (batch_size)
            if is_training:
                # gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10), 0.1, hard=True)
                # persona = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs)
                # indexs = gumbel_attn.max(-1)[1]
                # cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(1)  # (batch_size)
                persona = cue_enc_outputs.gather(
                    1,
                    indexs.view(-1, 1, 1, 1).repeat(
                        1, 1, cue_enc_outputs.size(2),
                        cue_enc_outputs.size(3))).squeeze(
                            1)  # (batch_size , max_len-2, 2*rnn_hidden_size)
                cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(
                    1)  # (batch_size)
            else:
                persona = cue_enc_outputs.gather(
                    1,
                    indexs.view(-1, 1, 1, 1).repeat(
                        1, 1, cue_enc_outputs.size(2),
                        cue_enc_outputs.size(3))).squeeze(
                            1)  # (batch_size , max_len-2, 2*rnn_hidden_size)
                cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(
                    1)  # (batch_size)

            outputs.add(indexs=indexs)
            outputs.add(attn_index=inputs.label)  # (batch_size)

            dec_init_state = self.decoder.initialize_state(
                hidden=enc_hidden,
                attn_memory=enc_outputs if self.attn_mode else None,
                memory_lengths=lengths
                if self.attn_mode else None,  # (batch_size)
                cue_enc_outputs=
                persona,  # (batch_size, max_len-2, 2*rnn_hidden_size)
                cue_lengths=cue_lengths,  # (batch_size)
                task_id=self.task_id)

            # if 'index' in inputs.keys():
            #     outputs.add(attn_index=inputs.index)

        elif self.task_id == 0:
            ''' 第一阶段'''
            # enc_inputs:((batch_size,max_len-2), (batch_size-2))**src去头去尾
            # hidden:None
            batch_size, sent_num, sent_len = inputs.src[0].size()
            src_lengths = inputs.src[1]  # (batch_size,sent_num)
            src_lengths[src_lengths > 0] -= 2
            # src_lengths(batch_size, sent_num)
            src_inputs = inputs.src[0].view(
                -1, sent_len)[:, 1:-1], src_lengths.view(-1)
            # src_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num))
            src_enc_outputs, enc_hidden, _ = self.encoder(src_inputs, hidden)

            if self.with_bridge:
                enc_hidden = self.bridge(enc_hidden)

            # src_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size)
            # enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size)
            src_outputs = torch.mean(
                enc_hidden.view(self.num_layers, batch_size, sent_num, -1),
                2)  # 池化
            # src_outputs:(层数,batch_size,  2 * rnn_hidden_size)

            # persona:((batch_size,max_len-2), (batch_size))**persona的Tensor去头去尾
            cue_inputs = inputs.cue[0][:, 1:-1], inputs.cue[1] - 2
            cue_lengths = inputs.cue[1] - 2  # (batch_size)
            cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder(
                cue_inputs, hidden)
            # cue_enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size)
            # cue_enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size)

            dec_init_state = self.decoder.initialize_state(
                hidden=src_outputs,
                attn_memory=src_enc_outputs.view(
                    batch_size, sent_num, sent_len -
                    2, -1) if self.attn_mode else None,
                # (batch_size, sent_num , max_len-2, 2*rnn_hidden_size)
                memory_lengths=src_lengths
                if self.attn_mode else None,  # (batch_size,sent_num)
                cue_enc_outputs=
                cue_enc_outputs,  # (batch_size, max_len-2, 2*rnn_hidden_size)
                cue_lengths=cue_lengths,
                task_id=self.task_id  # (batch_size)
            )
        return outputs, dec_init_state