Пример #1
0
def test_no_fc():
    no_fc_net_spec = deepcopy(net_spec)
    no_fc_net_spec['fc_hid_layers'] = []
    net = RecurrentNet(no_fc_net_spec, in_dim, out_dim)
    assert isinstance(net, nn.Module)
    assert not hasattr(net, 'fc_model')
    assert hasattr(net, 'rnn_model')
    assert hasattr(net, 'model_tail')
    assert not hasattr(net, 'model_tails')

    y = net.forward(x)
    assert y.shape == (batch_size, out_dim)
Пример #2
0
def test_multitails():
    net = RecurrentNet(net_spec, in_dim, [3, 4])
    assert isinstance(net, nn.Module)
    assert hasattr(net, 'fc_model')
    assert hasattr(net, 'rnn_model')
    assert not hasattr(net, 'model_tail')
    assert hasattr(net, 'model_tails')
    assert len(net.model_tails) == 2

    y = net.forward(x)
    assert len(y) == 2
    assert y[0].shape == (batch_size, 3)
    assert y[1].shape == (batch_size, 4)
Пример #3
0
def test_variant(bidirectional, cell_type):
    var_net_spec = deepcopy(net_spec)
    var_net_spec['bidirectional'] = bidirectional
    var_net_spec['cell_type'] = cell_type
    net = RecurrentNet(var_net_spec, in_dim, out_dim)
    assert isinstance(net, nn.Module)
    assert hasattr(net, 'fc_model')
    assert hasattr(net, 'rnn_model')
    assert hasattr(net, 'model_tail')
    assert not hasattr(net, 'model_tails')
    assert net.rnn_model.bidirectional == bidirectional

    y = net.forward(x)
    assert y.shape == (batch_size, out_dim)
Пример #4
0
def test_init():
    net = RecurrentNet(net_spec, in_dim, out_dim)
    assert isinstance(net, nn.Module)
    assert hasattr(net, 'fc_model')
    assert hasattr(net, 'rnn_model')
    assert hasattr(net, 'model_tail')
    assert not hasattr(net, 'model_tails')
    assert net.rnn_model.bidirectional == False
Пример #5
0
    "optim_spec": {
        "name": "Adam",
        "lr": 0.02
    },
    "lr_scheduler_spec": {
        "name": "StepLR",
        "step_size": 30,
        "gamma": 0.1
    },
    "gpu": True
}
in_dim = 10
out_dim = 3
batch_size = 16
seq_len = net_spec['seq_len']
net = RecurrentNet(net_spec, in_dim, out_dim)
x = torch.rand((batch_size, seq_len, in_dim))


def test_init():
    net = RecurrentNet(net_spec, in_dim, out_dim)
    assert isinstance(net, nn.Module)
    assert hasattr(net, 'fc_model')
    assert hasattr(net, 'rnn_model')
    assert hasattr(net, 'model_tail')
    assert not hasattr(net, 'model_tails')
    assert net.rnn_model.bidirectional == False


def test_forward():
    y = net.forward(x)