Esempio n. 1
0
def test_modify_node():
    node = ast_call("TensorFlow(distributions={'parameter_server': {'enabled': True}})")
    modifier = renamed_params.DistributionParameterRenamer()
    modifier.modify_node(node)

    expected = "TensorFlow(distribution={'parameter_server': {'enabled': True}})"
    assert expected == pasta.dump(node)
Esempio n. 2
0
def test_node_should_be_modified_no_distribution():
    constructors = (
        "TensorFlow()",
        "sagemaker.tensorflow.TensorFlow()",
        "sagemaker.tensorflow.estimator.TensorFlow()",
        "MXNet()",
        "sagemaker.mxnet.MXNet()",
        "sagemaker.mxnet.estimator.MXNet()",
    )

    modifier = renamed_params.DistributionParameterRenamer()

    for call in constructors:
        assert not modifier.node_should_be_modified(ast_call(call))
Esempio n. 3
0
def test_node_should_be_modified_random_function_call():
    modifier = renamed_params.DistributionParameterRenamer()
    assert not modifier.node_should_be_modified(ast_call("Session()"))