Esempio n. 1
0
    def test_basic(self, Xy_dummy, transform_type, hidden_size):
        eps = 1e-4

        X, _, _, _ = Xy_dummy
        n_samples, n_channels, lookback, n_assets = X.shape
        dtype = X.dtype
        device = X.device

        network = KeynesNet(n_channels,
                            hidden_size=hidden_size,
                            transform_type=transform_type,
                            n_groups=2)  #
        network.to(device=device, dtype=dtype)

        weights = network(X)

        assert isinstance(network.hparams, dict)
        assert network.hparams
        assert torch.is_tensor(weights)
        assert weights.shape == (n_samples, n_assets)
        assert X.device == weights.device
        assert X.dtype == weights.dtype
        assert torch.allclose(weights.sum(dim=1),
                              torch.ones(n_samples).to(dtype=dtype,
                                                       device=device),
                              atol=eps)
Esempio n. 2
0
    def test_n_params(self, n_input_channels, hidden_size, n_groups, transform_type):
        network = KeynesNet(n_input_channels=n_input_channels,
                            hidden_size=hidden_size,
                            n_groups=n_groups,
                            transform_type=transform_type)

        expected = 0
        expected += n_input_channels * 2  # instance norm
        if transform_type == 'Conv':
            expected += n_input_channels * 3 * hidden_size + hidden_size
        else:
            expected += 4 * ((n_input_channels * hidden_size) + (hidden_size * hidden_size) + 2 * hidden_size)
        expected += 2 * hidden_size  # group_norm
        expected += 1  # temperature

        actual = sum(p.numel() for p in network.parameters() if p.requires_grad)

        assert expected == actual
Esempio n. 3
0
    def test_error(self):
        with pytest.raises(ValueError):
            KeynesNet(2, transform_type='FAKE', hidden_size=10, n_groups=2)

        with pytest.raises(ValueError):
            KeynesNet(2, hidden_size=10, n_groups=3)