def test_node_should_be_modified_tf_constructor_script_mode():
    tf_script_mode_constructors = (
        "TensorFlow(script_mode=True)",
        "TensorFlow(py_version='py3')",
        "TensorFlow(py_version='py37')",
        "TensorFlow(py_version='py3', script_mode=False)",
        "TensorFlow(py_version=py_version, script_mode=False)",
        "TensorFlow(py_version='py3', script_mode=script_mode)",
        "sagemaker.tensorflow.TensorFlow(script_mode=True)",
        "sagemaker.tensorflow.TensorFlow(py_version='py3')",
        "sagemaker.tensorflow.TensorFlow(py_version='py37')",
        "sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
        "sagemaker.tensorflow.TensorFlow(py_version=py_version, script_mode=False)",
        "sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=script_mode)",
        "sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)",
        "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')",
        "sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')",
        "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)",
        "sagemaker.tensorflow.estimator.TensorFlow(py_version=py_version, script_mode=False)",
        "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=script_mode)",
    )

    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

    for constructor in tf_script_mode_constructors:
        node = ast_call(constructor)
        assert modifier.node_should_be_modified(node) is False
def test_modify_node_set_model_dir_and_image_name(retrieve_image_uri,
                                                  boto_session):
    boto_session.return_value.region_name = REGION_NAME

    tf_constructors = (
        "TensorFlow()",
        "TensorFlow(script_mode=False)",
        "TensorFlow(model_dir='s3//bucket/model')",
    )
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

    for constructor in tf_constructors:
        node = ast_call(constructor)
        modifier.modify_node(node)

        assert "TensorFlow(image_uri='{}', model_dir=False)".format(
            IMAGE_URI) == pasta.dump(node)
        retrieve_image_uri.assert_called_with(
            "tensorflow",
            REGION_NAME,
            instance_type="ml.m4.xlarge",
            version="1.11.0",
            py_version="py2",
            image_scope="training",
        )
def test_modify_node_prefer_param_over_hyperparameter(retrieve_image_uri):
    tf_constructor = """sagemaker.tensorflow.TensorFlow(
        training_steps=100,
        requirements_file='source/requirements.txt',
        hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'},
    )"""

    node = ast_call(tf_constructor)
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
    modifier.modify_node(node)

    expected_hyperparameters = {
        "sagemaker_requirements": "source/requirements.txt",
        "training_steps": 100,
    }

    assert expected_hyperparameters == _hyperparameters_from_node(node)
def test_modify_node_set_hyperparameters(retrieve_image_uri):
    tf_constructor = """TensorFlow(
        checkpoint_path='s3://foo/bar',
        training_steps=100,
        evaluation_steps=10,
        requirements_file='source/requirements.txt',
    )"""

    node = ast_call(tf_constructor)
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
    modifier.modify_node(node)

    expected_hyperparameters = {
        "checkpoint_path": "s3://foo/bar",
        "evaluation_steps": 10,
        "sagemaker_requirements": "source/requirements.txt",
        "training_steps": 100,
    }

    assert expected_hyperparameters == _hyperparameters_from_node(node)
def test_modify_node_preserve_other_hyperparameters(retrieve_image_uri):
    tf_constructor = """sagemaker.tensorflow.TensorFlow(
        training_steps=100,
        evaluation_steps=10,
        requirements_file='source/requirements.txt',
        hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'},
    )"""

    node = ast_call(tf_constructor)
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
    modifier.modify_node(node)

    expected_hyperparameters = {
        "optimizer": "sgd",
        "lr": 0.1,
        "checkpoint_path": "s3://foo/bar",
        "evaluation_steps": 10,
        "sagemaker_requirements": "source/requirements.txt",
        "training_steps": 100,
    }

    assert expected_hyperparameters == _hyperparameters_from_node(node)
def test_modify_node_set_image_name_from_args(retrieve_image_uri,
                                              boto_session):
    boto_session.return_value.region_name = REGION_NAME

    tf_constructor = "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0')"

    node = ast_call(tf_constructor)
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
    modifier.modify_node(node)

    retrieve_image_uri.assert_called_with(
        "tensorflow",
        REGION_NAME,
        instance_type="ml.p2.xlarge",
        version="1.4.0",
        py_version="py2",
        image_scope="training",
    )

    expected_string = (
        "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
        "image_uri='{}', model_dir=False)".format(IMAGE_URI))
    assert expected_string == pasta.dump(node)
def test_node_should_be_modified_random_function_call():
    node = ast_call("MXNet(py_version='py3')")
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
    assert modifier.node_should_be_modified(node) is False