def test_wave_net_bn(loader): batch = next(iter(loader)) batch2 = copy.deepcopy(batch) batch2["Prices"] = batch2["Prices"][50:, :] batch2["Dividends"] = batch2["Dividends"][50:, :] batch2["DayOfYear"] = batch2["DayOfYear"][50:, :] batch2["Ticker"] = batch2["Ticker"][50:] net = wave_net.WaveNet(loader.history_days, loader.features_description, **NET_PARAMS) net.eval() l1, m1, s1 = net(batch) l2, m2, s2 = net(batch2) assert l1.shape == (100, 1, 3) assert m1.shape == (100, 1, 3) assert s1.shape == (100, 1, 3) assert l2.shape == (50, 1, 3) assert m2.shape == (50, 1, 3) assert s2.shape == (50, 1, 3) assert l2.allclose(l1[50:, :]) assert m2.allclose(m1[50:, :]) assert s2.allclose(s1[50:, :])
def test_dist(loader): batch = next(iter(loader)) net = wave_net.WaveNet(loader.history_days, loader.features_description, **NET_PARAMS) dist = net.dist(batch) assert isinstance(dist, distributions.MixtureSameFamily) assert dist.mean.shape == (100, 1) assert dist.variance.shape == (100, 1) llh = dist.log_prob(batch["Label"] + torch.tensor(1.0)) assert llh.shape == (100, 1)
def test_wave_net_no_embedding(loader_no_emb): batch = next(iter(loader_no_emb)) batch2 = copy.deepcopy(batch) batch2["Prices"] = batch2["Prices"][60:, :] batch2["Dividends"] = batch2["Dividends"][60:, :] net = wave_net.WaveNet(loader_no_emb.features_description, **NET_PARAMS) net.eval() m1, s1 = net(batch) m2, s2 = net(batch2) assert m1.shape == (100, 1) assert s1.shape == (100, 1) assert m2.shape == (40, 1) assert s2.shape == (40, 1) assert m2.allclose(m1[60:, :]) assert s2.allclose(s1[60:, :])
def test_wave_net_no_bn(loader): batch = next(iter(loader)) batch2 = copy.deepcopy(batch) batch2["Prices"] = batch2["Prices"][:40, :] batch2["Dividends"] = batch2["Dividends"][:40, :] batch2["DayOfYear"] = batch2["DayOfYear"][:40, :] batch2["Ticker"] = batch2["Ticker"][:40] NET_PARAMS["start_bn"] = False net = wave_net.WaveNet(loader.features_description, **NET_PARAMS) m1, s1 = net(batch) m2, s2 = net(batch2) assert m1.shape == (100, 1) assert s1.shape == (100, 1) assert m2.shape == (40, 1) assert s2.shape == (40, 1) assert m2.allclose(m1[:40, :]) assert s2.allclose(s1[:40, :])