Esempio n. 1
0
    def run_inference(self, sources, targets, teach_ratio):
        loss = 0
        encoder_hidden = self.encoder.initHidden()
        encoder_length = sources.size()[0]
        encoder_outputs, encoder_hidden = self.encoder(sources, encoder_hidden)

        decoder_hidden = encoder_hidden
        decoder_length = targets.size()[0]
        decoder_input = basic_variable([self.start_token])
        decoder_context = smart_variable(
            torch.zeros(1, 1, self.decoder.hidden_size))
        # visual = torch.zeros(encoder_length, decoder_length)
        predictions = []

        for di in range(decoder_length):
            use_teacher_forcing = random.random() < self.teach_ratio
            decoder_output, decoder_context, decoder_hidden, attn_weights = self.decoder(
                decoder_input, decoder_context, decoder_hidden,
                encoder_outputs)

            # visual[:, di] = attn_weights.squeeze(0).squeeze(0).cpu().data
            loss += self.criterion(decoder_output, targets[di])

            if use_teacher_forcing:
                decoder_input = targets[di]
            else:  # Use the predicted word as the next input
                topv, topi = decoder_output.data.topk(1)
                ni = topi[0][0]
                predictions.append(ni)
                if ni == self.end_token:
                    break
                decoder_input = smart_variable([ni], "list")

        return loss, predictions
Esempio n. 2
0
 def forward(self, decoder_hidden, encoder_outputs):
     # Create variable to store attention scores           # seq_len = batch_size
     seq_len = len(encoder_outputs)
     attn_scores = smart_variable(torch.zeros(seq_len))    # B (batch_size)
     # Calculate scores for each encoder output
     for i in range(seq_len):           # h_j            h_i
             attn_scores[i] = self.score(decoder_hidden, encoder_outputs[i]).squeeze(0)
     # Normalize scores into weights in range 0 to 1, resize to 1 x 1 x B
     attn_weights = F.softmax(attn_scores, dim=0).unsqueeze(0).unsqueeze(0)
     return attn_weights
Esempio n. 3
0
    def forward(self, word_input, last_context, prev_hidden, encoder_outputs):
        # Get the embedding of the current input word (i.e. last output word)
        embedded = self.embedding(word_input).view(1, 1, -1)        # 1 x 1 x N
        embedded = self.dropout(embedded)
        embedded = smart_variable(embedded, "var")
        # Combine input word embedding and previous hidden state, run through RNN
        rnn_input = torch.cat((embedded, last_context), dim=2)
        # pdb.set_trace()
        rnn_output, current_hidden = self.gru(rnn_input, prev_hidden)

        # Calculate attention from current RNN state and encoder outputs, then apply
        # Drop first dimension to line up with single encoder_output
        decoder_hidden = current_hidden.squeeze(0)    # (1 x 1 x N) --> 1 x N
        attn_weights = self.attn(decoder_hidden, encoder_outputs)  # 1 x 1 x S
         # [1 x (1xS)(SxN)] = [1 x (1xN)] = 1 x 1 x N)   where S is seq_len of encoder
        attn_context = attn_weights.bmm(encoder_outputs.transpose(0,1))

        # Predict next word using the decoder hidden state and context vector
        joined_hidden = torch.cat((current_hidden, attn_context), dim=2).squeeze(0)
        output = F.log_softmax(self.out(joined_hidden), dim=1)  # (1x2N) (2NxV) = 1xV
        return output, attn_context, current_hidden, attn_weights
Esempio n. 4
0
 def initHidden(self):
     hidden = smart_variable(torch.zeros(1, 1, self.hidden_size))
     cell = smart_variable(torch.zeros(1, 1, self.hidden_size))
     return (hidden, cell)
Esempio n. 5
0
 def forward(self, word_inputs, hidden):
     embedded = self.embedding(word_inputs).view(1, 1, -1)
     embedded = smart_variable(embedded, "var")
     output, hidden = self.lstm(embedded, hidden)
     return output, hidden
Esempio n. 6
0
 def initHidden(self):
     return smart_variable(torch.zeros(1, 1, self.hidden_size))
Esempio n. 7
0
 def forward(self, word_inputs, hidden):
     seq_len = len(word_inputs)
     embedded = self.embedding(word_inputs).view(seq_len, 1, -1)
     embedded = smart_variable(embedded, "var")
     output, hidden = self.gru(embedded, hidden)
     return output, hidden
Esempio n. 8
0
 def forward(self, word_input, prev_hidden):
     embedded = self.embedding(word_input).view(1, 1, -1)
     embedded = smart_variable(embedded, "var")
     rnn_output, current_hidden = self.lstm(embedded, prev_hidden)
     output = F.log_softmax(self.out(rnn_output[0]), dim=1)
     return output, current_hidden