def test_jit(self, device, dtype): batch_size, in_dims = 1, 175 patches = torch.rand(batch_size, in_dims).to(device) model = Whitening(xform='lw', whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval() model_jit = torch.jit.script( Whitening(xform='lw', whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval()) assert_close(model(patches), model_jit(patches))
def test_batch_shape(self, bs, device): wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=128).to(device) inp = torch.ones(bs, 175).to(device) out = wh(inp) assert out.shape == (bs, 128)
def test_toy(self, device): wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=175).to(device) inp = torch.ones(1, 175).to(device).float() out = wh(inp) expected = torch.ones_like(inp).to(device) * 0.0756 assert_close(out, expected, atol=1e-3, rtol=1e-3)
def test_shape(self, kernel_type, xform, output_dims, device): in_dims = 63 if kernel_type == 'cart' else 175 wh = Whitening(xform=xform, whitening_model=None, in_dims=in_dims, output_dims=output_dims).to(device) inp = torch.ones(1, in_dims).to(device) out = wh(inp) assert out.shape == (1, output_dims)
def whitening_describe(patches, in_dims=175): wh = Whitening(xform='lw', whitening_model=None, in_dims=in_dims).double() wh.to(device) return wh(patches.double())
def test_print(self, device): wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=128).to(device) wh.__repr__()