class SNLIModel(nn.Module): def __init__(self, num_classes, num_words, word_dim, hidden_dim, clf_hidden_dim, clf_num_layers, use_leaf_rnn, intra_attention, use_batchnorm, dropout_prob, bidirectional): super(SNLIModel, self).__init__() self.num_classes = num_classes self.word_dim = word_dim self.hidden_dim = hidden_dim self.clf_hidden_dim = clf_hidden_dim self.clf_num_layers = clf_num_layers self.use_leaf_rnn = use_leaf_rnn self.intra_attention = intra_attention self.use_batchnorm = use_batchnorm self.dropout_prob = dropout_prob self.bidirectional = bidirectional self.word_embedding = nn.Embedding(num_embeddings=num_words, embedding_dim=word_dim) self.encoder = BinaryTreeLSTM(word_dim=word_dim, hidden_dim=hidden_dim, use_leaf_rnn=use_leaf_rnn, intra_attention=intra_attention, gumbel_temperature=1, bidirectional=bidirectional) if bidirectional: clf_input_dim = 2 * hidden_dim else: clf_input_dim = hidden_dim self.classifier = SNLIClassifier(num_classes=num_classes, input_dim=clf_input_dim, hidden_dim=clf_hidden_dim, num_layers=clf_num_layers, use_batchnorm=use_batchnorm, dropout_prob=dropout_prob) self.dropout = nn.Dropout(dropout_prob) self.reset_parameters() def reset_parameters(self): init.normal_(self.word_embedding.weight.data, mean=0, std=0.01) self.encoder.reset_parameters() self.classifier.reset_parameters() def forward(self, pre, pre_length, hyp, hyp_length): pre_embeddings = self.word_embedding(pre) hyp_embeddings = self.word_embedding(hyp) pre_embeddings = self.dropout(pre_embeddings) hyp_embeddings = self.dropout(hyp_embeddings) pre_h, _ = self.encoder(input=pre_embeddings, length=pre_length) hyp_h, _ = self.encoder(input=hyp_embeddings, length=hyp_length) logits = self.classifier(pre=pre_h, hyp=hyp_h) return logits
class SSTModel(nn.Module): def __init__(self, num_classes, num_words, word_dim, hidden_dim, clf_hidden_dim, clf_num_layers, use_leaf_rnn, use_leaf_birnn, intra_attention, use_batchnorm, dropout_prob): super(SSTModel, self).__init__() self.num_classes = num_classes self.word_dim = word_dim self.hidden_dim = hidden_dim self.clf_hidden_dim = clf_hidden_dim self.clf_num_layers = clf_num_layers self.use_leaf_rnn = use_leaf_rnn self.use_leaf_birnn = use_leaf_birnn self.intra_attention = intra_attention self.use_batchnorm = use_batchnorm self.dropout_prob = dropout_prob self.dropout = nn.Dropout(dropout_prob) self.word_embedding = nn.Embedding(num_embeddings=num_words, embedding_dim=word_dim) self.encoder = BinaryTreeLSTM(word_dim=word_dim, hidden_dim=hidden_dim, use_leaf_rnn=use_leaf_rnn, use_leaf_birnn=use_leaf_birnn, intra_attention=intra_attention, gumbel_temperature=1) self.classifier = SSTClassifier(num_classes=num_classes, input_dim=hidden_dim, hidden_dim=clf_hidden_dim, num_layers=clf_num_layers, use_batchnorm=use_batchnorm, dropout_prob=dropout_prob) self.reset_parameters() def reset_parameters(self): init.normal(self.word_embedding.weight.data, mean=0, std=0.01) self.encoder.reset_parameters() self.classifier.reset_parameters() def forward(self, words, length): words_embed = self.word_embedding(words) words_embed = self.dropout(words_embed) sentence_vector, _ = self.encoder(input=words_embed, length=length) logits = self.classifier(sentence_vector) return logits