Esempio n. 1
0
    def __init__(self,
                 input_size,
                 hidden_size,
                 output_size,
                 embedder=None,
                 num_layers=1,
                 attn_mode='mlp',
                 dropout=0.0):
        super(PointerDecoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.embedder = embedder
        self.num_layers = num_layers
        self.attn_mode = None if attn_mode == 'none' else attn_mode
        self.dropout = dropout

        self.memory_size = hidden_size

        self.rnn_input_size = self.input_size
        self.out_input_size = self.hidden_size + self.input_size

        if self.attn_mode is not None:
            self.hist_attention = Attention(query_size=self.hidden_size,
                                            memory_size=self.hidden_size,
                                            hidden_size=self.hidden_size,
                                            mode=self.attn_mode,
                                            project=False)
            self.fact_attention = Attention(query_size=self.hidden_size,
                                            memory_size=self.hidden_size,
                                            hidden_size=self.hidden_size,
                                            mode=self.attn_mode,
                                            project=False)
            self.rnn_input_size += self.memory_size * 2
            self.out_input_size += self.memory_size * 2

        self.rnn = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.num_layers,
                          dropout=self.dropout if self.num_layers > 1 else 0,
                          batch_first=True)

        self.ff = nn.Sequential(nn.Linear(self.out_input_size, 3),
                                nn.Softmax(dim=-1))

        if self.out_input_size > self.hidden_size:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.hidden_size),
                nn.Linear(self.hidden_size, self.output_size),
                nn.Softmax(dim=-1),
            )
        else:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.output_size),
                nn.Softmax(dim=-1),
            )
Esempio n. 2
0
    def __init__(self, config, is_train=True):
        super(RNNDecoder, self).__init__()
        self.config = config
        self.is_train = is_train
        self.input_size = self.config.input_size
        self.output_size = self.config.cn_vocab_size + 4
        self.hidden_units = self.config.hidden_units
        self.num_layers = self.config.num_layers
        self.dropout = self.config.dropout
        self.embedder = self.config.embedder
        self.memory_size = self.config.memory_size or self.config.hidden_units

        self.attention_mode = self.config.attention_mode
        self.rnn_input_size = self.input_size
        self.out_input_size = self.hidden_units

        if self.attention_mode:
            self.attention = Attention(config)
            self.rnn_input_size += self.memory_size

        self.rnn = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_units,
                          num_layers=self.num_layers,
                          batch_first=True,
                          dropout=self.dropout if self.num_layers > 1 else 0)

        self.output_layer = nn.Sequential(
            nn.Dropout(p=self.dropout),
            nn.Linear(self.out_input_size, self.hidden_units),
            nn.Linear(self.hidden_units, self.output_size),
            nn.LogSoftmax(dim=-1))
Esempio n. 3
0
    def __init__(self,
                 vocab,
                 hidden_size,
                 hop=1,
                 attn_mode='dot',
                 padding_idx=None):
        super(EncoderMemNN, self).__init__()
        self.num_vocab = vocab
        self.max_hops = hop
        self.hidden_size = hidden_size
        self.attn_mode = attn_mode
        self.padding_idx = padding_idx

        for hop in range(self.max_hops + 1):
            C = Embedder(self.num_vocab,
                         self.hidden_size,
                         padding_idx=self.padding_idx)
            C.weight.data.normal_(0, 0.1)
            self.add_module("C_{}".format(hop), C)

        for hop in range(self.max_hops):
            A = Attention(query_size=self.hidden_size,
                          memory_size=self.hidden_size,
                          hidden_size=self.hidden_size,
                          mode=self.attn_mode,
                          return_attn_only=True)
            self.add_module("A_{}".format(hop), A)

        self.C = AttrProxy(self, "C_")
        self.A = AttrProxy(self, "A_")
        self.softmax = nn.Softmax(dim=1)
    def __init__(self,
                 input_size,
                 hidden_size,
                 embedder=None,
                 num_layers=1,
                 attn_mode=None,
                 memory_size=None,
                 dropout=0.0):
        super(RNNDecoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embedder = embedder
        self.num_layers = num_layers
        self.attn_mode = None if attn_mode == 'none' else attn_mode
        self.memory_size = memory_size or hidden_size
        self.dropout = dropout
        self.out_input_size = hidden_size
        self.class_num=3
        self.rnn_input_size=input_size+hidden_size


        if self.attn_mode is not None:
            self.attention = Attention(query_size=self.hidden_size,
                                       memory_size=self.memory_size,
                                       hidden_size=self.memory_size,
                                       mode=self.attn_mode,
                                       project=False)

        self.rnn = nn.LSTM(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.num_layers,
                          dropout=self.dropout if self.num_layers > 1 else 0,
                          batch_first=True)

        self.fc1=nn.Linear(2*self.hidden_size, self.hidden_size)

        self.output_layer = nn.Sequential(
                nn.Linear(self.out_input_size, self.out_input_size//2),
                nn.Dropout(p=self.dropout),
                nn.ReLU(),
                nn.Linear(self.out_input_size//2, self.class_num),
                nn.Softmax(dim=-1)
            )

        self.fusion=nn.Sequential(
            nn.Linear(2*self.hidden_size,1),
            nn.Sigmoid()
        )
Esempio n. 5
0
    def __init__(self, vocab, embedding_dim, hidden_size, hop,
                 dropout=0.0, num_layers=1, padding_idx=None, attn_mode=None):
        super(DecoderMemNN, self).__init__()
        self.num_vocab = vocab
        self.max_hops = hop
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_layers = num_layers
        self.padding_idx = padding_idx
        self.attn_mode = attn_mode

        self.rnn_input_size = self.embedding_dim
        self.out_input_size = self.hidden_size

        for hop in range(self.max_hops + 1):
            C = Embedder(self.num_vocab, embedding_dim,
                         padding_idx=self.padding_idx)
            C.weight.data.normal_(0, 0.1)
            self.add_module("C_{}".format(hop), C)
        self.C = AttrProxy(self, "C_")

        if self.attn_mode is not None:
            self.attention = Attention(query_size=self.hidden_size,
                                       memory_size=self.hidden_size,
                                       hidden_size=self.hidden_size,
                                       mode=self.attn_mode,
                                       project=False)
            self.rnn_input_size += self.hidden_size

        self.softmax = nn.Softmax(dim=1)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.W = nn.Linear(self.embedding_dim, 1)
        self.W1 = nn.Linear(2 * self.embedding_dim, self.num_vocab)
        self.gru = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.embedding_dim,
                          num_layers=self.num_layers,
                          dropout=self.dropout if self.num_layers > 1 else 0,
                          batch_first=True)
    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 embed_size,
                 hidden_size,
                 padding_idx=None,
                 num_layers=1,
                 bidirectional=True,
                 attn_mode="mlp",
                 attn_hidden_size=None,
                 with_bridge=False,
                 tie_embedding=False,
                 dropout=0.0,
                 use_gpu=False,
                 use_bow=False,
                 use_kd=False,
                 use_dssm=False,
                 use_posterior=False,
                 weight_control=False,
                 use_pg=False,
                 use_gs=False,
                 concat=False,
                 pretrain_epoch=0):
        super(KnowledgeSeq2Seq, self).__init__()

        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.padding_idx = padding_idx
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.attn_mode = attn_mode
        self.attn_hidden_size = attn_hidden_size
        self.with_bridge = with_bridge
        self.tie_embedding = tie_embedding
        self.dropout = dropout
        self.use_gpu = use_gpu
        self.use_bow = use_bow
        self.use_dssm = use_dssm
        self.weight_control = weight_control
        self.use_kd = use_kd
        self.use_pg = use_pg
        self.use_gs = use_gs
        self.use_posterior = use_posterior
        self.pretrain_epoch = pretrain_epoch
        self.baseline = 0

        enc_embedder = Embedder(num_embeddings=self.src_vocab_size,
                                embedding_dim=self.embed_size,
                                padding_idx=self.padding_idx)

        self.encoder = RNNEncoder(input_size=self.embed_size,
                                  hidden_size=self.hidden_size,
                                  embedder=enc_embedder,
                                  num_layers=self.num_layers,
                                  bidirectional=self.bidirectional,
                                  dropout=self.dropout)

        if self.with_bridge:
            self.bridge = nn.Sequential(
                nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh())

        if self.tie_embedding:
            assert self.src_vocab_size == self.tgt_vocab_size
            dec_embedder = enc_embedder
            knowledge_embedder = enc_embedder
        else:
            dec_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
                                    embedding_dim=self.embed_size,
                                    padding_idx=self.padding_idx)
            knowledge_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
                                          embedding_dim=self.embed_size,
                                          padding_idx=self.padding_idx)

        self.knowledge_encoder = RNNEncoder(input_size=self.embed_size,
                                            hidden_size=self.hidden_size,
                                            embedder=knowledge_embedder,
                                            num_layers=self.num_layers,
                                            bidirectional=self.bidirectional,
                                            dropout=self.dropout)

        self.prior_attention = Attention(query_size=self.hidden_size,
                                         memory_size=self.hidden_size,
                                         hidden_size=self.hidden_size,
                                         mode="dot")

        self.posterior_attention = Attention(query_size=self.hidden_size,
                                             memory_size=self.hidden_size,
                                             hidden_size=self.hidden_size,
                                             mode="dot")

        self.decoder = RNNDecoder(input_size=self.embed_size,
                                  hidden_size=self.hidden_size,
                                  output_size=self.tgt_vocab_size,
                                  embedder=dec_embedder,
                                  num_layers=self.num_layers,
                                  attn_mode=self.attn_mode,
                                  memory_size=self.hidden_size,
                                  feature_size=None,
                                  dropout=self.dropout,
                                  concat=concat)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        if self.use_bow:
            self.bow_output_layer = nn.Sequential(
                nn.Linear(in_features=self.hidden_size,
                          out_features=self.hidden_size), nn.Tanh(),
                nn.Linear(in_features=self.hidden_size,
                          out_features=self.tgt_vocab_size),
                nn.LogSoftmax(dim=-1))

        if self.use_dssm:
            self.dssm_project = nn.Linear(in_features=self.hidden_size,
                                          out_features=self.hidden_size)
            self.mse_loss = torch.nn.MSELoss(reduction='mean')

        if self.use_kd:
            self.knowledge_dropout = nn.Dropout()

        if self.padding_idx is not None:
            self.weight = torch.ones(self.tgt_vocab_size)
            self.weight[self.padding_idx] = 0
        else:
            self.weight = None
        self.nll_loss = NLLLoss(weight=self.weight,
                                ignore_index=self.padding_idx,
                                reduction='mean')
        self.kl_loss = torch.nn.KLDivLoss(size_average=True)

        if self.use_gpu:
            self.cuda()
            self.weight = self.weight.cuda()
    def __init__(self,
                 input_size,
                 hidden_size,
                 output_size,
                 embedder=None,
                 num_layers=1,
                 attn_mode=None,
                 attn_hidden_size=None,
                 memory_size=None,
                 feature_size=None,
                 dropout=0.0,
                 concat=False):
        super(RNNDecoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.embedder = embedder
        self.num_layers = num_layers
        self.attn_mode = None if attn_mode == 'none' else attn_mode
        self.attn_hidden_size = attn_hidden_size or hidden_size // 2
        self.memory_size = memory_size or hidden_size
        self.feature_size = feature_size
        self.dropout = dropout
        self.concat = concat

        self.rnn_input_size = self.input_size
        self.out_input_size = self.hidden_size
        self.cue_input_size = self.hidden_size

        if self.feature_size is not None:
            self.rnn_input_size += self.feature_size
            self.cue_input_size += self.feature_size

        if self.attn_mode is not None:
            self.attention = Attention(query_size=self.hidden_size,
                                       memory_size=self.memory_size,
                                       hidden_size=self.attn_hidden_size,
                                       mode=self.attn_mode,
                                       project=False)
            self.rnn_input_size += self.memory_size
            self.cue_input_size += self.memory_size
            self.out_input_size += self.memory_size

        self.rnn = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.num_layers,
                          dropout=self.dropout if self.num_layers > 1 else 0,
                          batch_first=True)

        self.cue_rnn = nn.GRU(
            input_size=self.cue_input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            dropout=self.dropout if self.num_layers > 1 else 0,
            batch_first=True)

        self.fc1 = nn.Linear(self.hidden_size, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
        if self.concat:
            self.fc3 = nn.Linear(self.hidden_size * 2, self.hidden_size)
        else:
            self.fc3 = nn.Linear(self.hidden_size * 2, 1)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        if self.out_input_size > self.hidden_size:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.hidden_size),
                nn.Linear(self.hidden_size, self.output_size),
                nn.LogSoftmax(dim=-1),
            )
        else:
            self.output_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.out_input_size, self.output_size),
                nn.LogSoftmax(dim=-1),
            )
Esempio n. 8
0
    def __init__(
        self,
        input_size,
        hidden_size,
        topic_size,
        output_size,
        trans_mat,
        embedder=None,
        num_layers=1,
        attn_mode=None,
        attn_hidden_size=None,
        memory_size=None,
        feature_size=None,
        dropout=0.0,
        tgt_unk_idx=3,
        attention_channels='ST',
    ):
        super(RNNDecoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.topic_size = topic_size
        self.output_size = output_size
        self.trans_mat = trans_mat.transpose(0, 1)
        self.embedder = embedder
        self.num_layers = num_layers
        self.attn_mode = None if attn_mode == 'none' else attn_mode
        self.attn_hidden_size = attn_hidden_size or hidden_size // 2
        self.memory_size = memory_size or hidden_size
        self.feature_size = feature_size
        self.dropout = dropout
        self.tgt_unk_idx = tgt_unk_idx
        self.attention_channels = attention_channels

        self.rnn_input_size = self.input_size + self.memory_size * 2

        self.out_input_size = self.memory_size + self.hidden_size
        self.topic_input_size = self.hidden_size + self.memory_size

        self.soft_prob_layer = nn.Linear(self.hidden_size, 2)

        self.tgv_layer = nn.Sequential(
            nn.Linear(self.memory_size * 2, self.memory_size), nn.Tanh())
        self.tgv_fc = nn.Linear(self.memory_size * 2, self.memory_size)

        self.attention = Attention(query_size=self.hidden_size,
                                   memory_size=self.memory_size,
                                   hidden_size=self.attn_hidden_size,
                                   mode=self.attn_mode,
                                   project=False)

        self.rnn = nn.GRU(input_size=self.rnn_input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.num_layers,
                          dropout=self.dropout if self.num_layers > 1 else 0,
                          batch_first=True)

        attention_channels_length = len(self.attention_channels)
        self.tgv_layer = nn.Sequential(
            nn.Linear(self.memory_size * attention_channels_length,
                      self.memory_size), nn.Tanh())

        self.output_out_layer = nn.Sequential(
            nn.Dropout(p=self.dropout),
            nn.Linear(self.out_input_size, self.hidden_size),
            nn.Linear(self.hidden_size, self.output_size),
        )

        if self.topic_size is not None:
            self.topic_out_layer = nn.Sequential(
                nn.Dropout(p=self.dropout),
                nn.Linear(self.topic_input_size, self.hidden_size),
                nn.Linear(self.hidden_size, self.topic_size),
            )
        self.lsf = nn.LogSoftmax(dim=-1)
Esempio n. 9
0
    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 embed_size,
                 hidden_size,
                 padding_idx=None,
                 num_layers=1,
                 bidirectional=True,
                 attn_mode="mlp",
                 attn_hidden_size=None,
                 with_bridge=False,
                 tie_embedding=False,
                 dropout=0.0,
                 use_gpu=False,
                 use_bow=False,
                 use_kd=False,
                 use_dssm=False,
                 use_posterior=False,
                 weight_control=False,
                 use_pg=False,
                 use_gs=False,
                 concat=False,
                 pretrain_epoch=0):
        super(KnowledgeSeq2Seq, self).__init__()

        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.padding_idx = padding_idx
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.attn_mode = attn_mode
        self.attn_hidden_size = attn_hidden_size
        self.with_bridge = with_bridge
        self.tie_embedding = tie_embedding
        self.dropout = dropout
        self.use_gpu = use_gpu
        self.use_bow = use_bow
        self.use_dssm = use_dssm
        self.weight_control = weight_control
        self.use_kd = use_kd
        self.use_pg = use_pg
        self.use_gs = use_gs
        self.use_posterior = use_posterior
        self.pretrain_epoch = pretrain_epoch
        self.baseline = 0
        bc = BertClient()
        enc_embedder = bc.encode(['你好', '吃饭了么'])
        # enc_embedder = Embedder(num_embeddings=self.src_vocab_size,
        #                         embedding_dim=self.embed_size, padding_idx=self.padding_idx)    # Embedder(30004, 300, padding_idx=0)

        self.encoder = RNNEncoder(input_size=self.embed_size,
                                  hidden_size=self.hidden_size,
                                  embedder=enc_embedder,
                                  num_layers=self.num_layers,
                                  bidirectional=self.bidirectional,
                                  dropout=self.dropout)

        if self.with_bridge:
            self.bridge = nn.Sequential(
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.Tanh())  #定义一个linear层 ,通过Squential将网络层和激活函数结合起来,输出激活后的网络节点

        if self.tie_embedding:
            assert self.src_vocab_size == self.tgt_vocab_size
            dec_embedder = enc_embedder
            knowledge_embedder = enc_embedder
        else:
            dec_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
                                    embedding_dim=self.embed_size,
                                    padding_idx=self.padding_idx)
            knowledge_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
                                          embedding_dim=self.embed_size,
                                          padding_idx=self.padding_idx)

        self.knowledge_encoder = RNNEncoder(input_size=self.embed_size,
                                            hidden_size=self.hidden_size,
                                            embedder=knowledge_embedder,
                                            num_layers=self.num_layers,
                                            bidirectional=self.bidirectional,
                                            dropout=self.dropout)

        self.prior_attention = Attention(query_size=self.hidden_size,
                                         memory_size=self.hidden_size,
                                         hidden_size=self.hidden_size,
                                         mode="dot")

        self.posterior_attention = Attention(query_size=self.hidden_size,
                                             memory_size=self.hidden_size,
                                             hidden_size=self.hidden_size,
                                             mode="dot")

        self.decoder = RNNDecoder(input_size=self.embed_size,
                                  hidden_size=self.hidden_size,
                                  output_size=self.tgt_vocab_size,
                                  embedder=dec_embedder,
                                  num_layers=self.num_layers,
                                  attn_mode=self.attn_mode,
                                  memory_size=self.hidden_size,
                                  feature_size=None,
                                  dropout=self.dropout,
                                  concat=concat)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        """
        Softplus():a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
        """
        if self.use_bow:
            self.bow_output_layer = nn.Sequential(
                nn.Linear(in_features=self.hidden_size,
                          out_features=self.hidden_size), nn.Tanh(),
                nn.Linear(in_features=self.hidden_size,
                          out_features=self.tgt_vocab_size),
                nn.LogSoftmax(dim=-1))

        if self.use_dssm:
            self.dssm_project = nn.Linear(in_features=self.hidden_size,
                                          out_features=self.hidden_size)
            self.mse_loss = torch.nn.MSELoss(reduction='mean')

        if self.use_kd:
            self.knowledge_dropout = nn.Dropout()

        if self.padding_idx is not None:
            self.weight = torch.ones(self.tgt_vocab_size)
            self.weight[self.padding_idx] = 0
        else:
            self.weight = None
        self.nll_loss = NLLLoss(
            weight=self.weight,
            ignore_index=self.padding_idx,
            reduction='mean')  #量化真实回复与基线生成的回复的不同 :NLLLoss()       负对数似然
        self.kl_loss = torch.nn.KLDivLoss(
            size_average=True
        )  #KLDivLoss()   #select related background knowledge -> lead the conversation

        if self.use_gpu:
            self.cuda()
            self.weight = self.weight.cuda()
    def __init__(self,
                 emb_size=1024,
                 n_layer=12,
                 n_head=1,
                 voc_size=10005,
                 max_position_seq_len=1024,
                 sent_types=2,
                 num_labels=2,
                 dropout=0.3,
                 use_knowledge=False,
                 share_embedding=False,
                 padding_idx=0,
                 use_gpu=True):

        super(RetrievalModel, self).__init__()

        self.emb_size = emb_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.voc_size = voc_size
        self.max_position_seq_len = max_position_seq_len
        self.sent_types = sent_types
        self.num_labels = num_labels
        self.dropout = dropout
        self.use_knowledge = use_knowledge
        self.share_embedding = share_embedding
        self.padding_idx = padding_idx
        self.use_gpu = use_gpu

        self.embeddings = [nn.Embedding(self.voc_size, self.emb_size, self.padding_idx),
                           nn.Embedding(self.max_position_seq_len, self.emb_size),
                           nn.Embedding(self.sent_types, self.emb_size)]

        self.transformer_encoder = TransformerEncoder(self.n_layer,
                                                      self.emb_size,
                                                      self.n_head,
                                                      self.emb_size * 4,
                                                      self.dropout,
                                                      self.embeddings)
        '''
仅仅是初始化
定义 
embedding 
transformer:voc+position+sent_types
GRU
knowledge embedding 
        '''
        if self.use_knowledge:
            if self.share_embedding:
                self.knowledge_embeddings = self.embeddings[0]
            else:
                self.knowledge_embeddings = \
                    nn.Embedding(self.voc_size, self.emb_size, self.padding_idx)

            self.rnn_encoder = RNNEncoder(rnn_type="GRU",
                                          bidirectional=True,
                                          num_layers=1,
                                          input_size=self.emb_size,
                                          hidden_size=self.emb_size,
                                          dropout=self.dropout,
                                          embeddings=self.knowledge_embeddings,
                                          use_bridge=True)

            self.attention = Attention(query_size=self.emb_size,
                                       memory_size=self.emb_size,
                                       hidden_size=self.emb_size,
                                       mode="dot",
                                       project=True)

        self.middle_linear = nn.Sequential(
            nn.Linear(self.emb_size, self.emb_size),
            nn.Tanh()
        )

        self.final_linear = nn.Sequential(
            nn.Dropout(p=self.dropout),
            nn.Linear(self.emb_size, self.num_labels)
        )

        self.softmax = nn.LogSoftmax(dim=-1)
        self.criterion = nn.NLLLoss()

        if self.use_gpu:
            self.cuda()
    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 embed_size,
                 hidden_size,
                 padding_idx=None,
                 num_layers=1,
                 bidirectional=True,
                 attn_mode="mlp",
                 attn_hidden_size=None,
                 with_bridge=False,
                 tie_embedding=False,
                 dropout=0.0,
                 use_gpu=False,
                 use_dssm=False,
                 weight_control=False,
                 use_pg=False,
                 concat=False,
                 pretrain_epoch=0,
                 with_label=False):
        super(TwoStagePersonaSeq2Seq, self).__init__()

        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.padding_idx = padding_idx
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.attn_mode = attn_mode
        self.attn_hidden_size = attn_hidden_size
        self.with_bridge = with_bridge
        self.tie_embedding = tie_embedding
        self.dropout = dropout
        self.use_gpu = use_gpu
        self.use_dssm = use_dssm
        self.weight_control = weight_control
        self.use_pg = use_pg
        self.pretrain_epoch = pretrain_epoch
        self.baseline = 0
        self.with_label = with_label
        self.task_id = 1

        enc_embedder = Embedder(num_embeddings=self.src_vocab_size,
                                embedding_dim=self.embed_size,
                                padding_idx=self.padding_idx)

        self.encoder = RNNEncoder(input_size=self.embed_size,
                                  hidden_size=self.hidden_size,
                                  embedder=enc_embedder,
                                  num_layers=self.num_layers,
                                  bidirectional=self.bidirectional,
                                  dropout=self.dropout)

        if self.with_bridge:
            self.bridge = nn.Sequential(
                nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh())

        if self.tie_embedding:
            assert self.src_vocab_size == self.tgt_vocab_size
            dec_embedder = enc_embedder
            persona_embedder = enc_embedder
        else:
            dec_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
                                    embedding_dim=self.embed_size,
                                    padding_idx=self.padding_idx)
            persona_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
                                        embedding_dim=self.embed_size,
                                        padding_idx=self.padding_idx)

        self.persona_encoder = RNNEncoder(input_size=self.embed_size,
                                          hidden_size=self.hidden_size,
                                          embedder=persona_embedder,
                                          num_layers=self.num_layers,
                                          bidirectional=self.bidirectional,
                                          dropout=self.dropout)

        self.persona_attention = Attention(query_size=self.hidden_size,
                                           memory_size=self.hidden_size,
                                           hidden_size=self.hidden_size,
                                           mode="general")

        self.decoder = RNNDecoder(input_size=self.embed_size,
                                  hidden_size=self.hidden_size,
                                  output_size=self.tgt_vocab_size,
                                  embedder=dec_embedder,
                                  num_layers=self.num_layers,
                                  attn_mode=self.attn_mode,
                                  memory_size=self.hidden_size,
                                  feature_size=None,
                                  dropout=self.dropout,
                                  concat=concat,
                                  with_label=self.with_label)
        self.key_linear = nn.Linear(in_features=self.embed_size,
                                    out_features=self.hidden_size)

        if self.use_dssm:
            self.dssm_project = nn.Linear(in_features=self.hidden_size,
                                          out_features=self.hidden_size)
            self.mse_loss = torch.nn.MSELoss(reduction='mean')

        self.mse_loss = torch.nn.MSELoss(reduction='mean')

        if self.padding_idx is not None:
            self.weight = torch.ones(self.tgt_vocab_size)
            self.weight[self.padding_idx] = 0
        else:
            self.weight = None
        self.nll_loss = NLLLoss(weight=self.weight,
                                ignore_index=self.padding_idx,
                                reduction='mean')

        self.persona_loss = NLLLoss(weight=None, reduction='mean')
        self.eps = 1e-7

        if self.use_gpu:
            self.cuda()
            self.weight = self.weight.cuda()