示例#1
0
class SNLIModel(nn.Module):
    def __init__(self, num_classes, num_words, word_dim, hidden_dim,
                 clf_hidden_dim, clf_num_layers, use_leaf_rnn, intra_attention,
                 use_batchnorm, dropout_prob, bidirectional):
        super(SNLIModel, self).__init__()
        self.num_classes = num_classes
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.clf_hidden_dim = clf_hidden_dim
        self.clf_num_layers = clf_num_layers
        self.use_leaf_rnn = use_leaf_rnn
        self.intra_attention = intra_attention
        self.use_batchnorm = use_batchnorm
        self.dropout_prob = dropout_prob
        self.bidirectional = bidirectional

        self.word_embedding = nn.Embedding(num_embeddings=num_words,
                                           embedding_dim=word_dim)
        self.encoder = BinaryTreeLSTM(word_dim=word_dim,
                                      hidden_dim=hidden_dim,
                                      use_leaf_rnn=use_leaf_rnn,
                                      intra_attention=intra_attention,
                                      gumbel_temperature=1,
                                      bidirectional=bidirectional)
        if bidirectional:
            clf_input_dim = 2 * hidden_dim
        else:
            clf_input_dim = hidden_dim
        self.classifier = SNLIClassifier(num_classes=num_classes,
                                         input_dim=clf_input_dim,
                                         hidden_dim=clf_hidden_dim,
                                         num_layers=clf_num_layers,
                                         use_batchnorm=use_batchnorm,
                                         dropout_prob=dropout_prob)
        self.dropout = nn.Dropout(dropout_prob)
        self.reset_parameters()

    def reset_parameters(self):
        init.normal_(self.word_embedding.weight.data, mean=0, std=0.01)
        self.encoder.reset_parameters()
        self.classifier.reset_parameters()

    def forward(self, pre, pre_length, hyp, hyp_length):
        pre_embeddings = self.word_embedding(pre)
        hyp_embeddings = self.word_embedding(hyp)
        pre_embeddings = self.dropout(pre_embeddings)
        hyp_embeddings = self.dropout(hyp_embeddings)
        pre_h, _ = self.encoder(input=pre_embeddings, length=pre_length)
        hyp_h, _ = self.encoder(input=hyp_embeddings, length=hyp_length)
        logits = self.classifier(pre=pre_h, hyp=hyp_h)
        return logits
示例#2
0
class SSTModel(nn.Module):
    def __init__(self, num_classes, num_words, word_dim, hidden_dim,
                 clf_hidden_dim, clf_num_layers, use_leaf_rnn, use_leaf_birnn,
                 intra_attention, use_batchnorm, dropout_prob):
        super(SSTModel, self).__init__()
        self.num_classes = num_classes
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.clf_hidden_dim = clf_hidden_dim
        self.clf_num_layers = clf_num_layers
        self.use_leaf_rnn = use_leaf_rnn
        self.use_leaf_birnn = use_leaf_birnn
        self.intra_attention = intra_attention
        self.use_batchnorm = use_batchnorm
        self.dropout_prob = dropout_prob

        self.dropout = nn.Dropout(dropout_prob)
        self.word_embedding = nn.Embedding(num_embeddings=num_words,
                                           embedding_dim=word_dim)
        self.encoder = BinaryTreeLSTM(word_dim=word_dim,
                                      hidden_dim=hidden_dim,
                                      use_leaf_rnn=use_leaf_rnn,
                                      use_leaf_birnn=use_leaf_birnn,
                                      intra_attention=intra_attention,
                                      gumbel_temperature=1)
        self.classifier = SSTClassifier(num_classes=num_classes,
                                        input_dim=hidden_dim,
                                        hidden_dim=clf_hidden_dim,
                                        num_layers=clf_num_layers,
                                        use_batchnorm=use_batchnorm,
                                        dropout_prob=dropout_prob)
        self.reset_parameters()

    def reset_parameters(self):
        init.normal(self.word_embedding.weight.data, mean=0, std=0.01)
        self.encoder.reset_parameters()
        self.classifier.reset_parameters()

    def forward(self, words, length):
        words_embed = self.word_embedding(words)
        words_embed = self.dropout(words_embed)
        sentence_vector, _ = self.encoder(input=words_embed, length=length)
        logits = self.classifier(sentence_vector)
        return logits