def test_reshaping_of_training_data():
    x = np.zeros(5)
    dx_dt = np.zeros((5, 1))
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        training_data = ImplicitTrainingData(x, dx_dt)
    assert training_data.x.ndim == 2
def test_correct_partial_calculation_in_training_data_2():
    data_input = np.arange(20, dtype=float).reshape((20, 1)) * 2.0
    data_input = np.vstack((data_input, [np.nan], data_input))
    training_data = ImplicitTrainingData(data_input)
    expected_derivative = np.full((26, 1), 2.0)
    np.testing.assert_array_almost_equal(training_data.dx_dt,
                                         expected_derivative)
def test_getting_subset_of_training_data():
    data_input = np.arange(5).reshape((5, 1))
    training_data = ImplicitTrainingData(data_input, data_input)
    subset_training_data = training_data[[0, 2, 3]]

    expected_subset = np.array([[0], [2], [3]])
    np.testing.assert_array_equal(subset_training_data.x,
                                  expected_subset)
    np.testing.assert_array_equal(subset_training_data.dx_dt,
                                  expected_subset)
def test_correct_partial_calculation_in_training_data():
    data_input = np.arange(20, dtype=float).reshape((20, 1))
    data_input = np.c_[data_input * 0,
                       data_input * 1,
                       data_input * 2]
    training_data = ImplicitTrainingData(data_input)
    expected_derivative = np.c_[np.ones(13) * 0,
                                np.ones(13) * 1,
                                np.ones(13) * 2]
    np.testing.assert_array_almost_equal(training_data.dx_dt,
                                         expected_derivative)
def test_correct_training_data_length(input_size):
    data_input = np.arange(input_size).reshape((-1, 1))
    training_data = ImplicitTrainingData(data_input, data_input)
    assert len(training_data) == input_size
def test_poorly_shaped_input_dx_dt_of_training_data():
    x = np.zeros((5, 1))
    dx_dt = np.zeros(5)
    with pytest.raises(ValueError):
        _ = ImplicitTrainingData(x, dx_dt)