def test_asset_specific(self): n_samples, n_channels, lookback, n_assets = 2, 3, 4, 5 X = torch.rand(n_samples, n_channels, lookback, n_assets) tform = torch.randn(n_samples, lookback, n_assets, dtype=X.dtype) layer = Warp() x_warped = layer(X, tform) assert X.shape == x_warped.shape
def test_no_change(self, mode, padding_mode): # scale=1 n_samples, n_channels, lookback, n_assets = 2, 3, 4, 5 X = torch.rand(n_samples, n_channels, lookback, n_assets) tform = torch.stack(n_samples * [torch.linspace(-1, end=1, steps=lookback)], dim=0) layer = Warp(mode=mode, padding_mode=padding_mode) x_warped = layer(X, tform) assert torch.allclose(x_warped, X, atol=1e-5)
def test_equality_with_warp(self): n_samples, n_channels, lookback, n_assets = 2, 3, 4, 5 X = torch.rand(n_samples, n_channels, lookback, n_assets) scale = torch.ones(n_samples, dtype=X.dtype) * 0.5 tform = torch.stack(n_samples * [torch.linspace(0, end=1, steps=lookback)], dim=0) layer_zoom = Zoom() layer_warp = Warp() x_zoomed = layer_zoom(X, scale) x_warped = layer_warp(X, tform) assert torch.allclose(x_zoomed, x_warped)
def test_basic(self, Xy_dummy, mode, padding_mode): X, _, _, _ = Xy_dummy dtype, device = X.dtype, X.device n_samples, _, lookback, n_assets = X.shape layer = Warp(mode=mode, padding_mode=padding_mode) tform_ = torch.rand(n_samples, lookback, dtype=dtype, device=device) tform_cumsum = tform_.cumsum(dim=-1) tform = 2 * (tform_cumsum / tform_cumsum.max(dim=1, keepdim=True)[0] - 0.5) x_warped = layer(X, tform) assert torch.is_tensor(x_warped) assert x_warped.shape == X.shape assert x_warped.dtype == X.dtype assert x_warped.device == X.device
def test_n_parameters(self): n_parameters = sum(p.numel() for p in Warp().parameters() if p.requires_grad) assert n_parameters == 0
def test_error(self): with pytest.raises(ValueError): Warp()(torch.rand(1, 2, 3, 4), torch.ones(4, ))
sin_single(lookback, freq=4 / lookback))[None, None, :, None] x = torch.as_tensor(x_np) grid = torch.linspace(0, end=1, steps=lookback)[None, :].to(dtype=x.dtype) transform_dict = { 'identity': lambda x: 2 * (x - 0.5), 'zoom': lambda x: x, 'backwards': lambda x: -2 * (x - 0.5), 'slowdown\_start': lambda x: 2 * (x**3 - 0.5), 'slowdown\_end': lambda x: 2 * (x**(1 / 3) - 0.5), } n_tforms = len(transform_dict) _, axs = plt.subplots(n_tforms, 2, figsize=(16, 3 * n_tforms), sharex=True, sharey=True) layer = Warp() for i, (tform_name, tform_lambda) in enumerate(transform_dict.items()): tform = tform_lambda(grid) x_warped = layer(x, tform) axs[i, 0].plot(tform.numpy().squeeze(), linewidth=3, color='red') axs[i, 1].plot(x_warped.numpy().squeeze(), linewidth=3, color='blue') axs[i, 0].set_title(r'$\bf{}$ tform'.format(tform_name)) axs[i, 1].set_title(r'$\bf{}$ warped'.format(tform_name))