def test_wrong_construction_1(self, dataloader_dummy): """Wrong positional arguments.""" with pytest.raises(TypeError): Run('this_is_fake', MeanReturns(), dataloader_dummy) with pytest.raises(TypeError): Run(DummyNet(), 'this_is_fake', dataloader_dummy) with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), 'this_is_fake')
def test_wrong_construction_2(self, dataloader_dummy): """Wrong keyword arguments.""" with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), dataloader_dummy, metrics='this_is_fake') with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), dataloader_dummy, metrics={'a': 'this_is_fake'}) with pytest.raises(ValueError): Run(DummyNet(), MeanReturns(), dataloader_dummy, metrics={'loss': MeanReturns()}) with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), dataloader_dummy, val_dataloaders='this_is_fake') with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), dataloader_dummy, val_dataloaders={'val': 'this_is_fake'}) with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), dataloader_dummy, benchmarks='this_is_fake') with pytest.raises(TypeError): Run(DummyNet(), MeanReturns(), dataloader_dummy, benchmarks={'uniform': 'this_is_fake'}) with pytest.raises(ValueError): Run(DummyNet(), MeanReturns(), dataloader_dummy, benchmarks={'main': OneOverN()})
def test_basic(): n_channels = 2 x = torch.rand(10, n_channels, 4, 5) network = DummyNet(n_channels=n_channels) y = network(x) print(y)
def test_basic(self, Xy_dummy): X, _, _, _ = Xy_dummy n_samples, n_channels, lookback, n_assets = X.shape dtype = X.dtype device = X.device network = DummyNet(n_channels=n_channels) network.to(device=device, dtype=dtype) weights = network(X) assert torch.is_tensor(weights) assert weights.shape == (n_samples, n_assets) assert X.device == weights.device assert X.dtype == weights.dtype assert torch.allclose(weights.sum(dim=1), torch.ones(n_samples).to(dtype=dtype, device=device), atol=1e-4)
def test_launch_interrupt(self, dataloader_dummy, monkeypatch): network = DummyNet(n_channels=dataloader_dummy.dataset.X.shape[1]) loss = MeanReturns() class TempCallback(Callback): def on_train_begin(self, metadata): raise KeyboardInterrupt() monkeypatch.setattr('time.sleep', lambda x: None) run = Run(network, loss, dataloader_dummy, callbacks=[TempCallback()]) run.launch(n_epochs=1)
def test_attributes_after_construction(self, dataloader_dummy, additional_kwargs): network = DummyNet() loss = MeanReturns() kwargs = {} if additional_kwargs: kwargs.update({'metrics': {'std': StandardDeviation()}, 'val_dataloaders': {'val': dataloader_dummy}, 'benchmarks': {'whatever': OneOverN()}}) run = Run(network, loss, dataloader_dummy, **kwargs) assert network is run.network assert loss is run.loss assert dataloader_dummy is run.train_dataloader assert isinstance(run.metrics, dict) assert isinstance(run.val_dataloaders, dict)
def test_launch(self, dataloader_dummy): network = DummyNet(n_channels=dataloader_dummy.dataset.X.shape[1]) loss = MeanReturns() run = Run(network, loss, dataloader_dummy) run.launch(n_epochs=1)
def network_dummy(dataset_dummy): return DummyNet(n_channels=dataset_dummy.n_channels)