Ejemplo n.º 1
0
    def _lstm_test(self, layers, bidirectional, initial_state, packed_sequence,
                   dropout):
        model = LstmDiscardingCellState(RNN_INPUT_SIZE,
                                        RNN_HIDDEN_SIZE,
                                        layers,
                                        bidirectional=bidirectional,
                                        dropout=dropout)
        if packed_sequence:
            model = RnnModelWithPackedSequence(model)

        seq_lengths = np.random.randint(1,
                                        RNN_SEQUENCE_LENGTH + 1,
                                        size=RNN_BATCH_SIZE)
        seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
        inputs = [
            Variable(torch.randn(l, RNN_INPUT_SIZE)) for l in seq_lengths
        ]
        inputs = [rnn_utils.pad_sequence(inputs)]

        directions = 2 if bidirectional else 1

        if initial_state:
            h0 = Variable(
                torch.randn(directions * layers, RNN_BATCH_SIZE,
                            RNN_HIDDEN_SIZE))
            c0 = Variable(
                torch.randn(directions * layers, RNN_BATCH_SIZE,
                            RNN_HIDDEN_SIZE))
            inputs.append((h0, c0))
        if packed_sequence:
            inputs.append(Variable(torch.IntTensor(seq_lengths)))
        if len(inputs) == 1:
            input = inputs[0]
        else:
            input = tuple(inputs)
        self.run_model_test(model,
                            train=False,
                            batch_size=RNN_BATCH_SIZE,
                            input=input,
                            use_gpu=False)
Ejemplo n.º 2
0
 def test_lstm(self):
     rnn = LstmDiscardingCellState(10, 20, 2)
     input = Variable(torch.randn(5, 3, 10))
     h0 = Variable(torch.randn(2, 3, 20))
     c0 = Variable(torch.randn(2, 3, 20))
     self.assertONNX(rnn, input, (h0, c0))