コード例 #1
0
    def predict(self, deployment_name, df):
        """
        Predict on the specified deployment using the provided dataframe.

        Compute predictions on the pandas DataFrame ``df`` using the specified deployment.
        Note that the input/output types of this method matches that of `mlflow pyfunc predict`
        (we accept a pandas DataFrame as input and return either a pandas DataFrame,
        pandas Series, or numpy array as output).

        :param deployment_name: Name of deployment to predict against
        :param df: Pandas DataFrame to use for inference
        :return: A pandas DataFrame, pandas Series, or numpy array
        """
        try:
            service = Webservice(self.workspace, deployment_name)
        except Exception as e:
            raise MlflowException(
                'Failure retrieving deployment to predict against') from e

        # Take in DF, parse to json using split orient
        input_data = _get_jsonable_obj(df, pandas_orient='split')

        if not service.scoring_uri:
            raise MlflowException(
                'Error attempting to call webservice, scoring_uri unavailable. '
                'This could be due to a failed deployment, or the service is not ready yet.\n'
                'Current State: {}\n'
                'Errors: {}'.format(service.state, service.error))

        # Pass split orient json to webservice
        # Take records orient json from webservice
        resp = ClientBase._execute_func(service._webservice_session.post,
                                        service.scoring_uri,
                                        data=json.dumps(
                                            {'input_data': input_data}))

        if resp.status_code == 401:
            if service.auth_enabled:
                service_keys = service.get_keys()
                service._session.headers.update(
                    {'Authorization': 'Bearer ' + service_keys[0]})
            elif service.token_auth_enabled:
                service_token, refresh_token_time = service.get_access_token()
                service._refresh_token_time = refresh_token_time
                service._session.headers.update(
                    {'Authorization': 'Bearer ' + service_token})
            resp = ClientBase._execute_func(service._webservice_session.post,
                                            service.scoring_uri,
                                            data=input_data)

        if resp.status_code == 200:
            # Parse records orient json to df
            return parse_json_input(json.dumps(resp.json()), orient='records')
        else:
            raise MlflowException('Failure during prediction:\n'
                                  'Response Code: {}\n'
                                  'Headers: {}\n'
                                  'Content: {}'.format(resp.status_code,
                                                       resp.headers,
                                                       resp.content))
コード例 #2
0
 def cancel(self, uri=None):
     """
     Changes the state of the current run to canceled
     """
     if uri:
         auth = self.run._service_context.get_auth()
         headers = auth.get_authentication_header()
         with create_session_with_retry() as session:
             ClientBase._execute_func(session.post, uri, headers=headers)
     else:
         self.run.post_event_canceled()
     self.flush()
     self.upload_tracked_files()
コード例 #3
0
    def get_credentials(self):
        """Retrieve the credentials for the RemoteCompute target.

        :return: The credentials for the RemoteCompute target.
        :rtype: dict
        :raises azureml.exceptions.ComputeTargetException:
        """
        endpoint = self._mlc_endpoint + '/listKeys'
        headers = self._auth.get_authentication_header()
        ComputeTarget._add_request_tracking_headers(headers)
        params = {'api-version': MLC_WORKSPACE_API_VERSION}
        resp = ClientBase._execute_func(get_requests_session().post,
                                        endpoint,
                                        params=params,
                                        headers=headers)

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise ComputeTargetException('Received bad response from MLC:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(
                                             resp.status_code, resp.headers,
                                             resp.content))
        content = resp.content
        if isinstance(content, bytes):
            content = content.decode('utf-8')
        creds_content = json.loads(content)
        return creds_content
コード例 #4
0
    def _get(workspace, name):
        """Return web response content for the compute.

        :param workspace:
        :type workspace: azureml.core.Workspace
        :param name:
        :type name: str
        :return:
        :rtype: dict
        """
        endpoint = ComputeTarget._get_rp_compute_endpoint(workspace, name)
        headers = workspace._auth.get_authentication_header()
        ComputeTarget._add_request_tracking_headers(headers)
        params = {'api-version': MLC_WORKSPACE_API_VERSION}
        resp = ClientBase._execute_func(get_requests_session().get, endpoint, params=params, headers=headers)
        if resp.status_code == 200:
            content = resp.content
            if isinstance(content, bytes):
                content = content.decode('utf-8')
            get_content = json.loads(content)
            return get_content
        elif resp.status_code == 404:
            return None
        else:
            raise ComputeTargetException('Received bad response from Resource Provider:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(resp.status_code, resp.headers, resp.content))
コード例 #5
0
    def update(self, tags):
        """Update the image.

        :param tags: A dictionary of tags to update the image with. Will overwrite any existing tags.
        :type tags: dict[str, str]
        :raises: azureml.exceptions.WebserviceException
        """
        headers = {'Content-Type': 'application/json-patch+json'}
        headers.update(self._auth.get_authentication_header())
        params = {}

        patch_list = []
        self.tags = tags
        patch_list.append({
            'op': 'replace',
            'path': '/kvTags',
            'value': self.tags
        })

        resp = ClientBase._execute_func(get_requests_session().patch,
                                        self._mms_endpoint,
                                        headers=headers,
                                        params=params,
                                        json=patch_list,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code >= 400:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)
コード例 #6
0
    def _delete_or_detach(self, underlying_resource_action):
        """Remove the Compute object from its associated workspace.

        If underlying_resource_action is 'delete', the corresponding cloud-based objects will also be deleted.
        If underlying_resource_action is 'detach', no underlying cloud object will be deleted, the association
        will just be removed.

        :param underlying_resource_action: whether delete or detach the underlying cloud object
        :type underlying_resource_action: str
        :raises azureml.exceptions.ComputeTargetException:
        """
        headers = self._auth.get_authentication_header()
        ComputeTarget._add_request_tracking_headers(headers)
        params = {'api-version': MLC_WORKSPACE_API_VERSION, 'underlyingResourceAction': underlying_resource_action}
        resp = ClientBase._execute_func(get_requests_session().delete, self._mlc_endpoint, params=params,
                                        headers=headers)

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise ComputeTargetException('Received bad response from Resource Provider:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(resp.status_code, resp.headers, resp.content))

        self.provisioning_state = 'Deleting'
        self._operation_endpoint = resp.headers['Azure-AsyncOperation']
コード例 #7
0
    def _get_team_resource(self, arm_scope):
        arm_endpoint = self._auth_object._get_arm_end_point()
        headers = self._auth_object.get_authentication_header()
        query_parameters = {'api-version': VERSION}
        status = ClientBase._execute_func(
            requests.get, urljoin(arm_endpoint, arm_scope), headers=headers, params=query_parameters)

        status.raise_for_status()
        return status.json()
コード例 #8
0
    def remove_tags(self, tags):
        """Remove tags from the image.

        :param tags: A list of keys corresponding to tags to be removed.
        :type tags: builtin.list[str]
        :raises: azureml.exceptions.WebserviceException
        """
        headers = {'Content-Type': 'application/json-patch+json'}
        headers.update(self._auth.get_authentication_header())
        params = {}

        patch_list = []
        if self.tags is None:
            print('Image has no tags to remove.')
            return
        else:
            if type(tags) is not list:
                tags = [tags]
            for key in tags:
                if key in self.tags:
                    del self.tags[key]
                else:
                    print('Tag with key {} not found.'.format(key))

        patch_list.append({
            'op': 'replace',
            'path': '/kvTags',
            'value': self.tags
        })

        resp = ClientBase._execute_func(get_requests_session().patch,
                                        self._mms_endpoint,
                                        headers=headers,
                                        params=params,
                                        json=patch_list,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code >= 400:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)

        print('Image tag remove operation complete.')
コード例 #9
0
 def diagnostics(self, uri):
     """
     Retrieves the diagnostics in the working directory of the current run.
     """
     auth = self.run._service_context.get_auth()
     headers = auth.get_authentication_header()
     with create_session_with_retry() as session:
         try:
             response = ClientBase._execute_func(session.get, uri, headers=headers)
             response.raise_for_status()
         except requests.exceptions.HTTPError:
             raise WebserviceException('Received bad response from Execution Service:\n'
                                       'Response Code: {}\n'
                                       'Headers: {}\n'
                                       'Content: {}'.format(response.status_code, response.headers,
                                                            response.content),
                                       logger=module_logger)
     return response
コード例 #10
0
    def add_tags(self, tags):
        """Add tags to the image.

        :param tags: A dictionary of tags to add.
        :type tags: dict[str, str]
        :raises: azureml.exceptions.WebserviceException
        """
        headers = {'Content-Type': 'application/json-patch+json'}
        headers.update(self._auth.get_authentication_header())
        params = {}

        patch_list = []
        if self.tags is None:
            self.tags = copy.deepcopy(tags)
        else:
            for key in tags:
                if key in self.tags:
                    print("Replacing tag {} -> {} with {} -> {}".format(
                        key, self.tags[key], key, tags[key]))
                self.tags[key] = tags[key]

        patch_list.append({
            'op': 'replace',
            'path': '/kvTags',
            'value': self.tags
        })

        resp = ClientBase._execute_func(get_requests_session().patch,
                                        self._mms_endpoint,
                                        headers=headers,
                                        params=params,
                                        json=patch_list,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code >= 400:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)

        print('Image tag add operation complete.')
コード例 #11
0
    def add_properties(self, properties):
        """Add properties to the image.

        :param properties: A dictionary of properties to add.
        :type properties: dict[str, str]
        :raises: azureml.exceptions.WebserviceException
        """
        check_duplicate_properties(self.properties, properties)

        headers = {'Content-Type': 'application/json-patch+json'}
        headers.update(self._auth.get_authentication_header())
        params = {}

        patch_list = []
        if self.properties is None:
            self.properties = copy.deepcopy(properties)
        else:
            for key in properties:
                self.properties[key] = properties[key]

        patch_list.append({
            'op': 'add',
            'path': '/properties',
            'value': self.properties
        })

        resp = ClientBase._execute_func(get_requests_session().patch,
                                        self._mms_endpoint,
                                        headers=headers,
                                        params=params,
                                        json=patch_list,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code >= 400:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)

        print('Image properties add operation complete.')
コード例 #12
0
    def _get_operation_state(self):
        """Return operation state.

        :return:
        :rtype: (str, dict)
        """
        headers = self._auth.get_authentication_header()
        ComputeTarget._add_request_tracking_headers(headers)
        params = {}

        # API version should not be appended for operation status URLs.
        # This is a bug fix for older SDK and ARM breaking changes and
        # will append version only if the request URL doesn't have one.
        if 'api-version' not in self._operation_endpoint:
            params = {'api-version': MLC_WORKSPACE_API_VERSION}

        resp = ClientBase._execute_func(get_requests_session().get, self._operation_endpoint, params=params,
                                        headers=headers)

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise ComputeTargetException('Received bad response from Resource Provider:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(resp.status_code, resp.headers, resp.content))
        content = resp.content
        if isinstance(content, bytes):
            content = content.decode('utf-8')
        content = json.loads(content)
        status = content['status']
        error = content.get('error')

        # Prior to API version 2019-06-01 the 'error' element was double nested.
        # This change retains backwards compat for 2018-11-19 version.
        if error is not None:
            innererror = error.get('error')
            if innererror is not None:
                error = innererror
        # ---------------------------------------------------------------------

        return status, error
コード例 #13
0
    def update_creation_state(self):
        """Refresh the current state of the in-memory object.

        Perform an in-place update of the properties of the object based on the current state of the
        corresponding cloud object. Primarily useful for manual polling of creation state.

        :raises: azureml.exceptions.WebserviceException
        """
        headers = {'Content-Type': 'application/json'}
        headers.update(self._auth.get_authentication_header())
        params = {}

        resp = ClientBase._execute_func(get_requests_session().get,
                                        self._mms_endpoint,
                                        headers=headers,
                                        params=params,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code == 200:
            image = Image(self.workspace, id=self.id)
            for key in image.__dict__.keys():
                if key is not "_operation_endpoint":
                    self.__dict__[key] = image.__dict__[key]
        elif resp.status_code == 404:
            raise WebserviceException('Error: image {} not found:\n'
                                      'Response Code: {}\n'
                                      'Headers: {}\n'
                                      'Content: {}'.format(
                                          self.id, resp.status_code,
                                          resp.headers, resp.content),
                                      logger=module_logger)
        else:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)
コード例 #14
0
    def _get_operation_state(self):
        """Get the current async operation state for the image.

        :return:
        :rtype: (str, dict)
        """
        if not self._operation_endpoint:
            self.update_deployment_state()
            raise WebserviceException(
                'Long running operation information not known, unable to poll. '
                'Current state is {}'.format(self.creation_state),
                logger=module_logger)

        headers = {'Content-Type': 'application/json'}
        headers.update(self._auth.get_authentication_header())
        params = {}

        resp = ClientBase._execute_func(get_requests_session().get,
                                        self._operation_endpoint,
                                        headers=headers,
                                        params=params,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)
        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise WebserviceException(
                'Received bad response from Resource Provider:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)
        content = resp.content
        if isinstance(content, bytes):
            content = content.decode('utf-8')
        content = json.loads(content)
        state = content['state']
        error = content['error'] if 'error' in content else None
        return state, error
コード例 #15
0
def get_paginated_compute_results(payload, headers):
    if 'value' not in payload:
        raise ComputeTargetException(
            'Error, invalid paginated response payload, missing "value":\n'
            '{}'.format(payload))
    items = payload['value']
    while 'nextLink' in payload:
        next_link = payload['nextLink']

        try:
            resp = ClientBase._execute_func(get_requests_session().get,
                                            next_link,
                                            headers=headers)
        except requests.Timeout:
            print(
                'Error, request to Machine Learning Compute timed out. Returning with items found so far'
            )
            return items
        if resp.status_code == 200:
            content = resp.content
            if isinstance(content, bytes):
                content = content.decode('utf-8')
            payload = json.loads(content)
        else:
            raise ComputeTargetException(
                'Received bad response from Machine Learning Compute while retrieving '
                'paginated results:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content))
        if 'value' not in payload:
            raise ComputeTargetException(
                'Error, invalid paginated response payload, missing "value":\n'
                '{}'.format(payload))
        items += payload['value']

    return items
コード例 #16
0
    def _attach(workspace, name, attach_payload, target_class):
        """Attach implementation method.

        :param workspace:
        :type workspace: azureml.core.Workspace
        :param name:
        :type name: str
        :param attach_payload:
        :type attach_payload: dict
        :param target_class:
        :type target_class:
        :return:
        :rtype:
        """
        attach_payload['location'] = workspace.location
        endpoint = ComputeTarget._get_compute_endpoint(workspace, name)
        headers = {'Content-Type': 'application/json'}
        headers.update(workspace._auth.get_authentication_header())
        ComputeTarget._add_request_tracking_headers(headers)
        params = {'api-version': MLC_WORKSPACE_API_VERSION}
        resp = ClientBase._execute_func(get_requests_session().put, endpoint, params=params, headers=headers,
                                        json=attach_payload)

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise ComputeTargetException('Received bad response from Resource Provider:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(resp.status_code, resp.headers, resp.content))
        if 'Azure-AsyncOperation' not in resp.headers:
            raise ComputeTargetException('Error, missing operation location from resp headers:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(resp.status_code, resp.headers, resp.content))
        compute_target = target_class(workspace, name)
        compute_target._operation_endpoint = resp.headers['Azure-AsyncOperation']
        return compute_target
コード例 #17
0
    def delete(self):
        """Delete an image from its corresponding workspace.

        .. remarks::

            This method fails if the image has been deployed to a live webservice.

        :raises: azureml.exceptions.WebserviceException
        """
        headers = self._auth.get_authentication_header()
        params = {}

        resp = ClientBase._execute_func(get_requests_session().delete,
                                        self._mms_endpoint,
                                        headers=headers,
                                        params=params,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code >= 400:
            if resp.status_code == 412 and "DeletionRequired" in resp.content:
                raise WebserviceException(
                    'The image cannot be deleted because it is currently being used in one or '
                    'more webservices. To know what webservices contain the image, run '
                    '"Webservice.list(<workspace>, image_id={})"'.format(
                        self.id),
                    logger=module_logger)

            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)
        else:
            self.creation_state = 'Deleting'
コード例 #18
0
    def create(workspace, name, models, image_config):
        """Create an image in the provided workspace.

        :param workspace: The workspace to associate with this image.
        :type workspace: workspace: azureml.core.workspace.Workspace
        :param name: The name to associate with this image.
        :type name: str
        :param models: A list of Model objects to package with this image. Can be an empty list.
        :type models: builtin.list[azureml.core.Model]
        :param image_config: The image config object to use to configure this image.
        :type image_config: azureml.core.image.image.ImageConfig
        :return: The created Image object.
        :rtype: azureml.core.Image
        :raises: azureml.exceptions.WebserviceException
        """
        warnings.warn(
            "Image class has been deprecated and will be removed in a future release. "
            + "Please migrate to using Environments. " +
            "https://docs.microsoft.com/en-us/azure/machine-learning/how-to-use-environments",
            category=DeprecationWarning,
            stacklevel=2)

        image_name_validation(name)
        model_ids = Model._resolve_to_model_ids(workspace, models, name)

        headers = {'Content-Type': 'application/json'}
        headers.update(workspace._auth.get_authentication_header())
        params = {}
        base_endpoint = _get_mms_url(workspace)
        image_url = base_endpoint + '/images'

        json_payload = image_config.build_create_payload(
            workspace, name, model_ids)

        print('Creating image')
        resp = ClientBase._execute_func(get_requests_session().post,
                                        image_url,
                                        params=params,
                                        headers=headers,
                                        json=json_payload)
        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)
        if resp.status_code >= 400:
            raise WebserviceException('Error occurred creating image:\n'
                                      'Response Code: {}\n'
                                      'Headers: {}\n'
                                      'Content: {}'.format(
                                          resp.status_code, resp.headers,
                                          resp.content),
                                      logger=module_logger)

        if 'Operation-Location' in resp.headers:
            operation_location = resp.headers['Operation-Location']
        else:
            raise WebserviceException(
                'Missing response header key: Operation-Location',
                logger=module_logger)

        create_operation_status_id = operation_location.split('/')[-1]
        operation_url = base_endpoint + '/operations/{}'.format(
            create_operation_status_id)
        operation_headers = workspace._auth.get_authentication_header()

        operation_resp = ClientBase._execute_func(
            get_requests_session().get,
            operation_url,
            params=params,
            headers=operation_headers,
            timeout=MMS_SYNC_TIMEOUT_SECONDS)
        try:
            operation_resp.raise_for_status()
        except requests.Timeout:
            raise WebserviceException(
                'Error, request to {} timed out.'.format(operation_url),
                logger=module_logger)
        except requests.exceptions.HTTPError:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(operation_resp.status_code,
                                     operation_resp.headers,
                                     operation_resp.content),
                logger=module_logger)

        content = operation_resp.content
        if isinstance(content, bytes):
            content = content.decode('utf-8')
        operation_content = json.loads(content)
        if 'resourceLocation' in operation_content:
            image_id = operation_content['resourceLocation'].split('/')[-1]
        else:
            raise WebserviceException(
                'Invalid operation payload, missing resourceLocation:\n'
                '{}'.format(operation_content),
                logger=module_logger)

        image = Image(workspace, id=image_id)
        image._operation_endpoint = operation_url
        return image
コード例 #19
0
def _start_internal(project_object,
                    run_config_object,
                    prepare_only=False,
                    prepare_check=False,
                    custom_target_dict=None,
                    run_id=None,
                    injected_files=None,
                    telemetry_values=None,
                    parent_run_id=None):
    """
    :param project_object: Project object
    :type project_object: azureml.core.project.Project
    :param run_config_object: The run configuration object.
    :param run_config_object: azureml.core.runconfig.RunConfiguration
    :param prepare_only:
    :param prepare_check:
    :param custom_target_dict:
    :param run_id:
    :param injected_files:
    :type injected_files: dict
    :param telemetry_values:
    :param parent_run_id:
    :return: azureml.core.script_run.ScriptRun or bool if prepare_check=True
    """
    service_context = project_object.workspace.service_context
    snapshots_client = SnapshotsClient(service_context)

    service_address = service_context._get_execution_url()
    service_arm_scope = "{}/experiments/{}".format(
        service_context._get_workspace_scope(), project_object.history.name)

    if run_config_object.credential_passthrough:
        aml_client_token = project_object.workspace._auth_object._get_azureml_client_token(
        )
        auth_header = {"Authorization": "Bearer " + aml_client_token}
    else:
        auth_header = project_object.workspace._auth_object.get_authentication_header(
        )
    thread_pool = Pool(1)
    ignore_file = get_project_ignore_file(project_object.project_directory)

    snapshot_async = None
    execute_with_zip = False

    directory_size = get_directory_size(
        project_object.project_directory,
        _max_zip_size_bytes,
        exclude_function=ignore_file.is_file_excluded)

    if directory_size >= _max_zip_size_bytes:
        give_warning("Submitting {} directory for run. "
                     "The size of the directory >= {} MB, "
                     "so it can take a few minutes.".format(
                         project_object.project_directory, _num_max_mbs))

    if run_config_object.history.snapshot_project:
        snapshot_async = thread_pool.apply_async(
            snapshots_client.create_snapshot,
            (project_object.project_directory, ))

        # These can be set by users in case we have any issues with zip/snapshot and need to force a specific path
        force_execute_snapshot = os.environ.get(
            "AML_FORCE_EXECUTE_WITH_SNAPSHOT")
        force_execute_zip = os.environ.get("AML_FORCE_EXECUTE_WITH_ZIP")

        if force_execute_zip and not force_execute_snapshot:
            execute_with_zip = True
            module_logger.debug("Executing with zip file.")
        else:
            module_logger.debug("Executing with snapshot.")

    _run_config_modification(run_config_object)

    temporary = None
    archive = None
    try:
        if execute_with_zip:
            temporary = temp_dir_back.TemporaryDirectory()
            archive_path = os.path.join(temporary.name, "project.zip")
            _make_zipfile_exclude(project_object, archive_path,
                                  ignore_file.is_file_excluded)

            # Inject files into the user's project'
            if injected_files:
                _add_files_to_zip(archive_path, injected_files)
            archive = open(archive_path, "rb")

        headers = _get_common_headers()
        # Merging the auth header.
        headers.update(auth_header)

        uri = service_address
        api_prefix = ""

        if execute_with_zip:
            if not prepare_only:
                api_prefix = "start"
        else:
            api_prefix = "snapshot"

        if prepare_only:
            if prepare_check:
                uri += "/execution/v1.0" + service_arm_scope + "/{}checkprepare".format(
                    api_prefix)
            else:
                uri += "/execution/v1.0" + service_arm_scope + "/{}prepare".format(
                    api_prefix)

        else:
            uri += "/execution/v1.0" + service_arm_scope + "/{}run".format(
                api_prefix)

        run_id_query = urllib3.request.urlencode({"runId": run_id})

        uri += "?" + run_id_query

        snapshot_id = snapshot_async.get() if snapshot_async else None
        thread_pool.close()

        definition = {
            "TargetDetails": custom_target_dict,
            "Configuration": _serialize_run_config_to_dict(run_config_object),
            "TelemetryValues": telemetry_values
        }
        if parent_run_id is not None:
            definition["ParentRunId"] = parent_run_id

        if execute_with_zip:
            if prepare_only:
                files = [("files", ("definition.json",
                                    json.dumps(definition,
                                               indent=4,
                                               sort_keys=True))),
                         ("files", ("project.zip", archive))]
            else:
                files = [("runDefinitionFile", ("definition.json",
                                                json.dumps(definition,
                                                           indent=4,
                                                           sort_keys=True))),
                         ("projectZipFile", ("project.zip", archive))]

            response = ClientBase._execute_func(requests.post,
                                                uri,
                                                files=files,
                                                headers=headers)
        else:
            definition["SnapshotId"] = snapshot_id

            response = ClientBase._execute_func(requests.post,
                                                uri,
                                                json=definition,
                                                headers=headers)

        _raise_request_error(response, "starting run")

    finally:
        if archive:
            archive.close()
        if temporary:
            temporary.cleanup()

    result = response.json()
    if prepare_only and prepare_check:
        return result["environmentPrepared"]

    return _get_run_details(project_object,
                            run_config_object,
                            result["runId"],
                            snapshot_id=snapshot_id)
コード例 #20
0
def _start_internal_local_cloud(project_object,
                                run_config_object,
                                prepare_only=False,
                                custom_target_dict=None,
                                run_id=None,
                                injected_files=None,
                                telemetry_values=None,
                                parent_run_id=None):
    """
    :param project_object: Project object
    :type project_object: azureml.core.project.Project
    :param run_config_object: The run configuration object.
    :type run_config_object: azureml.core.runconfig.RunConfiguration
    :param prepare_only:
    :param custom_target_dict:
    :param run_id:
    :param injected_files:
    :type injected_files: dict
    :param telemetry_values:
    :param parent_run_id:
    :return: azureml.core.script_run.ScriptRun
    """
    service_context = project_object.workspace.service_context
    snapshots_client = SnapshotsClient(service_context)

    service_address = service_context._get_execution_url()
    #  TODO move this into project or experiment to avoid code dup
    service_arm_scope = "{}/experiments/{}".format(
        service_context._get_workspace_scope(), project_object.history.name)
    auth_header = project_object.workspace._auth_object.get_authentication_header(
    )
    thread_pool = Pool(1)

    snapshot_async = None
    if run_config_object.history.snapshot_project:
        snapshot_async = thread_pool.apply_async(
            snapshots_client.create_snapshot,
            (project_object.project_directory, ))

    _run_config_modification(run_config_object)

    # Check size of config folder
    if get_directory_size(project_object.project_directory,
                          _max_zip_size_bytes,
                          include_function=_include) > _max_zip_size_bytes:
        error_message = "====================================================================\n" \
                        "\n" \
                        "Your configuration directory exceeds the limit of {0} MB.\n" \
                        "Please see http://aka.ms/aml-largefiles on how to work with large files.\n" \
                        "\n" \
                        "====================================================================\n" \
                        "\n".format(_max_zip_size_bytes / _one_mb)
        raise ExperimentExecutionException(error_message)

    with temp_dir_back.TemporaryDirectory() as temporary:
        archive_path = os.path.join(temporary, "aml_config.zip")
        archive_path_local = os.path.join(temporary, "temp_project.zip")

        project_temp_dir = _get_project_temporary_directory(run_id)
        os.mkdir(project_temp_dir)

        # We send only aml_config zip to service and copy only necessary files to temp dir
        ignore_file = get_project_ignore_file(project_object.project_directory)
        _make_zipfile_include(project_object, archive_path, _include)
        _make_zipfile_exclude(project_object, archive_path_local,
                              ignore_file.is_file_excluded)

        # Inject files into the user's project'
        if injected_files:
            _add_files_to_zip(archive_path, injected_files)

        # Copy current project dir to temp/azureml-runs folder.
        zip_ref = zipfile.ZipFile(archive_path_local, "r")
        zip_ref.extractall(project_temp_dir)
        zip_ref.close()

        # TODO Missing driver arguments, job_name

        with open(archive_path, "rb") as archive:
            definition = {
                "TargetDetails": custom_target_dict,
                "Configuration":
                _serialize_run_config_to_dict(run_config_object),
                "TelemetryValues": telemetry_values
            }
            if parent_run_id is not None:
                definition["ParentRunId"] = parent_run_id

            files = [("files", ("definition.json",
                                json.dumps(definition,
                                           indent=4,
                                           sort_keys=True))),
                     ("files", ("aml_config.zip", archive))]

            headers = _get_common_headers()

            # Merging the auth header.
            headers.update(auth_header)

            uri = service_address + "/execution/v1.0" + service_arm_scope
            if prepare_only:
                uri += "/localprepare"
            else:
                uri += "/localrun"

            # Unfortunately, requests library does not take Queryparams nicely.
            # Appending run_id_query to the url for service to extract from it.
            run_id_query = urllib3.request.urlencode({"runId": run_id})
            uri += "?" + run_id_query

            response = ClientBase._execute_func(requests.post,
                                                uri,
                                                files=files,
                                                headers=headers)
            _raise_request_error(response, "starting run")

            invocation_zip_path = os.path.join(project_temp_dir,
                                               "invocation.zip")
            with open(invocation_zip_path, "wb") as file:
                file.write(response.content)

            with zipfile.ZipFile(invocation_zip_path, "r") as zip_ref:
                zip_ref.extractall(project_temp_dir)

            try:
                _invoke_command(project_temp_dir)
            except subprocess.CalledProcessError as ex:
                raise ExperimentExecutionException(ex.output)

            snapshot_id = snapshot_async.get() if snapshot_async else None
            thread_pool.close()

            return _get_run_details(project_object,
                                    run_config_object,
                                    run_id,
                                    snapshot_id=snapshot_id)
コード例 #21
0
    def list(workspace,
             image_name=None,
             model_name=None,
             model_id=None,
             tags=None,
             properties=None):
        """List the Images associated with the corresponding workspace. Can be filtered with specific parameters.

        :param workspace: The Workspace object to list the Images in.
        :type workspace: azureml.core.workspace.Workspace
        :param image_name: Filter list to only include Images deployed with the specific image name.
        :type image_name: str
        :param model_name: Filter list to only include Images deployed with the specific model name.
        :type model_name: str
        :param model_id: Filter list to only include Images deployed with the specific model ID.
        :type model_id: str
        :param tags: Will filter based on the provided list, by either 'key' or '[key, value]'.
            Ex. ['key', ['key2', 'key2 value']]
        :type tags: builtin.list
        :param properties: Will filter based on the provided list, by either 'key' or '[key, value]'.
            Ex. ['key', ['key2', 'key2 value']]
        :type properties: builtin.list
        :return: A filtered list of Images in the provided workspace.
        :rtype: builtin.list[Images]
        :raises: azureml.exceptions.WebserviceException
        """
        warnings.warn(
            "Image class has been deprecated and will be removed in a future release. "
            + "Please migrate to using Environments. " +
            "https://docs.microsoft.com/en-us/azure/machine-learning/how-to-use-environments",
            category=DeprecationWarning,
            stacklevel=2)

        headers = workspace._auth.get_authentication_header()
        params = {'expand': 'true'}
        base_url = _get_mms_url(workspace)
        mms_url = base_url + '/images'

        if image_name:
            params['name'] = image_name
        if model_name:
            params['modelName'] = model_name
        if model_id:
            params['modelId'] = model_id
        if tags:
            tags_query = ""
            for tag in tags:
                if type(tag) is list:
                    tags_query = tags_query + tag[0] + "=" + tag[1] + ","
                else:
                    tags_query = tags_query + tag + ","
            tags_query = tags_query[:-1]
            params['tags'] = tags_query
        if properties:
            properties_query = ""
            for prop in properties:
                if type(prop) is list:
                    properties_query = properties_query + prop[0] + "=" + prop[
                        1] + ","
                else:
                    properties_query = properties_query + prop + ","
            properties_query = properties_query[:-1]
            params['properties'] = properties_query
        try:
            resp = ClientBase._execute_func(get_requests_session().get,
                                            mms_url,
                                            headers=headers,
                                            params=params,
                                            timeout=MMS_SYNC_TIMEOUT_SECONDS)
            resp.raise_for_status()
        except requests.Timeout:
            raise WebserviceException(
                'Error, request to Model Management Service timed out to URL:\n'
                '{}'.format(mms_url),
                logger=module_logger)
        except requests.exceptions.HTTPError:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)

        content = resp.content
        if isinstance(content, bytes):
            content = content.decode('utf-8')
        image_payload = json.loads(content)
        paginated_results = get_paginated_results(image_payload, headers)

        return [
            Image.deserialize(workspace, image_dict)
            for image_dict in paginated_results
        ]
コード例 #22
0
    def list(workspace):
        """List all ComputeTarget objects within the workspace.

        Return a list of instantiated child objects corresponding to the specific type of Compute. Objects are
        children of :class:`azureml.core.ComputeTarget`.

        :param workspace: The workspace object containing the objects to list.
        :type workspace: azureml.core.Workspace
        :return: List of compute targets within the workspace.
        :rtype: builtin.list[azureml.core.ComputeTarget]
        :raises azureml.exceptions.ComputeTargetException:
        """
        envs = []
        endpoint = ComputeTarget._get_rp_list_computes_endpoint(workspace)
        headers = workspace._auth.get_authentication_header()
        ComputeTarget._add_request_tracking_headers(headers)
        params = {'api-version': MLC_WORKSPACE_API_VERSION}
        resp = ClientBase._execute_func(get_requests_session().get, endpoint, params=params, headers=headers)

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            raise ComputeTargetException('Error occurred retrieving targets:\n'
                                         'Response Code: {}\n'
                                         'Headers: {}\n'
                                         'Content: {}'.format(resp.status_code, resp.headers, resp.content))
        is_windows_contrib_installed = True
        try:
            from azureml.contrib.compute import AmlWindowsCompute  # noqa: F401
        except ImportError:
            is_windows_contrib_installed = False
            pass
        content = resp.content
        if isinstance(content, bytes):
            content = content.decode('utf-8')
        result_list = json.loads(content)
        paginated_results = get_paginated_compute_results(result_list, headers)
        for env in paginated_results:
            if 'properties' in env and 'computeType' in env['properties']:
                compute_type = env['properties']['computeType']
                is_attached = env['properties']['isAttachedCompute']
                env_obj = None
                for child in ComputeTarget.__subclasses__():
                    if is_attached and compute_type == 'VirtualMachine' and child.__name__ == 'DsvmCompute':
                        # Cannot attach DsvmCompute
                        continue
                    elif not is_attached and compute_type == 'VirtualMachine' and child.__name__ == 'RemoteCompute':
                        # Cannot create RemoteCompute
                        continue
                    elif not is_attached and compute_type == 'Kubernetes' and child.__name__ == 'KubernetesCompute':
                        # Cannot create KubernetesCompute
                        continue
                    elif compute_type == child._compute_type:
                        # If windows contrib is not installed, don't list windows compute type
                        # Windows is currently supported only for RL runs.
                        # The windows contrib is installed as a part of RL SDK install.
                        # This step is trying to avoid users using this compute target by mistake for a non-RL run
                        if not is_windows_contrib_installed and "properties" in env['properties'] and \
                                env['properties']['properties'] is not None and \
                                "osType" in env['properties']['properties'] and \
                                env['properties']['properties']['osType'].lower() == 'windows':
                            pass
                        else:
                            env_obj = child.deserialize(workspace, env)
                        break
                if env_obj:
                    envs.append(env_obj)
        return envs
コード例 #23
0
def get_paginated_compute_supported_vms(payload, headers):
    if 'amlCompute' not in payload:
        raise ComputeTargetException(
            'Error, invalid paginated response payload, missing "amlCompute":\n'
            '{}'.format(payload))

    items = []
    required_keys = [
        'name', 'vCPUs', 'gpus', 'memoryGB', 'maxResourceVolumeMB'
    ]
    for i in range(0, len(payload['amlCompute'])):
        for key in required_keys:
            if key not in payload['amlCompute'][i]:
                raise ComputeTargetException(
                    'Error, invalid paginated response payload, missing "{}":\n'
                    '{}'.format(key, payload))
        items.append(
            {key: payload['amlCompute'][i][key]
             for key in required_keys})

    while 'nextLink' in payload:
        next_link = payload['nextLink']

        try:
            resp = ClientBase._execute_func(get_requests_session().get,
                                            next_link,
                                            headers=headers)
        except requests.Timeout:
            print(
                'Error, request to Machine Learning Compute timed out. Returning with items found so far'
            )
            return items
        if resp.status_code == 200:
            content = resp.content
            if isinstance(content, bytes):
                content = content.decode('utf-8')
            payload = json.loads(content)
        else:
            raise ComputeTargetException(
                'Received bad response from Machine Learning Compute while retrieving '
                'paginated results:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content))

        if 'amlCompute' not in payload:
            raise ComputeTargetException(
                'Error, invalid paginated response payload, missing "amlCompute":\n'
                '{}'.format(payload))

        for i in range(0, len(payload['amlCompute'])):
            for key in required_keys:
                if key not in payload['amlCompute'][i]:
                    raise ComputeTargetException(
                        'Error, invalid paginated response payload, missing "{}":\n'
                        '{}'.format(key, payload))
            items.append(
                {key: payload['amlCompute'][i][key]
                 for key in required_keys})

    return items
コード例 #24
0
    def _get(workspace,
             name=None,
             id=None,
             tags=None,
             properties=None,
             version=None):
        """Get the image with the given filtering criteria.

        :param workspace:
        :type workspace: azureml.core.workspace.Workspace
        :param name:
        :type name: str
        :param id:
        :type id: str
        :param tags:
        :type tags: dict[str, str]
        :param properties:
        :type properties: dict[str, str]
        :param version:
        :type version: str
        :return: azureml.core.Image payload dictionary
        :rtype: dict
        :raises: azureml.exceptions.WebserviceException
        """
        if not name and not id:
            raise WebserviceException(
                'Error, one of id or name must be provided.',
                logger=module_logger)

        headers = workspace._auth.get_authentication_header()
        params = {'orderBy': 'CreatedAtDesc', 'count': 1, 'expand': 'true'}
        base_endpoint = _get_mms_url(workspace)
        mms_endpoint = base_endpoint + '/images'

        if id:
            image_url = mms_endpoint + '/{}'.format(id)
        else:
            image_url = mms_endpoint
            params['name'] = name
        if tags:
            tags_query = ""
            for tag in tags:
                if type(tag) is list:
                    tags_query = tags_query + tag[0] + "=" + tag[1] + ","
                else:
                    tags_query = tags_query + tag + ","
            tags_query = tags_query[:-1]
            params['tags'] = tags_query
        if properties:
            properties_query = ""
            for prop in properties:
                if type(prop) is list:
                    properties_query = properties_query + prop[0] + "=" + prop[
                        1] + ","
                else:
                    properties_query = properties_query + prop + ","
            properties_query = properties_query[:-1]
            params['properties'] = properties_query
        if version:
            params['version'] = version

        resp = ClientBase._execute_func(get_requests_session().get,
                                        image_url,
                                        headers=headers,
                                        params=params,
                                        timeout=MMS_SYNC_TIMEOUT_SECONDS)

        if resp.status_code == 200:
            content = resp.content
            if isinstance(content, bytes):
                content = content.decode('utf-8')
            image_payload = json.loads(content)
            if id:
                return image_payload
            else:
                paginated_results = get_paginated_results(
                    image_payload, headers)
                if paginated_results:
                    return paginated_results[0]
                else:
                    return None
        elif resp.status_code == 404:
            return None
        else:
            raise WebserviceException(
                'Received bad response from Model Management Service:\n'
                'Response Code: {}\n'
                'Headers: {}\n'
                'Content: {}'.format(resp.status_code, resp.headers,
                                     resp.content),
                logger=module_logger)
コード例 #25
0
 def discover_services_uris(self, discovery_url=None):
     status = ClientBase._execute_func(requests.get, discovery_url)
     status.raise_for_status()
     return status.json()