def node_should_be_modified(self, node): """Checks if the ast.Call node instantiates a framework estimator or model. It doesn't specify the ``framework_version`` and ``py_version`` parameter, as appropriate. This looks for the following formats: - ``TensorFlow`` - ``sagemaker.tensorflow.TensorFlow`` where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` is instantiating a framework class that should specify ``framework_version``, but doesn't. """ if matching.matches_any(node, ESTIMATORS) or matching.matches_any( node, MODELS): return _version_args_needed(node) return False
def node_should_be_modified(self, node): """Checks if the ``ast.Call`` node instantiates a SerDe class. This looks for the following calls (both with and without namespaces): - ``sagemaker.predictor._CsvSerializer`` - ``sagemaker.predictor._JsonSerializer`` - ``sagemaker.predictor._NpySerializer`` - ``sagemaker.predictor._CsvDeserializer`` - ``sagemaker.predictor.BytesDeserializer`` - ``sagemaker.predictor.StringDeserializer`` - ``sagemaker.predictor.StreamDeserializer`` - ``sagemaker.predictor._NumpyDeserializer`` - ``sagemaker.predictor._JsonDeserializer`` - ``sagemaker.amazon.common.numpy_to_record_serializer`` - ``sagemaker.amazon.common.record_deserializer`` Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` instantiates a SerDe class. """ return matching.matches_any(node, OLD_CLASS_NAME_TO_NAMESPACES)
def test_matches_any(): name_to_namespaces_dict = { "KMeansPredictor": ("sagemaker", "sagemaker.amazon.kmeans"), "Predictor": ("sagemaker.tensorflow.serving",), } matches = ( "KMeansPredictor()", "sagemaker.KMeansPredictor()", "sagemaker.amazon.kmeans.KMeansPredictor()", "Predictor()", "sagemaker.tensorflow.serving.Predictor()", ) for call in matches: assert matching.matches_any(ast_call(call), name_to_namespaces_dict) non_matches = ("MXNet()", "sagemaker.mxnet.MXNet()") for call in non_matches: assert not matching.matches_any(ast_call(call), name_to_namespaces_dict)
def node_should_be_modified(self, node): """Checks if the node matches any of the relevant functions and contains the parameter to be renamed. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` matches the relevant function calls and contains the parameter to be renamed. """ return matching.matches_any(node, self.calls_to_modify) and matching.has_arg( node, self.old_param_name )
def node_should_be_modified(self, node): """Checks if the ``ast.Call`` node instantiates a class of interest. This looks for the following calls: - ``sagemaker.<my>.<namespace>.<MyPredictor>`` - ``sagemaker.<namespace>.<MyPredictor>`` - ``<MyPredictor>`` Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` instantiates a class of interest. """ return matching.matches_any(node, PREDICTORS)
def node_should_be_modified(self, node): """Function to check Airflow model config and if it contains positional arguments. Checks if the ``ast.Call`` node creates an Airflow model config and contains positional arguments. This looks for the following formats: - ``model_config`` - ``airflow.model_config`` - ``workflow.airflow.model_config`` - ``sagemaker.workflow.airflow.model_config`` where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` is either a ``model_config`` call or a ``model_config_from_estimator`` call and has positional arguments. """ return matching.matches_any(node, FUNCTIONS) and len(node.args) > 0
def node_should_be_modified(self, node): """Checks if the node is an estimator constructor and contains any relevant parameters. This looks for the following parameters: - ``train_instance_count`` - ``train_instance_type`` - ``train_max_run`` - ``train_max_run_wait`` - ``train_use_spot_instances`` - ``train_volume_kms_key`` - ``train_volume_size`` Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` matches the relevant function calls and contains the parameter to be renamed. """ return matching.matches_any(node, ESTIMATORS) and self._has_train_parameter(node)
def node_should_be_modified(self, node): """Checks if the ``ast.Call`` node instantiates a TensorFlow Serving class. This looks for the following calls: - ``sagemaker.tensorflow.serving.Model`` - ``sagemaker.tensorflow.serving.Predictor`` - ``Predictor`` Because ``Model`` can refer to either ``sagemaker.tensorflow.serving.Model`` or :class:`~sagemaker.model.Model`, ``Model`` on its own is not sufficient for indicating a TFS Model object. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` instantiates a TensorFlow Serving class. """ if isinstance(node.func, ast.Name): return node.func.id == "Predictor" return matching.matches_any(node, TFS_CLASSES)