def test_estimator_modify_node(): node = ast_call("TensorFlow(image_name=my_image)") modifier = renamed_params.EstimatorImageURIRenamer() modifier.modify_node(node) expected = "TensorFlow(image_uri=my_image)" assert expected == pasta.dump(node)
def test_estimator_node_should_be_modified_no_distribution(): modifier = renamed_params.EstimatorImageURIRenamer() for estimator, namespaces in ESTIMATORS.items(): call = "{}()".format(estimator) assert not modifier.node_should_be_modified(ast_call(call)) for namespace in namespaces: call = "{}.{}()".format(namespace, estimator) assert not modifier.node_should_be_modified(ast_call(call))
def test_estimator_node_should_be_modified(): modifier = renamed_params.EstimatorImageURIRenamer() for estimator, namespaces in ESTIMATORS.items(): call = "{}(image_name='my-image:latest')".format(estimator) assert modifier.node_should_be_modified(ast_call(call)) for namespace in namespaces: call = "{}.{}(image_name='my-image:latest')".format( namespace, estimator) assert modifier.node_should_be_modified(ast_call(call))
def test_estimator_node_should_be_modified_random_function_call(): modifier = renamed_params.EstimatorImageURIRenamer() assert not modifier.node_should_be_modified(ast_call("Session()"))