class RNNDotJudgeNet(nn.Module): def __init__(self, word_vec_dim, bidir=True, rnn_cell='LSTM'): super().__init__() self.trainable = True self.word_vec_dim = word_vec_dim self.hidden_state_size = word_vec_dim self.encoder = EncoderRNN(self.word_vec_dim, self.word_vec_dim, bidir=bidir, rnn_cell=rnn_cell) self.encoder.apply(util.weight_init) def forward(self, Ks: torch.Tensor, Cs: torch.Tensor, *args): """ :param Ks, keywords used to expand: (batch_size, n_keys, word_vector_dim) :param Cs, candidates searched by Ks: (batch_size, n_candidates, word_vector_dim) :return: probs as good / bad candiates: (batch_size, n_candidates, 2) """ batch_size = Ks.shape[0] n_candidates = Cs.shape[1] sep = torch.zeros(batch_size, 1, self.word_vec_dim) query_string = torch.cat( [Ks, sep, Cs], dim=1) # (batch_size, n_keys + 1 + n_candidates, word_vector_dim) query_string_transposed = query_string.transpose( 0, 1) # (n_keys + 1 + n_candidates, batch_size, word_vector_dim) lengths = [query_string_transposed.shape[0]] encoder_outputs, encoder_states = self.encoder( query_string_transposed, torch.tensor(lengths).long().cpu()) # (n_keys + 1 + n_candidates, batch_size, hidden_state_size) # (n_layers=1, batch_size, hidden_state_size) encoder_hidden = torch.sum(encoder_states[0], dim=0).view(batch_size, self.hidden_state_size, 1) products = torch.bmm(Cs, encoder_hidden) # (batch_size, n_candidates, 1) rest = -1 * products result = torch.cat([products, rest], dim=-1) return result
class RNNJudgeNet(nn.Module): """ keys: (n_keys, word_vec_dim) candidates: (n_candidates, word_vec_dim) query = [keys; 0; candidates]: (n_keys + 1 + n_candidates, word_vec_dim), where 0 is used to separate keys and candidates result = GRU-Encoder-Decoder-with-Attention(query): (n_candidates, 2), which indicates the possibility of ith candidates to be good """ def __init__( self, word_vec_dim, hidden_state_size, bidir=True, rnn_cell='LSTM', ): super().__init__() self.trainable = True self.word_vec_dim = word_vec_dim self.hidden_state_size = hidden_state_size self.encoder = EncoderRNN(self.word_vec_dim, self.hidden_state_size, bidir=bidir, rnn_cell=rnn_cell) self.decoder = AttnDecoderRNN(self.word_vec_dim, self.hidden_state_size, 2, rnn_cell=rnn_cell) self.encoder.apply(util.weight_init) self.decoder.apply(util.weight_init) def forward(self, Ks: torch.Tensor, Cs: torch.Tensor, *args): """ :param Ks, keywords used to expand: (batch_size, n_keys, word_vector_dim) :param Cs, candidates searched by Ks: (batch_size, n_candidates, word_vector_dim) :return: probs as good / bad candiates: (batch_size, n_candidates, 2) """ batch_size = Ks.shape[0] n_candidates = Cs.shape[1] sep = torch.zeros(batch_size, 1, self.word_vec_dim) query_string = torch.cat( [Ks, sep, Cs], dim=1) # (batch_size, n_keys + 1 + n_candidates, word_vector_dim) query_string_transposed = query_string.transpose( 0, 1) # (n_keys + 1 + n_candidates, batch_size, word_vector_dim) lengths = [query_string_transposed.shape[0] ] # (n_keys + 1 + n_candidates) encoder_outputs, encoder_hidden = self.encoder( query_string_transposed, torch.tensor(lengths).long().cpu()) # (n_keys + 1 + n_candidates, batch_size, hidden_state_size) # (n_layers=1, batch_size, hidden_state_size) decoder_hidden = encoder_hidden answers = [] for i in range(n_candidates): # logger.debug(f"decoder_hidden: {decoder_hidden[:, :, 0:10]}") decoder_input = Cs[:, i].unsqueeze( 0) # TODO (new dim=1,a candidate=1, word_vector_dim) # (1, batch_size, hidden_state_size) 此处batch指的不是前面的那个了 output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs) # (1, batch_size, 2) # (n_layers=1, batch_size, hidden_state_size) answers.append(output) probs = torch.cat(answers, dim=0) # (n_candidates, batch_size, 2) probs = probs.transpose(0, 1) # (batch_size, n_candidates, 2) # probs = torch.softmax(probs, dim=-1) return probs