Ejemplo n.º 1
0
def test_forward_lstm(input_size, hidden_size, num_layers, bidirectional):
    # Test conversion from pytorch implementation
    lstm = nn.LSTM(input_size,
                   hidden_size,
                   num_layers,
                   bidirectional=bidirectional)
    lstm_man = DistillerLSTM.from_pytorch_impl(lstm)
    lstm.eval()
    lstm_man.eval()

    h = lstm_man.init_hidden(BATCH_SIZE)
    x = torch.rand(SEQUENCE_SIZE, BATCH_SIZE, input_size)

    out_true = lstm(x, h)
    out_pred = lstm_man(x, h)
    assert_output(out_true, out_pred)
    # Test conversion to pytorch implementation
    lstm_man = DistillerLSTM(input_size,
                             hidden_size,
                             num_layers,
                             bidirectional=bidirectional)
    lstm = lstm_man.to_pytorch_impl()

    lstm.eval()
    lstm_man.eval()

    h = lstm_man.init_hidden(BATCH_SIZE)
    x = torch.rand(SEQUENCE_SIZE, BATCH_SIZE, input_size)

    out_true = lstm(x, h)
    out_pred = lstm_man(x, h)
    assert_output(out_true, out_pred)
Ejemplo n.º 2
0
def test_basic():
    lstmcell = DistillerLSTMCell(3, 5)
    assert lstmcell.fc_gate_x.weight.shape == (5 * 4, 3)
    assert lstmcell.fc_gate_h.weight.shape == (5 * 4, 5)
    assert lstmcell.fc_gate_x.bias.shape == (5 * 4, )
    assert lstmcell.fc_gate_h.bias.shape == (5 * 4, )

    lstm = DistillerLSTM(3, 5, 4, False, False, 0.0, True)
    assert lstm.bidirectional_type == 2
    assert lstm.cells[0].fc_gate_x.weight.shape == (5 * 4, 3)
    assert lstm.cells[1].fc_gate_x.weight.shape == (5 * 4, 5 * 2)
Ejemplo n.º 3
0
def test_conversion():
    lc_man = DistillerLSTMCell(3, 5)
    lc_pth = lc_man.to_pytorch_impl()
    lc_man1 = DistillerLSTMCell.from_pytorch_impl(lc_pth)

    assert (lc_man.fc_gate_x.weight == lc_man1.fc_gate_x.weight).all()
    assert (lc_man.fc_gate_h.weight == lc_man1.fc_gate_h.weight).all()

    l_man = DistillerLSTM(3, 5, 2)
    l_pth = l_man.to_pytorch_impl()
    l_man1 = DistillerLSTM.from_pytorch_impl(l_pth)

    for i in range(l_man.num_layers):
        assert (l_man1.cells[i].fc_gate_x.weight ==
                l_man.cells[i].fc_gate_x.weight).all()
        assert (l_man1.cells[i].fc_gate_h.weight ==
                l_man.cells[i].fc_gate_h.weight).all()
        assert (l_man1.cells[i].fc_gate_x.bias == l_man.cells[i].fc_gate_x.bias
                ).all()
        assert (l_man1.cells[i].fc_gate_h.bias == l_man.cells[i].fc_gate_h.bias
                ).all()
Ejemplo n.º 4
0
def test_packed_sequence(input_size, hidden_size, num_layers, input_lengths,
                         bidirectional):
    lstm = nn.LSTM(input_size,
                   hidden_size,
                   num_layers,
                   bidirectional=bidirectional)
    lstm_man = DistillerLSTM.from_pytorch_impl(lstm)
    lstm.eval()
    lstm_man.eval()

    h = lstm_man.init_hidden(BATCH_SIZE)
    x = pack_sequence(
        [torch.rand(length, input_size) for length in input_lengths])
    out_true = lstm(x)
    out_pred = lstm_man(x)
    assert_output(out_true, out_pred)