def node_should_be_modified(self, node):
        """Checks if the ast.Call node instantiates a framework estimator or model.

        It doesn't specify the ``framework_version`` and ``py_version`` parameter,
        as appropriate.

        This looks for the following formats:

        - ``TensorFlow``
        - ``sagemaker.tensorflow.TensorFlow``

        where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow.

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` is instantiating a framework class that
                should specify ``framework_version``, but doesn't.
        """
        if matching.matches_any(node, ESTIMATORS) or matching.matches_any(
                node, MODELS):
            return _version_args_needed(node)

        return False
Esempio n. 2
0
    def node_should_be_modified(self, node):
        """Checks if the ``ast.Call`` node instantiates a SerDe class.

        This looks for the following calls (both with and without namespaces):

        - ``sagemaker.predictor._CsvSerializer``
        - ``sagemaker.predictor._JsonSerializer``
        - ``sagemaker.predictor._NpySerializer``
        - ``sagemaker.predictor._CsvDeserializer``
        - ``sagemaker.predictor.BytesDeserializer``
        - ``sagemaker.predictor.StringDeserializer``
        - ``sagemaker.predictor.StreamDeserializer``
        - ``sagemaker.predictor._NumpyDeserializer``
        - ``sagemaker.predictor._JsonDeserializer``
        - ``sagemaker.amazon.common.numpy_to_record_serializer``
        - ``sagemaker.amazon.common.record_deserializer``

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` instantiates a SerDe class.
        """
        return matching.matches_any(node, OLD_CLASS_NAME_TO_NAMESPACES)
def test_matches_any():
    name_to_namespaces_dict = {
        "KMeansPredictor": ("sagemaker", "sagemaker.amazon.kmeans"),
        "Predictor": ("sagemaker.tensorflow.serving",),
    }

    matches = (
        "KMeansPredictor()",
        "sagemaker.KMeansPredictor()",
        "sagemaker.amazon.kmeans.KMeansPredictor()",
        "Predictor()",
        "sagemaker.tensorflow.serving.Predictor()",
    )

    for call in matches:
        assert matching.matches_any(ast_call(call), name_to_namespaces_dict)

    non_matches = ("MXNet()", "sagemaker.mxnet.MXNet()")
    for call in non_matches:
        assert not matching.matches_any(ast_call(call), name_to_namespaces_dict)
    def node_should_be_modified(self, node):
        """Checks if the node matches any of the relevant functions and
        contains the parameter to be renamed.

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` matches the relevant function calls and
                contains the parameter to be renamed.
        """
        return matching.matches_any(node, self.calls_to_modify) and matching.has_arg(
            node, self.old_param_name
        )
    def node_should_be_modified(self, node):
        """Checks if the ``ast.Call`` node instantiates a class of interest.

        This looks for the following calls:

        - ``sagemaker.<my>.<namespace>.<MyPredictor>``
        - ``sagemaker.<namespace>.<MyPredictor>``
        - ``<MyPredictor>``

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` instantiates a class of interest.
        """
        return matching.matches_any(node, PREDICTORS)
Esempio n. 6
0
    def node_should_be_modified(self, node):
        """Function to check Airflow model config and if it contains positional arguments.

        Checks if the ``ast.Call`` node creates an Airflow model config and
        contains positional arguments. This looks for the following formats:

        - ``model_config``
        - ``airflow.model_config``
        - ``workflow.airflow.model_config``
        - ``sagemaker.workflow.airflow.model_config``

        where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` is either a ``model_config`` call or
                a ``model_config_from_estimator`` call and has positional arguments.
        """
        return matching.matches_any(node, FUNCTIONS) and len(node.args) > 0
    def node_should_be_modified(self, node):
        """Checks if the node is an estimator constructor and contains any relevant parameters.

        This looks for the following parameters:

        - ``train_instance_count``
        - ``train_instance_type``
        - ``train_max_run``
        - ``train_max_run_wait``
        - ``train_use_spot_instances``
        - ``train_volume_kms_key``
        - ``train_volume_size``

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` matches the relevant function calls and
                contains the parameter to be renamed.
        """
        return matching.matches_any(node, ESTIMATORS) and self._has_train_parameter(node)
Esempio n. 8
0
    def node_should_be_modified(self, node):
        """Checks if the ``ast.Call`` node instantiates a TensorFlow Serving class.

        This looks for the following calls:

        - ``sagemaker.tensorflow.serving.Model``
        - ``sagemaker.tensorflow.serving.Predictor``
        - ``Predictor``

        Because ``Model`` can refer to either ``sagemaker.tensorflow.serving.Model``
        or :class:`~sagemaker.model.Model`, ``Model`` on its own is not sufficient
        for indicating a TFS Model object.

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` instantiates a TensorFlow Serving class.
        """
        if isinstance(node.func, ast.Name):
            return node.func.id == "Predictor"

        return matching.matches_any(node, TFS_CLASSES)