Beispiel #1
0
def test_raise_long_history(org):
    gen = copy.deepcopy(org.genotype)
    gen["Data"]["history_days"] = 2000
    phenotype = gen.get_phenotype()
    net = model.Model(tuple(org._data.tickers), org._data.date, phenotype,
                      None)
    with pytest.raises(model.ModelError) as error:
        # noinspection PyStatementEffect
        net.llh
    assert issubclass(error.type, model.TooLongHistoryError)
Beispiel #2
0
def test_raise_gradient_error(org):
    gen = copy.deepcopy(org.genotype)
    gen["Scheduler"]["epochs"] /= 10
    gen["Scheduler"]["max_lr"] = 10
    phenotype = gen.get_phenotype()
    net = model.Model(tuple(org._data.tickers), org._data.date, phenotype, None)
    with pytest.raises(model.ModelError) as error:
        # noinspection PyStatementEffect
        net.llh
    assert issubclass(error.type, model.GradientsError)
Beispiel #3
0
def test_llh_from_trained_and_reloaded_model(org):
    gen = copy.deepcopy(org.genotype)
    # Для ускорения обучения
    gen["Scheduler"]["epochs"] /= 10
    phenotype = gen.get_phenotype()

    net = model.Model(tuple(org._data.tickers), org._data.date, phenotype, None)
    assert bytes(net) == bytes()

    llh = net.llh
    pickled_model = bytes(net)

    net = model.Model(tuple(org._data.tickers), org._data.date, phenotype, pickled_model)
    assert llh == net.llh
    assert bytes(net) == pickled_model

    # Из кеша
    assert llh == net.llh
    assert llh == net._eval_llh()
Beispiel #4
0
def test_forecast(org):
    gen = copy.deepcopy(org.genotype)
    gen["Scheduler"]["epochs"] /= 10
    phenotype = gen.get_phenotype()
    net = model.Model(tuple(org._data.tickers), org._data.date, phenotype, org._data.model)
    forecast = net.forecast()

    assert isinstance(forecast, Forecast)
    assert forecast.tickers == tuple(org._data.tickers)
    assert forecast.date == org._data.date
    assert forecast.history_days == phenotype["data"]["history_days"]
    assert isinstance(forecast.mean, pd.Series)
    assert forecast.mean.index.tolist() == list(org._data.tickers)
    assert isinstance(forecast.std, pd.Series)
    assert forecast.std.index.tolist() == list(org._data.tickers)