def solver(generators):
    return GenericSolver(
        diff_eqs=DIFF_EQS,
        conditions=CONDITIONS,
        train_generator=generators['train'],
        valid_generator=generators['valid'],
        n_input_units=1,
        n_output_units=1,
    )
def test_legacies(solver, generators):
    solver.fit(1)

    assert solver.batch == solver._batch
    with pytest.warns(FutureWarning):
        assert solver._batch_examples == solver._batch

    with pytest.raises(TypeError), pytest.warns(FutureWarning):
        GenericSolver(
            diff_eqs=DIFF_EQS,
            conditions=CONDITIONS,
            train_generator=generators['train'],
            valid_generator=generators['valid'],
            criterion=lambda residuals, zeros: (residuals ** 2).mean(),
            n_input_units=1,
            n_output_units=1,
        ).fit(1)

    class SolverWithLegacyAdditionalLoss(BaseSolver):
        def additional_loss(self, funcs, key):
            return 0

    with pytest.raises(TypeError), pytest.warns(FutureWarning):
        SolverWithLegacyAdditionalLoss(
            diff_eqs=DIFF_EQS,
            conditions=CONDITIONS,
            train_generator=generators['train'],
            valid_generator=generators['valid'],
            n_input_units=1,
            n_output_units=1,
        ).fit(1)

    with pytest.warns(FutureWarning):
        GenericSolver(
            diff_eqs=DIFF_EQS,
            conditions=CONDITIONS,
            train_generator=generators['train'],
            valid_generator=generators['valid'],
            n_input_units=1,
            n_output_units=1,
            shuffle=True,
        )
def test_lbfgs(generators):
    nets = [FCNN()]
    lbfgs = torch.optim.LBFGS(params=nets[0].parameters(), lr=1e-3)
    GenericSolver(
        diff_eqs=DIFF_EQS,
        conditions=CONDITIONS,
        train_generator=generators['train'],
        valid_generator=generators['valid'],
        nets=nets,
        optimizer=lbfgs,
        n_input_units=1,
        n_output_units=1,
    ).fit(1)
def test_missing_generator(generators):
    with pytest.raises(ValueError):
        GenericSolver(
            diff_eqs=DIFF_EQS,
            conditions=CONDITIONS,
            train_generator=generators['train'],
            n_input_units=1,
            n_output_units=1,
        )
    with pytest.raises(ValueError):
        GenericSolver(
            diff_eqs=DIFF_EQS,
            conditions=CONDITIONS,
            valid_generator=generators['valid'],
            n_input_units=1,
            n_output_units=1,
        )
    with pytest.raises(ValueError):
        GenericSolver(
            diff_eqs=DIFF_EQS,
            conditions=CONDITIONS,
            n_input_units=1,
            n_output_units=1,
        )