def test_half_padding_args(self): x = torch.ones(4, 1, 28, 28) conv_half = Conv().eval() conv_half(x) conv = Conv(2, p=1, s=2).eval() conv(x) assert conv._args == conv_half._args
def test_mul_list(self): conv_layer = Conv().eval() cs = (5, 3, 2) convs = conv_layer * cs assert convs[0] is conv_layer assert all(conv._args[k] == conv_layer._args[k] for k in conv_layer._args.keys() if k != 'c' for conv in convs) assert all(conv._args['c'] == c for conv, c in zip(convs, cs))
def test_conv_3d(self): conv = Conv().eval() conv(torch.randn(1, 2, 3, 4, 5)) assert isinstance(conv.layer, nn.Conv3d)
def test_double_padding(self): conv = Conv(p='double').eval() assert conv(torch.ones(4, 2, 28, 28)).shape == (4, 1, 56, 56)
def test_same_padding(self): conv = Conv(p='same').eval() assert conv(torch.ones(4, 1, 28, 28)).shape == (4, 1, 28, 28)
def test_half_padding(self): conv = Conv().eval() assert conv(torch.ones(4, 1, 28, 28)).shape == (4, 2, 14, 14)