Пример #1
0
    def __init__(self,
                 state_num,
                 hidden_dim,
                 know_vocab_size,
                 embed_dim,
                 embedder,
                 lg_interpreter: LocGloInterpreter,
                 gen_strategy,
                 know2word_tensor,
                 with_copy=True):
        super(PriorStateTracker, self).__init__()
        self.state_num = state_num
        self.hidden_dim = hidden_dim
        self.know_vocab_size = know_vocab_size
        self.embed_dim = embed_dim
        self.embedder = embedder
        self.lg_interpreter = lg_interpreter
        self.gen_strategy = gen_strategy
        self.know2word_tensor = know2word_tensor
        self.with_copy = with_copy

        self.embed_attn = Attention(embed_dim, hidden_dim)
        self.prior_basic_state_tracker = BasicStateTracker(
            self.know2word_tensor, self.state_num, self.hidden_dim,
            self.know_vocab_size, self.embed_dim, self.embedder,
            self.lg_interpreter, self.gen_strategy, self.with_copy)

        self.rnn_cell = nn.GRU(input_size=self.embed_dim,
                               hidden_size=self.embed_dim,
                               num_layers=1,
                               batch_first=True,
                               bidirectional=False)
Пример #2
0
class IntentionDetector(nn.Module):
    def __init__(self, know_vocab_size, intention_cate=4, hidden_dim=512, embed_dim=300, graph_dim=128, is_prior=True):
        super(IntentionDetector, self).__init__()
        self.know_vocab_size = know_vocab_size
        self.intention_cate = intention_cate
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.is_prior = is_prior
        self.graph_dim = graph_dim if self.is_prior else 0

        if not self.is_prior:
            self.res_attn = Attention(self.hidden_dim, self.hidden_dim)
        self.his_attn = Attention(self.hidden_dim, self.hidden_dim)
        self.state_attn = Attention(self.hidden_dim, self.embed_dim)

        intention_mlp_input_size = self.hidden_dim + self.hidden_dim + self.embed_dim

        if not self.is_prior:
            intention_mlp_input_size += self.hidden_dim

        if self.is_prior:  # prior policy network, use the graph embedding
            intention_mlp_input_size += self.graph_dim

        self.intention_mlp = nn.Sequential(nn.Linear(intention_mlp_input_size, self.hidden_dim),
                                           nn.ReLU(),
                                           nn.Linear(self.hidden_dim, self.intention_cate),
                                           nn.Softmax())

    def forward(self, state_emb, hidden, pv_r_u_enc, pv_r_u_len=None, r_enc=None, graph_context=None):
        intention_input = []

        # r
        if not self.is_prior:
            r_inp, _ = self.res_attn.forward(hidden.permute(1, 0, 2), r_enc)
            intention_input.append(r_inp)

        # pv_r_u
        h_inp, _ = self.his_attn.forward(hidden.permute(1, 0, 2),
                                         pv_r_u_enc,
                                         mask=reverse_sequence_mask(pv_r_u_len, max_len=pv_r_u_enc.size(1)))
        intention_input.append(h_inp)

        # state
        s_inp, _ = self.state_attn.forward(hidden.permute(1, 0, 2), state_emb)
        intention_input.append(s_inp)

        # question_hidden
        intention_input.append(hidden.permute(1, 0, 2))
        intention_input = torch.cat(intention_input, dim=-1).squeeze(1)  # B, E + H + H [+ H]

        if graph_context is not None and self.is_prior:
            intention_input = torch.cat([intention_input, graph_context.squeeze(1)], dim=-1)

        # intention project
        intention = self.intention_mlp.forward(intention_input)  # B, I
        return intention
Пример #3
0
    def __init__(self,
                 action_num,
                 hidden_dim,
                 know_vocab_size,
                 embed_dim,
                 embedder,
                 lg_interpreter,
                 gen_strategy,
                 with_copy,
                 know2word_tensor,
                 know_encoder: RNNEncoder = None):
        super(PosteriorPolicyNetwork, self).__init__()
        self.action_num = action_num
        self.hidden_dim = hidden_dim
        self.know_vocab_size = know_vocab_size
        self.embed_dim = embed_dim
        self.embedder = embedder
        self.lg_interpreter = lg_interpreter
        self.gen_strategy = gen_strategy
        self.with_copy = with_copy
        self.know2word_tensor = know2word_tensor
        self.embed_attn = Attention(hidden_dim, embed_dim)
        self.hidden_attn = Attention(hidden_dim, hidden_dim)

        if know_encoder is not None:
            self.know_encoder = know_encoder
        else:
            self.know_encoder = RNNEncoder(self.embed_dim,
                                           self.hidden_dim,
                                           self.hidden_dim // 2,
                                           num_layers=1,
                                           bidirectional=True,
                                           dropout=0,
                                           embedder=self.embedder)

        self.embed2hidden_linear = nn.Linear(self.embed_dim, self.hidden_dim)
        self.intention_detector = IntentionDetector(know_vocab_size,
                                                    hidden_dim=self.hidden_dim,
                                                    embed_dim=self.embed_dim,
                                                    is_prior=False)
        self.basic_policy_network = BasicPolicyNetwork(
            self.action_num, self.hidden_dim, self.embed_dim,
            self.know_vocab_size, self.embedder, self.lg_interpreter,
            self.know2word_tensor, self.gen_strategy, self.with_copy)

        self.rnn_cell = nn.GRU(input_size=self.embed_dim,
                               hidden_size=self.embed_dim,
                               num_layers=1,
                               batch_first=True,
                               bidirectional=False)
Пример #4
0
    def __init__(self, action_num, hidden_dim, know_vocab_size, embed_dim,
                 embedder, lg_interpreter, gen_strategy, with_copy,
                 graph_db: GraphDB, know2word_tensor,
                 intention_gumbel_softmax):
        super(PriorPolicyNetwork, self).__init__()
        self.action_num = action_num
        self.hidden_dim = hidden_dim
        self.know_vocab_size = know_vocab_size
        self.embed_dim = embed_dim
        self.embedder = embedder
        self.lg_interpreter = lg_interpreter
        self.gen_strategy = gen_strategy
        self.with_copy = with_copy
        self.graph_db = graph_db
        self.intention_gumbel_softmax = intention_gumbel_softmax
        self.know2word_tensor = know2word_tensor
        self.embed_attn = Attention(hidden_dim, embed_dim)
        self.hidden_attn = Attention(hidden_dim, hidden_dim)
        self.hidden_type_linear = nn.Linear(hidden_dim + 4, hidden_dim)

        self.embed2hidden_linear = nn.Linear(self.embed_dim, self.hidden_dim)
        self.intention_detector = IntentionDetector(
            know_vocab_size,
            hidden_dim=self.hidden_dim,
            embed_dim=self.embed_dim,
            graph_dim=VO.GATConfig.embed_dim,
            is_prior=True)

        self.basic_policy_network = BasicPolicyNetwork(
            self.action_num, self.hidden_dim, self.embed_dim,
            self.know_vocab_size, self.embedder, self.lg_interpreter,
            self.know2word_tensor, self.gen_strategy, self.with_copy)

        self.graph_attn = GraphAttn(VO.node_embed_dim, VO.hidden_dim)
        self.rnn_enc_cell = nn.GRU(input_size=self.embed_dim,
                                   hidden_size=self.embed_dim,
                                   num_layers=1,
                                   batch_first=True,
                                   bidirectional=False)

        self.r_gat = None
        gatc = VO.GATConfig
        self.r_gat = MyGAT(gatc.embed_dim, gatc.edge_embed_dim,
                           gatc.flag_embed_dim, gatc.node_num, gatc.edge_num,
                           gatc.flag_num)
Пример #5
0
    def __init__(self,
                 know2word_tensor,
                 state_num,
                 hidden_dim,
                 know_vocab_size,
                 embed_dim,
                 embedder,
                 lg_interpreter: LocGloInterpreter,
                 gen_strategy="gru",
                 with_copy=True):
        super(BasicStateTracker, self).__init__()
        self.know2word_tensor = know2word_tensor
        self.state_num = state_num
        self.hidden_dim = hidden_dim
        self.know_vocab_size = know_vocab_size
        self.embed_dim = embed_dim
        self.embedder = embedder
        self.lg_interpreter = lg_interpreter
        self.gen_strategy = gen_strategy
        self.with_copy = with_copy
        self.gumbel_softmax = GumbelSoftmax(normed=True)
        assert self.gen_strategy in ("gru", "mlp") or self.gen_strategy is None

        self.embed_attn = Attention(self.hidden_dim, self.embed_dim)
        self.hidden_attn = Attention(self.hidden_dim, self.hidden_dim)

        if self.with_copy:
            self.embed_copy_attn = Attention(self.hidden_dim, self.embed_dim)
            self.hidden_copy_attn = Attention(self.hidden_dim, self.hidden_dim)

        if self.gen_strategy == "gru":
            self.gru = nn.GRU(input_size=self.embed_dim + self.hidden_dim +
                              self.embed_dim,
                              hidden_size=self.hidden_dim,
                              num_layers=1,
                              dropout=0.0,
                              batch_first=True)

        elif self.gen_strategy == "mlp":
            self.step_linear = nn.Sequential(
                nn.Linear(self.hidden_dim, self.state_num * self.hidden_dim),
                Reshape(1, [self.state_num, self.hidden_dim]), nn.Dropout(0.2),
                nn.Linear(self.hidden_dim, self.hidden_dim))
        else:
            raise NotImplementedError

        self.hidden_projection = nn.Linear(self.hidden_dim,
                                           self.know_vocab_size)
        self.word_softmax = nn.Softmax(-1)

        self.hidden2word_projection = nn.Sequential(self.hidden_projection,
                                                    self.word_softmax)
Пример #6
0
    def __init__(self, know_vocab_size, intention_cate=4, hidden_dim=512, embed_dim=300, graph_dim=128, is_prior=True):
        super(IntentionDetector, self).__init__()
        self.know_vocab_size = know_vocab_size
        self.intention_cate = intention_cate
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.is_prior = is_prior
        self.graph_dim = graph_dim if self.is_prior else 0

        if not self.is_prior:
            self.res_attn = Attention(self.hidden_dim, self.hidden_dim)
        self.his_attn = Attention(self.hidden_dim, self.hidden_dim)
        self.state_attn = Attention(self.hidden_dim, self.embed_dim)

        intention_mlp_input_size = self.hidden_dim + self.hidden_dim + self.embed_dim

        if not self.is_prior:
            intention_mlp_input_size += self.hidden_dim

        if self.is_prior:  # prior policy network, use the graph embedding
            intention_mlp_input_size += self.graph_dim

        self.intention_mlp = nn.Sequential(nn.Linear(intention_mlp_input_size, self.hidden_dim),
                                           nn.ReLU(),
                                           nn.Linear(self.hidden_dim, self.intention_cate),
                                           nn.Softmax())
Пример #7
0
class GraphAttn(nn.Module):
    def __init__(self, node_embed_dim, hidden_dim):
        super(GraphAttn, self).__init__()
        self.node_embed_dim = node_embed_dim
        self.hidden_dim = hidden_dim
        # attention read context
        self.attn = Attention(self.hidden_dim, self.node_embed_dim)

    def forward(self, node_embedding, node_efficient, head_flag_bit_matrix, h_c_t):
        efficient_mask = (1 - node_efficient)  # B, N
        efficient_mask = efficient_mask | head_flag_bit_matrix
        node_context, _ = self.attn.forward(h_c_t,
                                            node_embedding,
                                            mask=efficient_mask.bool())  # B, 1, E
        return node_context
Пример #8
0
class GraphCopy(nn.Module):
    def __init__(self, node_embed_dim, hidden_dim):
        super(GraphCopy, self).__init__()
        self.node_embed_dim = node_embed_dim
        self.hidden_dim = hidden_dim
        self.copy_attn = Attention(self.hidden_dim, self.node_embed_dim)

    def forward(self, node_embedding, node_efficient, head_flag_bit_matrix, h_c_t):
        efficient_mask = (1 - node_efficient)  # B, N
        efficient_mask = efficient_mask | head_flag_bit_matrix
        node_logits = self.copy_attn.forward(h_c_t,
                                             node_embedding,
                                             mask=efficient_mask.bool(),
                                             not_softmax=True,
                                             return_weight_only=True)  # B, 1, N
        return node_logits
Пример #9
0
 def __init__(self, node_embed_dim, hidden_dim):
     super(GraphAttn, self).__init__()
     self.node_embed_dim = node_embed_dim
     self.hidden_dim = hidden_dim
     # attention read context
     self.attn = Attention(self.hidden_dim, self.node_embed_dim)
Пример #10
0
 def __init__(self, node_embed_dim, hidden_dim):
     super(GraphCopy, self).__init__()
     self.node_embed_dim = node_embed_dim
     self.hidden_dim = hidden_dim
     self.copy_attn = Attention(self.hidden_dim, self.node_embed_dim)
Пример #11
0
class PriorStateTracker(nn.Module):
    def __init__(self,
                 state_num,
                 hidden_dim,
                 know_vocab_size,
                 embed_dim,
                 embedder,
                 lg_interpreter: LocGloInterpreter,
                 gen_strategy,
                 know2word_tensor,
                 with_copy=True):
        super(PriorStateTracker, self).__init__()
        self.state_num = state_num
        self.hidden_dim = hidden_dim
        self.know_vocab_size = know_vocab_size
        self.embed_dim = embed_dim
        self.embedder = embedder
        self.lg_interpreter = lg_interpreter
        self.gen_strategy = gen_strategy
        self.know2word_tensor = know2word_tensor
        self.with_copy = with_copy

        self.embed_attn = Attention(embed_dim, hidden_dim)
        self.prior_basic_state_tracker = BasicStateTracker(
            self.know2word_tensor, self.state_num, self.hidden_dim,
            self.know_vocab_size, self.embed_dim, self.embedder,
            self.lg_interpreter, self.gen_strategy, self.with_copy)

        self.rnn_cell = nn.GRU(input_size=self.embed_dim,
                               hidden_size=self.embed_dim,
                               num_layers=1,
                               batch_first=True,
                               bidirectional=False)

    def forward(self,
                hidden,
                pv_state,
                pv_r_u,
                pv_r_u_enc,
                gth_state=None,
                supervised=False):
        # B, S, E
        pv_state_emb = self.know_prob_embed(pv_state)
        tmp = []
        s_hidden = None
        for i in range(pv_state_emb.size(1)):
            _, s_hidden = self.rnn_cell.forward(pv_state_emb[:, i:i + 1, :],
                                                s_hidden)  # [1, B, H]
            tmp.append(s_hidden.permute(1, 0, 2))  # [B, 1, E] * S
        pv_state_emb = torch.cat(tmp, 1)  # B, S, E

        # B, 1, E
        pv_state_emb_mean = pv_state_emb.mean(1).unsqueeze(1)
        # B, 1, T
        weight = self.embed_attn.forward(
            pv_state_emb_mean,
            pv_r_u_enc,
            mask=pv_r_u <= DO.PreventWord.RESERVED_MAX_INDEX,
            not_softmax=True,
            return_weight_only=True)
        # B, 1, T
        weight = torch.softmax(weight, -1)
        # B, 1, H ==> B, H
        hidden = hidden.squeeze(0) + torch.bmm(weight, pv_r_u_enc).squeeze(1)

        states, gumbel_states = self.prior_basic_state_tracker.forward(
            hidden,
            pv_state,
            pv_state_emb,
            pv_r_u,
            pv_r_u_enc,
            gth_state=gth_state,
            supervised=supervised)
        return states, gumbel_states

    def know_prob_embed(self, state_gumbel_prob):
        B, S, K = state_gumbel_prob.shape
        # K, E
        know_embedding = self.embedder(self.know2word_tensor)
        state_gumbel_embed = torch.bmm(
            state_gumbel_prob.reshape(B * S, 1, K),
            know_embedding.unsqueeze(0).expand(B * S, K, self.embed_dim))
        state_gumbel_embed = state_gumbel_embed.reshape(B, S, self.embed_dim)
        return state_gumbel_embed
Пример #12
0
    def __init__(self,
                 embed_size,
                 hidden_size,
                 vocab_size,
                 target_max_len,
                 embedder=None,
                 num_layers=1,
                 dropout=0.0,
                 attention_mode="mlp",
                 max_pooling_k=10,
                 attn_history_sentence=True,
                 with_state_know=True,
                 with_action_know=True,
                 with_copy=True):
        super(RNNDecoder, self).__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.target_max_len = target_max_len
        self.embedder = embedder
        self.max_pooling_k = max_pooling_k
        self.num_layers = num_layers
        self.dropout = dropout if (self.num_layers > 1) else 0.0

        self.attn_history_sentence = attn_history_sentence

        self.with_state_know = False
        self.with_action_know = False
        self.with_copy = with_copy and (with_state_know or with_action_know)

        self.rnn_input_size = self.embed_size  # r_{t-1} input
        self.rnn_input_size += (self.hidden_size + self.embed_size +
                                self.embed_size
                                )  # history attn, state attn, action attn
        self.out_input_size = self.hidden_size

        if self.attn_history_sentence:
            # SENTENCE & WORD LEVEL ATTENTION
            self.history_attn_word = Attention(query_size=self.hidden_size,
                                               mode=attention_mode)

        if self.with_state_know:
            self.state_know_attn = Attention(query_size=self.hidden_size,
                                             mode=attention_mode)
            self.rnn_input_size += self.hidden_size

        if self.with_action_know:
            self.action_know_attn = Attention(query_size=self.hidden_size,
                                              mode=attention_mode)
            self.rnn_input_size += self.hidden_size

        if self.with_copy:
            # copy from knowledge
            self.copy_attn = Attention(query_size=self.hidden_size, mode="dot")
            self.sparse_copy_attn = Attention(query_size=self.hidden_size,
                                              memory_size=self.embed_size + 1,
                                              hidden_size=self.hidden_size)
            self.hm_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
            # select the most relative items
            self.k_max_pooling = KMaxPooling(self.max_pooling_k)

        # DECODER CELL
        self.rnn = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.num_layers,
                          dropout=self.dropout,
                          batch_first=True)

        self.init_s_attn = Attention(self.hidden_size, self.embed_size)
        self.init_a_attn = Attention(self.hidden_size, self.embed_size)
        self.s_linear = nn.Linear(self.embed_size, self.hidden_size)
        self.a_linear = nn.Linear(self.embed_size, self.hidden_size)

        if not self.with_copy:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.vocab_size),
                nn.LogSoftmax(dim=-1))
        else:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.vocab_size))

            self.softmax = nn.Softmax(dim=-1)
Пример #13
0
class RNNDecoder(nn.Module):
    def __init__(self,
                 embed_size,
                 hidden_size,
                 vocab_size,
                 target_max_len,
                 embedder=None,
                 num_layers=1,
                 dropout=0.0,
                 attention_mode="mlp",
                 max_pooling_k=10,
                 attn_history_sentence=True,
                 with_state_know=True,
                 with_action_know=True,
                 with_copy=True):
        super(RNNDecoder, self).__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.target_max_len = target_max_len
        self.embedder = embedder
        self.max_pooling_k = max_pooling_k
        self.num_layers = num_layers
        self.dropout = dropout if (self.num_layers > 1) else 0.0

        self.attn_history_sentence = attn_history_sentence

        self.with_state_know = False
        self.with_action_know = False
        self.with_copy = with_copy and (with_state_know or with_action_know)

        self.rnn_input_size = self.embed_size  # r_{t-1} input
        self.rnn_input_size += (self.hidden_size + self.embed_size +
                                self.embed_size
                                )  # history attn, state attn, action attn
        self.out_input_size = self.hidden_size

        if self.attn_history_sentence:
            # SENTENCE & WORD LEVEL ATTENTION
            self.history_attn_word = Attention(query_size=self.hidden_size,
                                               mode=attention_mode)

        if self.with_state_know:
            self.state_know_attn = Attention(query_size=self.hidden_size,
                                             mode=attention_mode)
            self.rnn_input_size += self.hidden_size

        if self.with_action_know:
            self.action_know_attn = Attention(query_size=self.hidden_size,
                                              mode=attention_mode)
            self.rnn_input_size += self.hidden_size

        if self.with_copy:
            # copy from knowledge
            self.copy_attn = Attention(query_size=self.hidden_size, mode="dot")
            self.sparse_copy_attn = Attention(query_size=self.hidden_size,
                                              memory_size=self.embed_size + 1,
                                              hidden_size=self.hidden_size)
            self.hm_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
            # select the most relative items
            self.k_max_pooling = KMaxPooling(self.max_pooling_k)

        # DECODER CELL
        self.rnn = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.num_layers,
                          dropout=self.dropout,
                          batch_first=True)

        self.init_s_attn = Attention(self.hidden_size, self.embed_size)
        self.init_a_attn = Attention(self.hidden_size, self.embed_size)
        self.s_linear = nn.Linear(self.embed_size, self.hidden_size)
        self.a_linear = nn.Linear(self.embed_size, self.hidden_size)

        if not self.with_copy:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.vocab_size),
                nn.LogSoftmax(dim=-1))
        else:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.vocab_size))

            self.softmax = nn.Softmax(dim=-1)

    def decode(self,
               inp,
               hidden,
               hs_vectors,
               state_gumbel_prob,
               state_gumbel_embed,
               action_gumbel_prob,
               action_gumbel_embed,
               history_lens=None,
               history_word_indices=None,
               mask_action_prob=False,
               mask_state_prob=False,
               state_detach=False):
        if self.embedder is not None:
            inp = self.embedder(inp)
        # B, 1, E
        rnn_input = inp

        if self.attn_history_sentence:
            word_mask = None
            B, L_sent, H = hs_vectors.shape
            if history_lens is not None:
                # B, L_sent
                word_mask = reverse_sequence_mask(history_lens,
                                                  max_len=L_sent).reshape(
                                                      B, -1)
            # B, 1, H
            c_his_word, _ = self.history_attn_word.forward(hidden.transpose(
                0, 1),
                                                           hs_vectors,
                                                           mask=word_mask)
            rnn_input = torch.cat([rnn_input, c_his_word], dim=-1)

        # attn state and action embed representation
        c_state = self.init_consult(hidden, state_gumbel_embed,
                                    self.init_s_attn)
        c_action = self.init_consult(hidden, action_gumbel_embed,
                                     self.init_a_attn)
        rnn_input = torch.cat([rnn_input, c_state, c_action], dim=-1)

        rnn_output, new_hidden = self.rnn(rnn_input, hidden)

        gene_output = self.output_layer(rnn_output)

        if not self.with_copy:
            return gene_output, new_hidden, None

        copy_query = new_hidden.permute(1, 0, 2)
        copy_logits = []
        copy_word_index = []

        # copy from post
        hs_logits = self.copy_attn.forward(
            copy_query, hs_vectors, not_softmax=True,
            return_weight_only=True)  # B, 1, L_sentence
        copy_logits.append(hs_logits)
        copy_word_index.append(history_word_indices.unsqueeze(1))

        s_tag = torch.zeros(list(state_gumbel_prob.shape)[:-1] + [1],
                            device=state_gumbel_prob.device,
                            dtype=torch.float)
        a_tag = torch.ones(list(action_gumbel_prob.shape)[:-1] + [1],
                           device=action_gumbel_prob.device,
                           dtype=torch.float)
        # COPY FROM PV_STATE & ACTION
        # B, 1, S
        state_copy_weight = self.sparse_copy(
            copy_query, torch.cat([state_gumbel_embed, s_tag], dim=-1))
        # B, 1, A
        action_copy_weight = self.sparse_copy(
            copy_query, torch.cat([action_gumbel_embed, a_tag], dim=-1))

        if mask_state_prob:
            state_copy_weight.masked_fill_(
                mask=torch.ones_like(state_copy_weight).bool(), value=-1e24)
        if mask_action_prob:
            action_copy_weight.masked_fill_(
                mask=torch.ones_like(action_copy_weight).bool(), value=-1e24)

        copy_logits.append(state_copy_weight)
        copy_logits.append(action_copy_weight)

        # B, V [+ (k * L_triple) [+ (k * L_triple)]] + S + A
        word_proba = self.softmax(
            torch.cat([gene_output] + copy_logits, dim=-1)).squeeze(1)
        # B, V
        gene_proba = word_proba[:, :self.vocab_size]

        state_gumbel_weight = word_proba[:, -(
            DO.state_num + DO.action_num):-DO.action_num].unsqueeze(
                1)  # B, 1, S
        action_gumbel_weight = word_proba[:, -DO.action_num:].unsqueeze(
            1)  # B, 1, A
        # B, V
        state_copy_prob = torch.bmm(state_gumbel_weight,
                                    state_gumbel_prob).squeeze(1)
        if state_detach:
            state_copy_prob = state_copy_prob.detach()

        action_copy_prob = torch.bmm(action_gumbel_weight,
                                     action_gumbel_prob).squeeze(1)

        l_s = vrbot_train_stage.s_copy_lambda
        l_a = vrbot_train_stage.a_copy_lambda if VO.train_stage == "action" else vrbot_train_stage.a_copy_lambda_mini
        gene_proba = gene_proba + l_s * state_copy_prob + l_a * action_copy_prob

        if len(copy_word_index) == 0:
            output = gene_proba
        else:
            if len(copy_word_index) > 1:
                copy_word_index = torch.cat(copy_word_index, dim=-1)
            else:
                copy_word_index = copy_word_index[0]
            # B, V
            B = word_proba.size(0)
            copy_proba = torch.zeros(B,
                                     self.vocab_size,
                                     device=word_proba.device)
            copy_proba = copy_proba.scatter_add(
                1, copy_word_index.squeeze(1),
                word_proba[:, self.vocab_size:-(DO.state_num + DO.action_num)])
            output = gene_proba + copy_proba

        output = torch.log(output.unsqueeze(1))
        return output, new_hidden

    def copy(self, query, memory_word_enc, memory_word_index):
        B, t_num, l_trip, H = memory_word_enc.shape
        flatten_memory_word_enc = memory_word_enc.reshape(B, t_num * l_trip, H)
        flatten_selected_word_index = memory_word_index.reshape(B, -1)
        flatten_selected_mask = flatten_selected_word_index <= DO.PreventWord.RESERVED_MAX_INDEX
        match_logits = self.copy_attn.forward(query,
                                              flatten_memory_word_enc,
                                              not_softmax=True)[1].squeeze(1)
        match_logits = match_logits.masked_fill(flatten_selected_mask, -1e24)
        return match_logits, flatten_selected_word_index

    def sparse_copy(self, query, word_enc):
        # B, 1, S
        copy_weight = self.sparse_copy_attn.forward(query,
                                                    word_enc,
                                                    not_softmax=True,
                                                    return_weight_only=True)
        return copy_weight

    @staticmethod
    def init_consult(hidden, state_emb, attn, state_linear=None):
        state_attn_emb, _ = attn.forward(hidden.permute(1, 0, 2),
                                         state_emb)  # B, 1, E
        if state_linear is not None:
            state_attn_emb = state_linear(state_attn_emb)  # B, 1, H
            return state_attn_emb.permute(1, 0, 2)  # 1, B, H
        else:
            return state_attn_emb

    def forward(self,
                hidden,
                inputs,
                hs_vectors,
                state_gumbel_prob,
                state_gumbel_embed,
                action_gumbel_prob,
                action_gumbel_embed,
                history_lens=None,
                history_word_indices=None,
                mask_action_prob=False,
                mask_state_prob=False,
                state_detach=False):
        if self.training:
            assert inputs is not None, "In training stage, inputs should not be None"
            batch_size, s_max_len = inputs.shape
        else:
            batch_size = hidden.size(1)

        rnn_input = hidden.new_ones(
            batch_size, 1, dtype=torch.long) * DO.PreventWord.SOS_ID  # SOS

        # consult the state and action
        hidden = self.init_consult_state_action(action_gumbel_embed, hidden,
                                                mask_action_prob,
                                                state_gumbel_embed)

        if self.training:
            valid_lengths = s_max_len - (
                inputs == DO.PreventWord.EOS_ID).long().cumsum(1).sum(1)
            valid_lengths, indices = valid_lengths.sort(descending=True)
            lengths_tag = torch.arange(
                0, s_max_len,
                device=TrainOption.device).unsqueeze(0).expand(batch_size, -1)

            # max_len
            batch_num_valid = (valid_lengths.unsqueeze(1).expand(
                -1, s_max_len) > lengths_tag).long().sum(0)
            batch_num_valid = batch_num_valid.cpu().numpy().tolist()
            # B, max_len, vocab_size
            output_placeholder = torch.zeros(batch_size,
                                             self.target_max_len,
                                             self.vocab_size,
                                             dtype=torch.float,
                                             device=TrainOption.device)

            # input Tensor index_select
            hidden = index_select_if_not_none(hidden, indices, 1)
            inputs = index_select_if_not_none(inputs, indices, 0)
            hs_vectors = index_select_if_not_none(hs_vectors, indices, 0)

            state_gumbel_embed = index_select_if_not_none(
                state_gumbel_embed, indices, 0)
            state_gumbel_prob = index_select_if_not_none(
                state_gumbel_prob, indices, 0)
            action_gumbel_embed = index_select_if_not_none(
                action_gumbel_embed, indices, 0)
            action_gumbel_prob = index_select_if_not_none(
                action_gumbel_prob, indices, 0)
            history_word_indices = index_select_if_not_none(
                history_word_indices, indices, 0)
            history_lens = index_select_if_not_none(history_lens, indices, 0)

            for dec_step, vb in enumerate(batch_num_valid):
                if vb <= 0:
                    break

                if inputs is not None and self.training:  # teacher forcing in training stage
                    # B, 1
                    rnn_input = inputs[:, dec_step].unsqueeze(1)

                rnn_dec = self.decode(
                    rnn_input[:vb, :], hidden[:, :vb, :],
                    hs_vectors[:vb, ...] if hs_vectors is not None else None,
                    state_gumbel_prob[:vb, ...] if state_gumbel_prob
                    is not None else None, state_gumbel_embed[:vb, ...]
                    if state_gumbel_embed is not None else None,
                    action_gumbel_prob[:vb, ...] if action_gumbel_prob
                    is not None else None, action_gumbel_embed[:vb, ...]
                    if action_gumbel_embed is not None else None,
                    history_lens[:vb] if history_lens is not None else None,
                    history_word_indices[:vb]
                    if history_word_indices is not None else None,
                    mask_action_prob, mask_state_prob, state_detach)

                word_output, hidden = rnn_dec
                # B, max_len, vocab_size
                output_placeholder[:vb, dec_step:dec_step + 1, :] = word_output

                if not (inputs is not None and
                        self.training):  # inference in testing or valid stage
                    rnn_input = word_output.argmax(dim=-1)

            _, rev_indices = indices.sort()
            output_placeholder = output_placeholder.index_select(
                0, rev_indices)
            return output_placeholder
        else:
            global_indices = torch.arange(0,
                                          batch_size,
                                          dtype=torch.long,
                                          device=TrainOption.device)

            output_placeholder = torch.zeros(batch_size,
                                             self.target_max_len,
                                             self.vocab_size,
                                             dtype=torch.float,
                                             device=TrainOption.device)

            # decode
            for i in range(self.target_max_len):
                word_output, hidden = self.decode(rnn_input, hidden,
                                                  hs_vectors, history_lens)
                # B, 1
                next_step_input = word_output.argmax(dim=-1)
                # B,
                continue_tag = (next_step_input.squeeze(1) !=
                                DO.PreventWord.EOS_ID).long()
                _, local_indices = continue_tag.sort(descending=True)
                output_placeholder[global_indices, i:i + 1, :] = word_output
                b_case = continue_tag.sum().item()
                if b_case <= 0:
                    break

                # B',
                local_indices = local_indices[:b_case]
                global_indices = global_indices.index_select(0, local_indices)
                rnn_input = next_step_input

                # input next time
                rnn_input = rnn_input.index_select(0, local_indices)
                hidden = hidden.index_select(1, local_indices)

                # history_info & memory
                hs_vectors = index_select_if_not_none(hs_vectors,
                                                      local_indices, 0)
                state_know_key = index_select_if_not_none(
                    state_know_key, local_indices, 0)
                state_know_value = index_select_if_not_none(
                    state_know_value, local_indices, 0)
                state_know_word_enc = index_select_if_not_none(
                    state_know_word_enc, local_indices, 0)
                state_know_word = index_select_if_not_none(
                    state_know_word, local_indices, 0)
                state_know_len = index_select_if_not_none(
                    state_know_len, local_indices, 0)
                action_know_key = index_select_if_not_none(
                    action_know_key, local_indices, 0)
                action_know_value = index_select_if_not_none(
                    action_know_value, local_indices, 0)
                action_know_word_enc = index_select_if_not_none(
                    action_know_word_enc, local_indices, 0)
                action_know_word = index_select_if_not_none(
                    action_know_word, local_indices, 0)
                action_know_len = index_select_if_not_none(
                    action_know_len, local_indices, 0)
                history_lens = index_select_if_not_none(
                    history_lens, local_indices, 0)

            return output_placeholder

    def init_consult_state_action(self, action_gumbel_embed, hidden,
                                  mask_action_prob, state_gumbel_embed):
        if mask_action_prob:
            s_emb = self.init_consult(hidden, state_gumbel_embed,
                                      self.init_s_attn, self.s_linear)
            hidden = hidden + s_emb
        else:
            s_emb = self.init_consult(hidden, state_gumbel_embed,
                                      self.init_s_attn, self.s_linear)
            a_emb = self.init_consult(hidden, action_gumbel_embed,
                                      self.init_a_attn, self.a_linear)
            hidden = hidden + (s_emb + a_emb) / 2.0
        return hidden
Пример #14
0
    def __init__(self,
                 action_num,
                 hidden_dim,
                 embed_dim,
                 know_vocab_size,
                 embedder,
                 kw_interpreter: LocGloInterpreter,
                 know2word_tensor,
                 gen_strategy=None,
                 with_copy=True):
        super(BasicPolicyNetwork, self).__init__()
        self.action_num = action_num
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.know_vocab_size = know_vocab_size
        self.embedder = embedder
        self.kw_interpreter = kw_interpreter
        self.gen_strategy = gen_strategy
        self.with_copy = with_copy
        self.know2word_tensor = know2word_tensor
        self.gumbel_softmax = GumbelSoftmax(normed=True)
        self.graph_copy = GraphCopy(VO.node_embed_dim, VO.hidden_dim)
        assert self.gen_strategy in ("gru", "mlp")

        if self.with_copy:
            self.embed_copy_attn = Attention(self.hidden_dim, self.hidden_dim)
            self.hidden_copy_attn = Attention(self.hidden_dim, self.hidden_dim)

        if self.gen_strategy == "gru":
            self.embed_attn = Attention(self.hidden_dim, self.embed_dim)
            self.hidden_attn = Attention(self.hidden_dim, self.hidden_dim)
            self.embed2hidden_linear = nn.Linear(self.embed_dim,
                                                 self.hidden_dim)
            self.gru = nn.GRU(input_size=self.embed_dim + self.embed_dim +
                              self.hidden_dim,
                              hidden_size=self.hidden_dim,
                              num_layers=1,
                              dropout=0.0,
                              batch_first=True)

        elif self.gen_strategy == "mlp":
            self.eh_linear = nn.Sequential(
                nn.Linear(self.hidden_dim + self.embed_dim, self.hidden_dim),
                nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim))
            self.trans_attn = Attention(
                self.hidden_dim,
                self.hidden_dim)  # for interaction between state
            self.action_attn = Attention(
                self.hidden_dim, self.hidden_dim)  # for attentively read state
            self.pred_proj = nn.Sequential(
                nn.Linear(self.hidden_dim + self.hidden_dim + self.embed_dim,
                          self.hidden_dim * self.action_num),
                Reshape(1, [self.action_num, self.hidden_dim]))

        else:
            raise NotImplementedError

        self.gen_linear = nn.Linear(self.hidden_dim, self.know_vocab_size)
        self.hidden_projection = nn.Linear(self.hidden_dim,
                                           self.know_vocab_size)
        self.word_softmax = nn.Softmax(-1)
        self.hidden2word_projection = nn.Sequential(self.hidden_projection,
                                                    self.word_softmax)
Пример #15
0
class BasicPolicyNetwork(nn.Module):
    def __init__(self,
                 action_num,
                 hidden_dim,
                 embed_dim,
                 know_vocab_size,
                 embedder,
                 kw_interpreter: LocGloInterpreter,
                 know2word_tensor,
                 gen_strategy=None,
                 with_copy=True):
        super(BasicPolicyNetwork, self).__init__()
        self.action_num = action_num
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.know_vocab_size = know_vocab_size
        self.embedder = embedder
        self.kw_interpreter = kw_interpreter
        self.gen_strategy = gen_strategy
        self.with_copy = with_copy
        self.know2word_tensor = know2word_tensor
        self.gumbel_softmax = GumbelSoftmax(normed=True)
        self.graph_copy = GraphCopy(VO.node_embed_dim, VO.hidden_dim)
        assert self.gen_strategy in ("gru", "mlp")

        if self.with_copy:
            self.embed_copy_attn = Attention(self.hidden_dim, self.hidden_dim)
            self.hidden_copy_attn = Attention(self.hidden_dim, self.hidden_dim)

        if self.gen_strategy == "gru":
            self.embed_attn = Attention(self.hidden_dim, self.embed_dim)
            self.hidden_attn = Attention(self.hidden_dim, self.hidden_dim)
            self.embed2hidden_linear = nn.Linear(self.embed_dim,
                                                 self.hidden_dim)
            self.gru = nn.GRU(input_size=self.embed_dim + self.embed_dim +
                              self.hidden_dim,
                              hidden_size=self.hidden_dim,
                              num_layers=1,
                              dropout=0.0,
                              batch_first=True)

        elif self.gen_strategy == "mlp":
            self.eh_linear = nn.Sequential(
                nn.Linear(self.hidden_dim + self.embed_dim, self.hidden_dim),
                nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim))
            self.trans_attn = Attention(
                self.hidden_dim,
                self.hidden_dim)  # for interaction between state
            self.action_attn = Attention(
                self.hidden_dim, self.hidden_dim)  # for attentively read state
            self.pred_proj = nn.Sequential(
                nn.Linear(self.hidden_dim + self.hidden_dim + self.embed_dim,
                          self.hidden_dim * self.action_num),
                Reshape(1, [self.action_num, self.hidden_dim]))

        else:
            raise NotImplementedError

        self.gen_linear = nn.Linear(self.hidden_dim, self.know_vocab_size)
        self.hidden_projection = nn.Linear(self.hidden_dim,
                                           self.know_vocab_size)
        self.word_softmax = nn.Softmax(-1)
        self.hidden2word_projection = nn.Sequential(self.hidden_projection,
                                                    self.word_softmax)

    def forward(self,
                hidden,
                state_emb,
                state,
                pv_r_u_enc,
                pv_r_u_len,
                r=None,
                r_enc=None,
                r_mask=None,
                mask_gen=False,
                gth_action=None,
                supervised=False,
                node_embedding=None,
                head_nodes=None,
                node_efficient=None,
                head_flag_bit=None):

        if self.gen_strategy == "gru":
            return self.forward_gru(
                hidden,
                state_emb,
                pv_r_u_enc,
                pv_r_u_len,
                r=r,
                r_enc=r_enc,
                r_mask=r_mask,
                gth_action=gth_action,
                mask_gen=mask_gen,
                supervised=supervised,
                node_embedding=node_embedding,  # B, N, E
                head_nodes=head_nodes,  # B, N
                node_efficient=node_efficient,  # B, N
                head_flag_bit=head_flag_bit)
        elif self.gen_strategy == "mlp":
            return self.forward_mlp(hidden,
                                    state_emb,
                                    r=r,
                                    r_enc=r_enc,
                                    r_mask=r_mask,
                                    mask_gen=mask_gen), None
        else:
            raise NotImplementedError

    def forward_gru(
            self,
            hidden,
            state_emb,
            pv_r_u_enc,
            pv_r_u_len,
            r=None,
            r_enc=None,
            r_mask=None,
            gth_action=None,
            mask_gen=False,
            supervised=False,
            node_embedding=None,  # B, N, E
            head_nodes=None,  # B, N
            node_efficient=None,  # B, N
            head_flag_bit=None):
        batch_size = hidden.size(0)

        # B, 1, E
        state_context, _ = self.embed_attn.forward(hidden.unsqueeze(1),
                                                   state_emb)
        hidden = hidden + self.embed2hidden_linear(state_context).squeeze(1)
        hidden = hidden.unsqueeze(0)  # 1, B, H

        # init input
        step_input = torch.zeros(batch_size,
                                 1,
                                 self.know_vocab_size,
                                 dtype=torch.float,
                                 device=TO.device)
        step_input[:, :, 0] = 1.0  # B, 1, K
        step_input_embed = self.know_prob_embed(step_input)  # B, 1, E

        actions = []
        gumbel_actions = []

        for i in range(self.action_num):
            # B, 1, E + H
            state_context, _ = self.embed_attn.forward(hidden.permute(1, 0, 2),
                                                       state_emb)
            pv_r_u_mask = reverse_sequence_mask(pv_r_u_len, pv_r_u_enc.size(1))
            post_context, _ = self.hidden_attn.forward(hidden.permute(1, 0, 2),
                                                       pv_r_u_enc,
                                                       mask=pv_r_u_mask)
            pv_s_input = torch.cat(
                [step_input_embed, state_context, post_context], dim=-1)

            next_action_hidden, hidden = self.gru.forward(pv_s_input, hidden)

            probs = self.action_pred(batch_size,
                                     next_action_hidden,
                                     r=r,
                                     r_enc=r_enc,
                                     r_mask=r_mask,
                                     mask_gen=mask_gen,
                                     node_embedding=node_embedding,
                                     head_nodes=head_nodes,
                                     node_efficient=node_efficient,
                                     head_flag_bit=head_flag_bit)
            actions.append(probs)

            if self.training and TO.auto_regressive and (
                    gth_action
                    is not None) and supervised and (not TO.no_action_super):
                gth_step_input = gth_action[:, i:i + 1]
                gth_step_input = one_hot_scatter(gth_step_input,
                                                 self.know_vocab_size,
                                                 dtype=torch.float)
                step_input_embed = self.know_prob_embed(gth_step_input)
            else:
                if TO.auto_regressive:
                    if self.training:
                        gumbel_probs = self.action_gumbel_softmax_sampling(
                            probs)
                    else:
                        max_indices = probs.argmax(-1)
                        gumbel_probs = one_hot_scatter(max_indices,
                                                       probs.size(2),
                                                       dtype=torch.float)

                    step_input_embed = self.know_prob_embed(gumbel_probs)
                    gumbel_actions.append(gumbel_probs)
                else:
                    step_input_embed = self.know_prob_embed(probs)

        actions = torch.cat(actions, dim=1)
        if len(gumbel_actions) == 0:
            return actions, None

        gumbel_actions = torch.cat(gumbel_actions, dim=1)
        return actions, gumbel_actions

    def forward_mlp(self,
                    hidden,
                    state_emb,
                    r=None,
                    r_enc=None,
                    r_mask=None,
                    mask_gen=False):
        batch_size, state_num = state_emb.size(0), state_emb.size(1)
        # B, S, H
        expanded_hidden = hidden.unsqueeze(0).permute(1, 0, 2).expand(
            batch_size, state_num, self.hidden_dim)
        # B, S, H
        deep_input = self.eh_linear(
            torch.cat([expanded_hidden, state_emb], dim=2))
        # B, S, H
        deep_inner, _ = self.trans_attn.forward(deep_input, deep_input)
        # B, 1, H
        deep_output, _ = self.action_attn.forward(hidden.unsqueeze(1),
                                                  deep_inner)
        deep_output = deep_output.squeeze(1)
        # B, H + H + E
        proj_input = torch.cat([deep_output, hidden, state_emb.sum(1)], dim=-1)
        # B, A, H (for generation or for copy from )
        pred_action_hidden = self.pred_proj.forward(proj_input)

        return self.action_pred(batch_size,
                                pred_action_hidden,
                                r,
                                r_enc,
                                r_mask,
                                mask_gen=mask_gen)

    def action_pred(
            self,
            batch_size,
            pred_action_hidden,
            r=None,
            r_enc=None,
            r_mask=None,
            mask_gen=False,
            node_embedding=None,  # B, N, E
            head_nodes=None,  # B, N
            node_efficient=None,  # B, N
            head_flag_bit=None):  # B, N
        logits = []
        indexs = None

        action_num = pred_action_hidden.size(1)
        action_gen_logits = self.gen_linear(pred_action_hidden)
        logits.append(action_gen_logits)

        if r is not None:
            r_know_index = self.kw_interpreter.glo2loc(r)
            r_inner_mask = (r_know_index == 0)
            if r_mask is not None:
                r_mask = r_inner_mask | r_mask
            else:
                r_mask = r_inner_mask

            r_logits = self.hidden_copy_attn.forward(pred_action_hidden,
                                                     r_enc,
                                                     mask=r_mask,
                                                     not_softmax=True,
                                                     return_weight_only=True)
            logits.append(r_logits)
            indexs = r_know_index  # B, Tr

        if node_embedding is not None:
            node_copy_logits = self.graph_copy.forward(node_embedding,
                                                       node_efficient,
                                                       head_flag_bit,
                                                       pred_action_hidden)
            logits.append(node_copy_logits)
            indexs = head_nodes  # B, N

        if (r is None) or (not VO.ppn_dq):
            if len(logits) >= 1:
                logits = torch.cat(logits, -1)
            else:
                raise RuntimeError

            probs = self.word_softmax(logits)

            if indexs is None:
                return probs

            if not mask_gen:
                gen_probs = probs[:, :, :self.know_vocab_size]
                copy_probs = probs[:, :, self.know_vocab_size:]
            else:
                gen_probs = 0.0
                copy_probs = probs
        else:
            g_lambda = 0.05
            gen_probs = self.word_softmax(logits[0]) * (1 - g_lambda)
            copy_probs = self.word_softmax(logits[1]) * g_lambda

        copy_probs_placeholder = torch.zeros(batch_size,
                                             action_num,
                                             self.know_vocab_size,
                                             device=TO.device)
        if indexs is not None:
            expand_indexs = indexs.unsqueeze(1).expand(-1, action_num, -1)
            copy_probs = copy_probs_placeholder.scatter_add(
                2, expand_indexs, copy_probs)  # B, A, K

        action_probs = gen_probs + copy_probs  # B, A, K
        return action_probs

    def action_gumbel_softmax_sampling(self, probs):
        gumbel_probs = self.gumbel_softmax.forward(probs,
                                                   vrbot_train_stage.a_tau)
        return gumbel_probs

    def know_prob_embed(self, action_prob):
        B, A, K = action_prob.shape
        # K, E
        know_embedding = self.embedder(self.know2word_tensor)
        action_embed = torch.bmm(
            action_prob.reshape(B * A, 1, K),
            know_embedding.unsqueeze(0).expand(B * A, K, self.embed_dim))
        action_embed = action_embed.reshape(B, A, self.embed_dim)
        return action_embed
Пример #16
0
class BasicStateTracker(nn.Module):
    def __init__(self,
                 know2word_tensor,
                 state_num,
                 hidden_dim,
                 know_vocab_size,
                 embed_dim,
                 embedder,
                 lg_interpreter: LocGloInterpreter,
                 gen_strategy="gru",
                 with_copy=True):
        super(BasicStateTracker, self).__init__()
        self.know2word_tensor = know2word_tensor
        self.state_num = state_num
        self.hidden_dim = hidden_dim
        self.know_vocab_size = know_vocab_size
        self.embed_dim = embed_dim
        self.embedder = embedder
        self.lg_interpreter = lg_interpreter
        self.gen_strategy = gen_strategy
        self.with_copy = with_copy
        self.gumbel_softmax = GumbelSoftmax(normed=True)
        assert self.gen_strategy in ("gru", "mlp") or self.gen_strategy is None

        self.embed_attn = Attention(self.hidden_dim, self.embed_dim)
        self.hidden_attn = Attention(self.hidden_dim, self.hidden_dim)

        if self.with_copy:
            self.embed_copy_attn = Attention(self.hidden_dim, self.embed_dim)
            self.hidden_copy_attn = Attention(self.hidden_dim, self.hidden_dim)

        if self.gen_strategy == "gru":
            self.gru = nn.GRU(input_size=self.embed_dim + self.hidden_dim +
                              self.embed_dim,
                              hidden_size=self.hidden_dim,
                              num_layers=1,
                              dropout=0.0,
                              batch_first=True)

        elif self.gen_strategy == "mlp":
            self.step_linear = nn.Sequential(
                nn.Linear(self.hidden_dim, self.state_num * self.hidden_dim),
                Reshape(1, [self.state_num, self.hidden_dim]), nn.Dropout(0.2),
                nn.Linear(self.hidden_dim, self.hidden_dim))
        else:
            raise NotImplementedError

        self.hidden_projection = nn.Linear(self.hidden_dim,
                                           self.know_vocab_size)
        self.word_softmax = nn.Softmax(-1)

        self.hidden2word_projection = nn.Sequential(self.hidden_projection,
                                                    self.word_softmax)

    def state_gumbel_softmax_sampling(self, probs):
        gumbel_probs = self.gumbel_softmax.forward(probs,
                                                   vrbot_train_stage.s_tau)
        return gumbel_probs

    def forward(self,
                hidden,
                pv_state,
                pv_state_emb,
                pv_r_u,
                pv_r_u_enc,
                gth_state=None,
                supervised=False):
        batch_size = pv_state_emb.size(0)
        states = []
        gumbel_states = []

        multi_hiddens = None
        step_input_embed = None
        if self.gen_strategy == "gru":
            hidden = hidden.unsqueeze(0)  # 1, B, H
            step_input = torch.zeros(batch_size,
                                     1,
                                     self.know_vocab_size,
                                     dtype=torch.float,
                                     device=TO.device)
            step_input[:, :, 0] = 1.0  # B, 1, K
            step_input_embed = self.know_prob_embed(step_input)  # B, 1, E
        elif self.gen_strategy == "mlp":
            multi_hiddens = self.step_linear(hidden)  # B, S, H
        else:
            raise NotImplementedError

        for i in range(self.state_num):
            if self.gen_strategy == "gru":
                pv_s_context, _ = self.embed_attn.forward(
                    hidden.permute(1, 0, 2), pv_state_emb)
                pv_r_u_context, _ = self.hidden_attn.forward(
                    hidden.permute(1, 0, 2), pv_r_u_enc)
                pv_s_input = torch.cat(
                    [pv_s_context, pv_r_u_context, step_input_embed],
                    dim=-1)  # B, 1, E + H + E
                next_state_hidden, hidden = self.gru.forward(
                    pv_s_input, hidden)  # B, 1, H | 1, B, H
            elif self.gen_strategy == "mlp":
                next_state_hidden = multi_hiddens[:, i:i + 1, :]  # B, 1, H
            else:
                raise NotImplementedError

            next_state = self.hidden_projection(next_state_hidden)
            logits = [next_state]
            indexs = []

            if self.with_copy:
                pv_state_weight = self.embed_copy_attn.forward(
                    next_state_hidden,
                    pv_state_emb,
                    mask=None,
                    not_softmax=True,
                    return_weight_only=True)
                logits.append(pv_state_weight)

                pv_r_u_know_index = self.lg_interpreter.glo2loc(pv_r_u)
                pv_r_u_mask = (pv_r_u_know_index == 0)
                pv_r_u_weight = self.hidden_copy_attn.forward(
                    next_state_hidden,
                    pv_r_u_enc,
                    mask=pv_r_u_mask,
                    not_softmax=True,
                    return_weight_only=True)
                logits.append(pv_r_u_weight)
                indexs.append(pv_r_u_know_index)

                logits = torch.cat(logits, -1)
                indexs = torch.cat(indexs, -1).unsqueeze(1)

            probs = self.word_softmax(logits)

            if self.with_copy:
                gen_probs = probs[:, :, :self.know_vocab_size]

                pv_state_copy_probs = probs[:, :, self.know_vocab_size:self.
                                            know_vocab_size + DO.state_num]
                pv_state_copy_probs = torch.bmm(pv_state_copy_probs, pv_state)

                copy_probs = probs[:, :, self.know_vocab_size + DO.state_num:]
                copy_probs_placeholder = torch.zeros(batch_size,
                                                     1,
                                                     self.know_vocab_size,
                                                     device=TO.device)
                copy_probs = copy_probs_placeholder.scatter_add(
                    2, indexs, copy_probs)

                probs = gen_probs + pv_state_copy_probs + copy_probs

            states.append(probs)

            if self.training and TO.auto_regressive and (
                    gth_state is not None) and supervised:
                gth_step_input = gth_state[:, i:i + 1]  # B, 1
                gth_step_input = one_hot_scatter(gth_step_input,
                                                 self.know_vocab_size,
                                                 dtype=torch.float)
                step_input_embed = self.know_prob_embed(gth_step_input)
            else:
                if TO.auto_regressive:
                    if self.training:
                        gumbel_probs = self.state_gumbel_softmax_sampling(
                            probs)  # gumbel-softmax
                    else:
                        max_indices = probs.argmax(-1)  # B, S
                        gumbel_probs = one_hot_scatter(max_indices,
                                                       probs.size(2),
                                                       dtype=torch.float)

                    step_input_embed = self.know_prob_embed(gumbel_probs)
                    gumbel_states.append(gumbel_probs)
                else:
                    step_input_embed = self.know_prob_embed(probs)  # B, 1, E

        states = torch.cat(states, dim=1)
        if len(gumbel_states) == 0:
            return states, None
        gumbel_states = torch.cat(gumbel_states, dim=1)
        return states, gumbel_states

    def know_prob_embed(self, state_prob):
        B, S, K = state_prob.shape
        know_embedding = self.embedder(self.know2word_tensor)
        state_embed = torch.bmm(
            state_prob.reshape(B * S, 1, K),
            know_embedding.unsqueeze(0).expand(B * S, K, self.embed_dim))
        state_embed = state_embed.reshape(B, S, self.embed_dim)
        return state_embed