示例#1
0
    def test_wrong_construction_1(self, dataloader_dummy):
        """Wrong positional arguments."""
        with pytest.raises(TypeError):
            Run('this_is_fake', MeanReturns(), dataloader_dummy)

        with pytest.raises(TypeError):
            Run(DummyNet(), 'this_is_fake', dataloader_dummy)

        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), 'this_is_fake')
示例#2
0
    def test_wrong_construction_2(self, dataloader_dummy):
        """Wrong keyword arguments."""
        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, metrics='this_is_fake')

        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, metrics={'a': 'this_is_fake'})

        with pytest.raises(ValueError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, metrics={'loss': MeanReturns()})

        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, val_dataloaders='this_is_fake')

        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, val_dataloaders={'val': 'this_is_fake'})

        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, benchmarks='this_is_fake')

        with pytest.raises(TypeError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, benchmarks={'uniform': 'this_is_fake'})

        with pytest.raises(ValueError):
            Run(DummyNet(), MeanReturns(), dataloader_dummy, benchmarks={'main': OneOverN()})
示例#3
0
def test_basic():
    n_channels = 2
    x = torch.rand(10, n_channels, 4, 5)
    network = DummyNet(n_channels=n_channels)
    y = network(x)

    print(y)
示例#4
0
    def test_basic(self, Xy_dummy):
        X, _, _, _ = Xy_dummy
        n_samples, n_channels, lookback, n_assets = X.shape
        dtype = X.dtype
        device = X.device

        network = DummyNet(n_channels=n_channels)
        network.to(device=device, dtype=dtype)

        weights = network(X)

        assert torch.is_tensor(weights)
        assert weights.shape == (n_samples, n_assets)
        assert X.device == weights.device
        assert X.dtype == weights.dtype
        assert torch.allclose(weights.sum(dim=1), torch.ones(n_samples).to(dtype=dtype, device=device), atol=1e-4)
示例#5
0
    def test_launch_interrupt(self, dataloader_dummy, monkeypatch):
        network = DummyNet(n_channels=dataloader_dummy.dataset.X.shape[1])
        loss = MeanReturns()

        class TempCallback(Callback):
            def on_train_begin(self, metadata):
                raise KeyboardInterrupt()

        monkeypatch.setattr('time.sleep', lambda x: None)
        run = Run(network, loss, dataloader_dummy, callbacks=[TempCallback()])

        run.launch(n_epochs=1)
示例#6
0
    def test_attributes_after_construction(self, dataloader_dummy, additional_kwargs):
        network = DummyNet()
        loss = MeanReturns()

        kwargs = {}
        if additional_kwargs:
            kwargs.update({'metrics': {'std': StandardDeviation()},
                           'val_dataloaders': {'val': dataloader_dummy},
                           'benchmarks': {'whatever': OneOverN()}})

        run = Run(network, loss, dataloader_dummy, **kwargs)

        assert network is run.network
        assert loss is run.loss
        assert dataloader_dummy is run.train_dataloader
        assert isinstance(run.metrics, dict)
        assert isinstance(run.val_dataloaders, dict)
示例#7
0
    def test_launch(self, dataloader_dummy):
        network = DummyNet(n_channels=dataloader_dummy.dataset.X.shape[1])
        loss = MeanReturns()
        run = Run(network, loss, dataloader_dummy)

        run.launch(n_epochs=1)
示例#8
0
def network_dummy(dataset_dummy):
    return DummyNet(n_channels=dataset_dummy.n_channels)