def test_node_should_be_modified(): modifier = training_params.TrainPrefixRemover() for estimator in _estimators(): for param in PARAMS_WITH_VALUES: call = ast_call("{}({})".format(estimator, param)) assert modifier.node_should_be_modified(call)
def test_modify_node(): modifier = training_params.TrainPrefixRemover() for params in _parameter_combinations(): node = ast_call("Estimator({})".format(params)) modifier.modify_node(node) expected = "Estimator({})".format(params).replace("train_", "") assert expected == pasta.dump(node)
def test_node_should_be_modified_random_function_call(): modifier = training_params.TrainPrefixRemover() assert not modifier.node_should_be_modified(ast_call("Session()"))
def test_node_should_be_modified_no_params(): modifier = training_params.TrainPrefixRemover() for estimator in _estimators(): call = ast_call("{}()".format(estimator)) assert not modifier.node_should_be_modified(call)