def test_train_sqrt_mean_array():
    data_transformer = DataTransformer(transformations="sqrt_mean")
    data_transformer.train(data)
    transformed = data_transformer.transform(data)
    assert np.sum(
        np.abs(data - data_transformer.back_transform(transformed))) < 1e-7
def test_train_identity_array():
    data_transformer = DataTransformer(transformations="identity")
    data_transformer.train(data)
    transformed = data_transformer.transform(data)
    assert np.all(data == transformed)
    assert np.all(data == data_transformer.back_transform(transformed))
def test_train_log_mean_array():
    data_transformer = DataTransformer(transformations="log_mean")
    data_transformer.train(data)
    transformed = data_transformer.transform(data)
    assert np.all(data == data_transformer.back_transform(transformed))
def test_train_mean_array():
    data_transformer = DataTransformer(transformations="mean")
    data_transformer.train(data)
    transformed = data_transformer.transform(data)
    assert np.all(np.abs(np.mean(transformed, axis=0) - 1.0) < 1e-7)
    assert np.all(data == data_transformer.back_transform(transformed))