def test_basic(self, Xy_dummy, dim_ix): X, _, _, _ = Xy_dummy layer_inst = MultiplyByConstant(dim_ix=dim_ix, dim_size=X.shape[dim_ix]) layer_inst.to(device=X.device, dtype=X.dtype) res = layer_inst(X) assert torch.is_tensor(res) assert X.device == res.device assert X.dtype == res.dtype assert res.shape == X.shape
def test_error(self): with pytest.raises(ValueError): MultiplyByConstant(dim_ix=1, dim_size=2)(torch.ones((2, 3)))