Example #1
0
def test_forward_lstm(input_size, hidden_size, num_layers, bidirectional):
    # Test conversion from pytorch implementation
    lstm = nn.LSTM(input_size,
                   hidden_size,
                   num_layers,
                   bidirectional=bidirectional)
    lstm_man = DistillerLSTM.from_pytorch_impl(lstm)
    lstm.eval()
    lstm_man.eval()

    h = lstm_man.init_hidden(BATCH_SIZE)
    x = torch.rand(SEQUENCE_SIZE, BATCH_SIZE, input_size)

    out_true = lstm(x, h)
    out_pred = lstm_man(x, h)
    assert_output(out_true, out_pred)
    # Test conversion to pytorch implementation
    lstm_man = DistillerLSTM(input_size,
                             hidden_size,
                             num_layers,
                             bidirectional=bidirectional)
    lstm = lstm_man.to_pytorch_impl()

    lstm.eval()
    lstm_man.eval()

    h = lstm_man.init_hidden(BATCH_SIZE)
    x = torch.rand(SEQUENCE_SIZE, BATCH_SIZE, input_size)

    out_true = lstm(x, h)
    out_pred = lstm_man(x, h)
    assert_output(out_true, out_pred)
Example #2
0
def test_conversion():
    lc_man = DistillerLSTMCell(3, 5)
    lc_pth = lc_man.to_pytorch_impl()
    lc_man1 = DistillerLSTMCell.from_pytorch_impl(lc_pth)

    assert (lc_man.fc_gate_x.weight == lc_man1.fc_gate_x.weight).all()
    assert (lc_man.fc_gate_h.weight == lc_man1.fc_gate_h.weight).all()

    l_man = DistillerLSTM(3, 5, 2)
    l_pth = l_man.to_pytorch_impl()
    l_man1 = DistillerLSTM.from_pytorch_impl(l_pth)

    for i in range(l_man.num_layers):
        assert (l_man1.cells[i].fc_gate_x.weight ==
                l_man.cells[i].fc_gate_x.weight).all()
        assert (l_man1.cells[i].fc_gate_h.weight ==
                l_man.cells[i].fc_gate_h.weight).all()
        assert (l_man1.cells[i].fc_gate_x.bias == l_man.cells[i].fc_gate_x.bias
                ).all()
        assert (l_man1.cells[i].fc_gate_h.bias == l_man.cells[i].fc_gate_h.bias
                ).all()