def test_modify_node(): node = ast_call("TensorFlow(distributions={'parameter_server': {'enabled': True}})") modifier = renamed_params.DistributionParameterRenamer() modifier.modify_node(node) expected = "TensorFlow(distribution={'parameter_server': {'enabled': True}})" assert expected == pasta.dump(node)
def test_node_should_be_modified_no_distribution(): constructors = ( "TensorFlow()", "sagemaker.tensorflow.TensorFlow()", "sagemaker.tensorflow.estimator.TensorFlow()", "MXNet()", "sagemaker.mxnet.MXNet()", "sagemaker.mxnet.estimator.MXNet()", ) modifier = renamed_params.DistributionParameterRenamer() for call in constructors: assert not modifier.node_should_be_modified(ast_call(call))
def test_node_should_be_modified_random_function_call(): modifier = renamed_params.DistributionParameterRenamer() assert not modifier.node_should_be_modified(ast_call("Session()"))