예제 #1
0
def test_wave_net_bn(loader):
    batch = next(iter(loader))
    batch2 = copy.deepcopy(batch)
    batch2["Prices"] = batch2["Prices"][50:, :]
    batch2["Dividends"] = batch2["Dividends"][50:, :]
    batch2["DayOfYear"] = batch2["DayOfYear"][50:, :]
    batch2["Ticker"] = batch2["Ticker"][50:]

    net = wave_net.WaveNet(loader.history_days, loader.features_description,
                           **NET_PARAMS)
    net.eval()
    l1, m1, s1 = net(batch)
    l2, m2, s2 = net(batch2)

    assert l1.shape == (100, 1, 3)
    assert m1.shape == (100, 1, 3)
    assert s1.shape == (100, 1, 3)

    assert l2.shape == (50, 1, 3)
    assert m2.shape == (50, 1, 3)
    assert s2.shape == (50, 1, 3)

    assert l2.allclose(l1[50:, :])
    assert m2.allclose(m1[50:, :])
    assert s2.allclose(s1[50:, :])
예제 #2
0
def test_dist(loader):
    batch = next(iter(loader))

    net = wave_net.WaveNet(loader.history_days, loader.features_description,
                           **NET_PARAMS)
    dist = net.dist(batch)

    assert isinstance(dist, distributions.MixtureSameFamily)

    assert dist.mean.shape == (100, 1)
    assert dist.variance.shape == (100, 1)

    llh = dist.log_prob(batch["Label"] + torch.tensor(1.0))
    assert llh.shape == (100, 1)
예제 #3
0
def test_wave_net_no_embedding(loader_no_emb):
    batch = next(iter(loader_no_emb))
    batch2 = copy.deepcopy(batch)
    batch2["Prices"] = batch2["Prices"][60:, :]
    batch2["Dividends"] = batch2["Dividends"][60:, :]

    net = wave_net.WaveNet(loader_no_emb.features_description, **NET_PARAMS)
    net.eval()
    m1, s1 = net(batch)
    m2, s2 = net(batch2)

    assert m1.shape == (100, 1)
    assert s1.shape == (100, 1)

    assert m2.shape == (40, 1)
    assert s2.shape == (40, 1)

    assert m2.allclose(m1[60:, :])
    assert s2.allclose(s1[60:, :])
예제 #4
0
def test_wave_net_no_bn(loader):
    batch = next(iter(loader))
    batch2 = copy.deepcopy(batch)
    batch2["Prices"] = batch2["Prices"][:40, :]
    batch2["Dividends"] = batch2["Dividends"][:40, :]
    batch2["DayOfYear"] = batch2["DayOfYear"][:40, :]
    batch2["Ticker"] = batch2["Ticker"][:40]

    NET_PARAMS["start_bn"] = False
    net = wave_net.WaveNet(loader.features_description, **NET_PARAMS)
    m1, s1 = net(batch)
    m2, s2 = net(batch2)

    assert m1.shape == (100, 1)
    assert s1.shape == (100, 1)

    assert m2.shape == (40, 1)
    assert s2.shape == (40, 1)

    assert m2.allclose(m1[:40, :])
    assert s2.allclose(s1[:40, :])