def test_basic(self, Xy_dummy, transform_type, hidden_size): eps = 1e-4 X, _, _, _ = Xy_dummy n_samples, n_channels, lookback, n_assets = X.shape dtype = X.dtype device = X.device network = KeynesNet(n_channels, hidden_size=hidden_size, transform_type=transform_type, n_groups=2) # network.to(device=device, dtype=dtype) weights = network(X) assert isinstance(network.hparams, dict) assert network.hparams assert torch.is_tensor(weights) assert weights.shape == (n_samples, n_assets) assert X.device == weights.device assert X.dtype == weights.dtype assert torch.allclose(weights.sum(dim=1), torch.ones(n_samples).to(dtype=dtype, device=device), atol=eps)
def test_n_params(self, n_input_channels, hidden_size, n_groups, transform_type): network = KeynesNet(n_input_channels=n_input_channels, hidden_size=hidden_size, n_groups=n_groups, transform_type=transform_type) expected = 0 expected += n_input_channels * 2 # instance norm if transform_type == 'Conv': expected += n_input_channels * 3 * hidden_size + hidden_size else: expected += 4 * ((n_input_channels * hidden_size) + (hidden_size * hidden_size) + 2 * hidden_size) expected += 2 * hidden_size # group_norm expected += 1 # temperature actual = sum(p.numel() for p in network.parameters() if p.requires_grad) assert expected == actual
def test_error(self): with pytest.raises(ValueError): KeynesNet(2, transform_type='FAKE', hidden_size=10, n_groups=2) with pytest.raises(ValueError): KeynesNet(2, hidden_size=10, n_groups=3)