def process(self, element, model_dir):
        try:
            if isinstance(model_dir, ValueProvider):
                model_dir = model_dir.get()

            if self._model_state is None:
                if (getattr(self._thread_local, "model_state", None) is None or
                        self._thread_local.model_state.model_dir != model_dir):
                    start = datetime.datetime.now()
                    self._thread_local.model_state = self._ModelState(
                        model_dir, self._tag_list, self._signature_name,
                        self._skip_preprocessing)
                    self._model_load_seconds_distribution.update(
                        int((datetime.datetime.now() - start).total_seconds()))
                self._model_state = self._thread_local.model_state
            else:
                assert self._model_state.model_dir == model_dir

            # Try to load it.
            if self._model_state.model.is_single_string_input():
                loaded_data = element
            else:
                loaded_data = [json.loads(d) for d in element]
            instances = mlprediction.decode_base64(loaded_data)
            inputs, predictions = self._model_state.model.predict(instances)
            predictions = list(predictions)
            predictions = mlprediction.encode_base64(
                predictions, self._model_state.model.signature.outputs)

            if self._aggregator_dict:
                aggr = self._aggregator_dict.get(
                    aggregators.AggregatorName.ML_PREDICTIONS, None)
                if aggr:
                    aggr.inc(len(predictions))

            for i, p in zip(inputs, predictions):
                yield i, p

        except mlprediction.PredictionError as e:
            logging.error("Got a known exception: [%s]\n%s", str(e),
                          traceback.format_exc())
            clean_error_detail = error_filter.filter_tensorflow_error(
                e.error_detail)
            if self._cloud_logger:
                # TODO(user): consider to write a sink to buffer the logging events. It
                # also eliminates the restarting/duplicated running issue.
                self._cloud_logger.write_error_message(
                    clean_error_detail, self._create_snippet(element))
            # reraise failure to load model as permanent exception to end dataflow job
            if e.error_code == mlprediction.PredictionError.FAILED_TO_LOAD_MODEL:
                raise beam.utils.retry.PermanentException(clean_error_detail)
            try:
                yield beam.pvalue.TaggedOutput("errors",
                                               (clean_error_detail, element))
            except AttributeError:
                yield beam.pvalue.SideOutputValue(
                    "errors", (clean_error_detail, element))

        except Exception as e:  # pylint: disable=broad-except
            logging.error("Got an unknown exception: [%s].",
                          traceback.format_exc())
            if self._cloud_logger:
                self._cloud_logger.write_error_message(
                    str(e), self._create_snippet(element))
            try:
                yield beam.pvalue.TaggedOutput("errors", (str(e), element))
            except AttributeError:
                yield beam.pvalue.SideOutputValue("errors", (str(e), element))
Esempio n. 2
0
    def process(self, element, model_dir):
        try:
            if isinstance(model_dir, ValueProvider):
                model_dir = model_dir.get()
            framework = self._framework.get()
            if self._model_state is None:
                if (getattr(self._thread_local, "model_state", None) is None or
                        self._thread_local.model_state.model_dir != model_dir):
                    start = datetime.datetime.now()
                    self._thread_local.model_state = self._ModelState(
                        model_dir, self._tag_list, framework)
                    self._model_load_seconds_distribution.update(
                        int((datetime.datetime.now() - start).total_seconds()))
                self._model_state = self._thread_local.model_state
            else:
                assert self._model_state.model_dir == model_dir

            # Measure the processing time.
            start = datetime.datetime.now()
            # Try to load it.
            if framework == mlprediction.TENSORFLOW_FRAMEWORK_NAME:
                # Even though predict() checks the signature in TensorFlowModel,
                # we need to duplicate this check here to determine the single string
                # input case.
                self._signature_name, signature = self._model_state.model.get_signature(
                    self._signature_name)
                if self._model_state.model.is_single_string_input(signature):
                    loaded_data = element
                else:
                    loaded_data = [json.loads(d) for d in element]
            else:
                loaded_data = [json.loads(d) for d in element]
            instances = mlprediction.decode_base64(loaded_data)
            # Actual prediction occurs.
            kwargs = {}
            if self._signature_name:
                kwargs = {"signature_name": self._signature_name}
            inputs, predictions = self._model_state.model.predict(
                instances, **kwargs)

            predictions = list(predictions)

            if self._aggregator_dict:
                self._aggregator_dict[
                    aggregators.AggregatorName.ML_PREDICTIONS].inc(
                        len(predictions))

            # For successful processing, record the time.
            td = datetime.datetime.now() - start
            time_delta_in_ms = int(td.microseconds / 10**3 +
                                   (td.seconds + td.days * 24 * 3600) * 10**3)
            self._batch_process_ms_distribution.update(time_delta_in_ms)

            for i, p in zip(inputs, predictions):
                yield i, p

        except mlprediction.PredictionError as e:
            logging.error("Got a known exception: [%s]\n%s", str(e),
                          traceback.format_exc())
            clean_error_detail = error_filter.filter_tensorflow_error(
                e.error_detail)
            if self._cloud_logger:
                # TODO(user): consider to write a sink to buffer the logging events. It
                # also eliminates the restarting/duplicated running issue.
                self._cloud_logger.write_error_message(
                    clean_error_detail, self._create_snippet(element))
            # Track in the counter.
            if self._aggregator_dict:
                counter_name = aggregators.AggregatorName.ML_FAILED_PREDICTIONS
                self._aggregator_dict[counter_name].inc(len(element))

            # reraise failure to load model as permanent exception to end dataflow job
            if e.error_code == mlprediction.PredictionError.FAILED_TO_LOAD_MODEL:
                raise beam.utils.retry.PermanentException(clean_error_detail)
            try:
                yield beam.pvalue.TaggedOutput("errors",
                                               (clean_error_detail, element))
            except AttributeError:
                yield beam.pvalue.SideOutputValue(
                    "errors", (clean_error_detail, element))

        except Exception as e:  # pylint: disable=broad-except
            logging.error("Got an unknown exception: [%s].",
                          traceback.format_exc())
            if self._cloud_logger:
                self._cloud_logger.write_error_message(
                    str(e), self._create_snippet(element))
            # Track in the counter.
            if self._aggregator_dict:
                counter_name = aggregators.AggregatorName.ML_FAILED_PREDICTIONS
                self._aggregator_dict[counter_name].inc(len(element))

            try:
                yield beam.pvalue.TaggedOutput("errors", (str(e), element))
            except AttributeError:
                yield beam.pvalue.SideOutputValue("errors", (str(e), element))