def test_transformer_tsp_multisamples(device):
    '''multivariate test'''
    cuda_check(device)

    start = time.time()
    tsp = TimeSeriesPredictor(
        Transformer(d_model=12),
        lr=1e-5,
        lambda1=1e-8,
        optimizer__weight_decay=1e-8,
        iterator_train__shuffle=True,
        early_stopping=EarlyStopping(patience=100),
        max_epochs=500,
        train_split=CVSplit(10),
        optimizer=Adam,
        device=device,
    )

    past_pattern_length = 24
    future_pattern_length = 12
    pattern_length = past_pattern_length + future_pattern_length
    # pylint: disable-next=line-too-long
    fsd = FlightSeriesDataset(pattern_length,
                              future_pattern_length,
                              pattern_length,
                              stride=1,
                              generate_test_dataset=True)
    tsp.fit(fsd)
    end = time.time()
    elapsed = timedelta(seconds=end - start)
    print(f"Fitting in {device} time delta: {elapsed}")

    mean_r2_score = tsp.score(tsp.dataset)
    assert mean_r2_score > -0.5

    netout = tsp.predict(fsd.test.x)

    idx = np.random.randint(0, len(fsd.test.x))

    y_true = fsd.test.y[idx, :, :]
    y_hat = netout[idx, :, :]
    r2s = r2_score(y_true, y_hat)
    assert r2s > -1
    print(f"Final R2 score: {r2s}")
def test_main(stride, test_main_context):
    context = test_main_context(stride)
    past_pattern_length = context['past_pattern_length']
    future_pattern_length = context['future_pattern_length']
    pattern_length = past_pattern_length + future_pattern_length
    tsp = TimeSeriesPredictor(
        BenchmarkLSTM(
            initial_forget_gate_bias=1,
            hidden_dim=7,
            num_layers=1,
        ),
        lr=context['lr'],
        lambda1=1e-8,
        optimizer__weight_decay=1e-8,
        iterator_train__shuffle=True,
        early_stopping=EarlyStopping(patience=100),
        max_epochs=500,
        train_split=CVSplit(context['n_cv_splits']),
        optimizer=Adam,
    )
    fsd = FlightSeriesDataset(pattern_length,
                              future_pattern_length,
                              context['except_last_n'],
                              stride=stride,
                              generate_test_dataset=True)
    tsp.fit(fsd)

    mean_r2_score = tsp.score(tsp.dataset)
    assert mean_r2_score > context['mean_r2_score']

    netout = tsp.predict(fsd.test.x)

    idx = np.random.randint(0, len(fsd.test.x))

    y_true = fsd.test.y[idx, :, :]
    y_hat = netout[idx, :, :]
    r2s = r2_score(y_true, y_hat)
    print("Final R2 score: {}".format(r2s))
    assert r2s > context['final_r2_score']