Beispiel #1
0
    def from_config(cls, config: Dict):
        """
        Convert from Data Step config to ZenML Datasource object.

        Data step is also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a DataStep config in dict-form (probably loaded from YAML).
        """
        if keys.DataSteps.DATA not in config[keys.PipelineKeys.STEPS]:
            raise Exception("Cant have datasource without data step.")

        # this is the data step config block
        step_config = config[keys.PipelineKeys.STEPS][keys.DataSteps.DATA]
        source = config[keys.PipelineKeys.DATASOURCE][
            keys.DatasourceKeys.SOURCE]
        datasource_class = source_utils.load_source_path_class(source)
        datasource_name = config[keys.PipelineKeys.DATASOURCE][
            keys.DatasourceKeys.NAME]
        _id = config[keys.PipelineKeys.DATASOURCE][keys.DatasourceKeys.ID]
        obj = datasource_class(
            name=datasource_name, _id=_id, _source=source,
            **step_config[keys.StepKeys.ARGS])
        obj._immutable = True
        return obj
Beispiel #2
0
    def from_config(config_block: Dict):
        """
        Takes config block that represents a Step and converts it back into
        its Python equivalent. This functionality is similar for most steps,
        and expected config_block may look like

        {
            'source': this.module.StepClass@sha  # where sha is optional
            'args': {}  # to be passed to the constructor
        }

        Args:
            config_block: config block representing source and args of step.
        """
        # resolve source path
        if StepKeys.SOURCE in config_block:
            source = config_block[StepKeys.SOURCE]
            class_ = load_source_path_class(source)
            args = config_block[StepKeys.ARGS]
            obj = class_(**args)

            # If we load from config, its immutable
            obj._immutable = True
            obj._source = source
            return obj
        else:
            raise AssertionError("Cannot create config_block without source "
                                 "key.")
Beispiel #3
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """
        Args:
            input_dict:
            output_dict:
            exec_properties:
        """
        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]

        c = source_utils.load_source_path_class(source)
        data_step: BaseDataStep = c(**args)

        # Get output split path
        examples_artifact = artifact_utils.get_single_instance(
            output_dict[DATA_SPLIT_NAME])
        split_names = [DATA_SPLIT_NAME]
        examples_artifact.split_names = artifact_utils.encode_split_names(
            split_names)
        output_split_path = artifact_utils.get_split_uri([examples_artifact],
                                                         DATA_SPLIT_NAME)

        with self._make_beam_pipeline() as p:
            (p
             | data_step.read_from_source()
             # | data_step.convert_to_dict()
             | WriteToTFRecord(data_step.schema, output_split_path))
Beispiel #4
0
    def from_config(config_block: Dict):
        """
        Takes config block that represents a Step and converts it back into
        its Python equivalent. This functionality is similar for most steps,
        and expected config_block may look like

        {
            'source': this.module.StepClass@sha  # where sha is optional
            'args': {}  # to be passed to the constructor
        }

        Args:
            config_block: config block representing source and args of step.
        """
        # resolve source path
        if StepKeys.SOURCE in config_block:
            source = config_block[StepKeys.SOURCE]
            class_ = load_source_path_class(source)
            args = config_block[StepKeys.ARGS]
            resolved_args = {}

            # resolve backend
            backend = None
            if StepKeys.BACKEND in config_block:
                backend_config = config_block[StepKeys.BACKEND]
                backend = BaseBackend.from_config(backend_config)

            # resolve args for special cases
            for k, v in args.items():
                if isinstance(v, str) and is_valid_source(v):
                    resolved_args[k] = load_source_path_class(v)
                else:
                    resolved_args[k] = v

            obj = class_(**resolved_args)

            # If we load from config, its immutable
            obj._immutable = True
            obj.backend = backend
            return obj
        else:
            raise AssertionError("Cannot create config_block without source "
                                 "key.")
Beispiel #5
0
    def from_config(cls, config: Dict):
        """
        Convert from pipeline config to ZenML Pipeline object.

        All steps are also populated and configuration set to parameters set
        in the config file.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML).
        """
        # start with artifact store
        artifact_store = ArtifactStore(config[keys.GlobalKeys.ARTIFACT_STORE])

        # metadata store
        metadata_store = ZenMLMetadataStore.from_config(
            config=config[keys.GlobalKeys.METADATA_STORE]
        )

        # orchestration backend
        backend = OrchestratorBaseBackend.from_config(
            config[keys.GlobalKeys.BACKEND])

        # pipeline configuration
        p_config = config[keys.GlobalKeys.PIPELINE]
        pipeline_name = p_config[keys.PipelineKeys.NAME]
        pipeline_source = p_config[keys.PipelineKeys.SOURCE]

        # populate steps
        steps_dict: Dict = {}
        for step_key, step_config in p_config[keys.PipelineKeys.STEPS].items():
            steps_dict[step_key] = BaseStep.from_config(step_config)

        # datasource
        datasource = BaseDatasource.from_config(
            config[keys.GlobalKeys.PIPELINE])

        # enable cache
        enable_cache = p_config[keys.PipelineKeys.ENABLE_CACHE]

        class_ = source_utils.load_source_path_class(pipeline_source)

        obj = class_(
            name=cls.get_name_from_pipeline_name(pipeline_name),
            pipeline_name=pipeline_name,
            enable_cache=enable_cache,
            steps_dict=steps_dict,
            backend=backend,
            artifact_store=artifact_store,
            metadata_store=metadata_store,
            datasource=datasource)
        obj._immutable = True
        logger.debug(f'Pipeline {pipeline_name} loaded and and is immutable.')
        return obj
Beispiel #6
0
    def from_config(cls, config: Dict):
        """
        Convert from ZenML config dict to ZenML Backend object.

        Args:
            config: a ZenML config in dict-form (probably loaded from YAML)
        """
        backend_class = source_utils.load_source_path_class(
            config[BackendKeys.SOURCE])
        backend_args = config[BackendKeys.ARGS]
        obj = backend_class(**backend_args)
        obj._immutable = True
        return obj
Beispiel #7
0
def run_fn(fn_args):
    c = load_source_path_class(fn_args.pop(StepKeys.SOURCE))

    # Pop unnecessary args
    args = fn_args.pop(StepKeys.ARGS)

    # TODO: [LOW] Hard-coded
    fn_args.pop('data_accessor')

    # We update users args first, because fn_args might have overlaps
    args.update(fn_args)

    return c(**args).get_run_fn()()
Beispiel #8
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:

        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]

        c = source_utils.load_source_path_class(source)
        tokenizer_step: BaseTokenizer = c(**args)

        tokenizer_location = artifact_utils.get_single_uri(
            output_dict["tokenizer"])

        split_uris, split_names, all_files = [], [], []
        for artifact in input_dict["examples"]:
            for split in artifact_utils.decode_split_names(
                    artifact.split_names):
                split_names.append(split)
                uri = os.path.join(artifact.uri, split)
                split_uris.append((split, uri))
                all_files += path_utils.list_dir(uri)

        # Get output split path
        output_examples = artifact_utils.get_single_instance(
            output_dict["output_examples"])
        output_examples.split_names = artifact_utils.encode_split_names(
            split_names)

        if not tokenizer_step.skip_training:
            tokenizer_step.train(files=all_files)

            tokenizer_step.save(output_dir=tokenizer_location)

        with self._make_beam_pipeline() as p:
            for split, uri in split_uris:
                input_uri = io_utils.all_files_pattern(uri)

                _ = (p
                     | 'ReadData.' + split >> beam.io.ReadFromTFRecord(
                            file_pattern=input_uri)
                     | "ParseTFExFromString." + split >> beam.Map(
                            tf.train.Example.FromString)
                     | "AddTokens." + split >> beam.Map(
                            append_tf_example,
                            tokenizer_step=tokenizer_step)
                     | 'Serialize.' + split >> beam.Map(
                            lambda x: x.SerializeToString())
                     | 'WriteSplit.' + split >> WriteSplit(
                            get_split_uri(
                                output_dict["output_examples"],
                                split)))
Beispiel #9
0
    def read(
        self,
        output_data_type: Optional[Type[Any]] = None,
        materializer_class: Optional[Type["BaseMaterializer"]] = None,
    ) -> Any:
        """Materializes the data stored in this artifact.

        Args:
            output_data_type: The datatype to which the materializer should
                read, will be passed to the materializers `handle_input` method.
            materializer_class: The class of the materializer that should be
                used to read the artifact data. If no materializer class is
                given, we use the materializer that was used to write the
                artifact during execution of the pipeline.

        Returns:
              The materialized data.
        """
        if not materializer_class:
            materializer_class = source_utils.load_source_path_class(
                self._materializer)

        if not output_data_type:
            output_data_type = source_utils.load_source_path_class(
                self._data_type)

        logger.debug(
            "Using '%s' to read '%s' (uri: %s).",
            materializer_class.__qualname__,
            self._type,
            self._uri,
        )

        # TODO [ENG-162]: passing in `self` to initialize the materializer only
        #  works because materializers only require a `.uri` property at the
        #  moment.
        materializer = materializer_class(self)  # type: ignore[arg-type]
        return materializer.handle_input(output_data_type)
Beispiel #10
0
def run_fn(fn_args):
    fn_args_dict = fn_args.__dict__
    custom_config = fn_args_dict.pop('custom_config')
    c = load_source_path_class(custom_config.pop(StepKeys.SOURCE))

    # Pop unnecessary args
    args = custom_config.pop(StepKeys.ARGS)

    # TODO: [LOW] Hard-coded
    fn_args_dict.pop('data_accessor')

    # We update users args first, because fn_args might have overlaps
    args.update(fn_args_dict)

    return c(**args).run_fn()
Beispiel #11
0
    def resolve_input_artifact(self, artifact: BaseArtifact,
                               data_type: Type[Any]) -> Any:
        """Resolves an input artifact, i.e., reading it from the Artifact Store
        to a pythonic object.

        Args:
            artifact: A TFX artifact type.
            data_type: The type of data to be materialized.

        Returns:
            Return the output of `handle_input()` of selected materializer.
        """
        materializer = source_utils.load_source_path_class(
            artifact.materializer)(artifact)
        # The materializer now returns a resolved input
        return materializer.handle_input(data_type=data_type)
Beispiel #12
0
def get_component_from_key(
        key: str, mapping: Dict[str, UUIDSourceTuple]) -> BaseComponent:
    """Given a key and a mapping, return an initialized component.

    Args:
        key: Unique key.
        mapping: Dict of type str -> UUIDSourceTuple.

    Returns:
        An object which is a subclass of type BaseComponent.
    """
    tuple_ = mapping[key]
    class_ = source_utils.load_source_path_class(tuple_.source)
    if not issubclass(class_, BaseComponent):
        raise TypeError("")
    return class_(uuid=tuple_.uuid)
Beispiel #13
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """
        Write description regarding this beautiful executor.

        Args:
            input_dict:
            output_dict:
            exec_properties:
        """
        self._log_startup(input_dict, output_dict, exec_properties)

        schema = parse_schema(input_dict=input_dict)

        statistics = parse_statistics(
            split_name=DATA_SPLIT_NAME,
            statistics=input_dict[constants.STATISTICS])

        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]

        # pass the schema and stats straight to the Step
        args[constants.SCHEMA] = schema
        args[constants.STATISTICS] = statistics

        c = source_utils.load_source_path_class(source)
        split_step: BaseSplit = c(**args)

        # infer the names of the splits from the config
        split_names = split_step.get_split_names()

        # Get output split path
        examples_artifact = artifact_utils.get_single_instance(
            output_dict[constants.OUTPUT_EXAMPLES])
        if SKIP in split_names:
            sanitized_names = [name for name in split_names if name != SKIP]
            examples_artifact.split_names = artifact_utils.encode_split_names(
                sanitized_names)
        else:
            examples_artifact.split_names = artifact_utils.encode_split_names(
                split_names)

        split_uris = []
        for artifact in input_dict[constants.INPUT_EXAMPLES]:
            for split in artifact_utils.decode_split_names(
                    artifact.split_names):
                uri = os.path.join(artifact.uri, split)
                split_uris.append((split, uri))

        with self._make_beam_pipeline() as p:
            # The outer loop will for now only run once
            for split, uri in split_uris:
                input_uri = io_utils.all_files_pattern(uri)

                new_splits = (
                    p
                    | 'ReadData.' + split >>
                    beam.io.ReadFromTFRecord(file_pattern=input_uri)
                    | beam.Map(tf.train.Example.FromString)
                    |
                    'Split' >> beam.Partition(split_step.partition_fn()[0],
                                              split_step.get_num_splits(),
                                              **split_step.partition_fn()[1]))

                for split_name, new_split in zip(split_names,
                                                 list(new_splits)):
                    if split_name != SKIP:
                        # WriteSplit function writes to TFRecord again
                        (new_split
                         | 'Serialize.' + split_name >>
                         beam.Map(lambda x: x.SerializeToString())
                         | 'WriteSplit_' + split_name >> WriteSplit(
                             get_split_uri(
                                 output_dict[constants.OUTPUT_EXAMPLES],
                                 split_name)))
Beispiel #14
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Runs batch inference on a given model with given input examples.

        Args:
          input_dict: Input dict from input key to a list of Artifacts.
            - examples: examples for inference.
            - model: exported model.
            - model_blessing: model blessing result, optional.
          output_dict: Output dict from output key to a list of Artifacts.
            - output: bulk inference results.
          exec_properties: A dict of execution properties.
            - model_spec: JSON string of bulk_inferrer_pb2.ModelSpec instance.
            - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance.

        Returns:
          None
        """
        self._log_startup(input_dict, output_dict, exec_properties)

        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]
        c = source_utils.load_source_path_class(source)
        inferrer_step: BaseInferrer = c(**args)

        output_examples = artifact_utils.get_single_instance(
            output_dict[PREDICTIONS])

        if EXAMPLES not in input_dict:
            raise ValueError('\'examples\' is missing in input dict.')
        if MODEL not in input_dict:
            raise ValueError('Input models are not valid, model '
                             'need to be specified.')
        if MODEL_BLESSING in input_dict:
            model_blessing = artifact_utils.get_single_instance(
                input_dict['model_blessing'])
            if not model_utils.is_model_blessed(model_blessing):
                logging.info('Model on %s was not blessed', model_blessing.uri)
                return
        else:
            logging.info(
                'Model blessing is not provided, exported model will be '
                'used.')

        model = artifact_utils.get_single_instance(input_dict[MODEL])
        model_path = path_utils.serving_model_path(model.uri)
        logging.info('Use exported model from %s.', model_path)

        output_example_spec = bulk_inferrer_pb2.OutputExampleSpec(
            output_columns_spec=[
                bulk_inferrer_pb2.OutputColumnsSpec(
                    predict_output=bulk_inferrer_pb2.PredictOutput(
                        output_columns=[
                            bulk_inferrer_pb2.PredictOutputCol(
                                output_key=x,
                                output_column=f'{x}_label',
                            ) for x in inferrer_step.get_labels()
                        ]))
            ])

        model_spec = bulk_inferrer_pb2.ModelSpec()
        saved_model_spec = model_spec_pb2.SavedModelSpec(
            model_path=model_path,
            tag=model_spec.tag,
            signature_name=model_spec.model_signature_name)
        inference_spec = model_spec_pb2.InferenceSpecType()
        inference_spec.saved_model_spec.CopyFrom(saved_model_spec)

        self._run_model_inference(output_example_spec, input_dict[EXAMPLES],
                                  output_examples, inference_spec,
                                  inferrer_step)
Beispiel #15
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """
        Main execution logic for the Sequencer component

        :param input_dict: input channels
        :param output_dict: output channels
        :param exec_properties: the execution properties defined in the spec
        """

        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]

        c = source_utils.load_source_path_class(source)

        # Get the schema
        schema_path = io_utils.get_only_uri_in_dir(
            artifact_utils.get_single_uri(input_dict[constants.SCHEMA]))
        schema = io_utils.SchemaReader().read(schema_path)

        # TODO: Getting the statistics might help the future implementations

        sequence_step: BaseSequencerStep = c(schema=schema,
                                             statistics=None,
                                             **args)

        # Get split names
        input_artifact = artifact_utils.get_single_instance(
            input_dict[constants.INPUT_EXAMPLES])
        split_names = artifact_utils.decode_split_names(
            input_artifact.split_names)

        # Create output artifact
        output_artifact = artifact_utils.get_single_instance(
            output_dict[constants.OUTPUT_EXAMPLES])
        output_artifact.split_names = artifact_utils.encode_split_names(
            split_names)

        with self._make_beam_pipeline() as p:
            for s in split_names:
                input_uri = io_utils.all_files_pattern(
                    artifact_utils.get_split_uri(
                        input_dict[constants.INPUT_EXAMPLES], s))

                output_uri = artifact_utils.get_split_uri(
                    output_dict[constants.OUTPUT_EXAMPLES], s)
                output_path = os.path.join(output_uri, self._DEFAULT_FILENAME)

                # Read and decode the data
                data = \
                    (p
                     | 'Read_' + s >> beam.io.ReadFromTFRecord(
                                file_pattern=input_uri)
                     | 'Decode_' + s >> tf_example_decoder.DecodeTFExample()
                     | 'ToDataFrame_' + s >> beam.ParDo(utils.ConvertToDataframe()))

                # Window into sessions
                s_data = \
                    (data
                     | 'AddCategory_' + s >> beam.ParDo(
                                sequence_step.get_category_do_fn())
                     | 'AddTimestamp_' + s >> beam.ParDo(
                                sequence_step.get_timestamp_do_fn())
                     | 'Sessions_' + s >> beam.WindowInto(
                                sequence_step.get_window()))

                # Combine and transform
                p_data = \
                    (s_data
                     | 'Combine_' + s >> beam.CombinePerKey(
                                sequence_step.get_combine_fn()))

                # Write the results
                _ = \
                    (p_data
                     | 'Global_' + s >> beam.WindowInto(GlobalWindows())
                     | 'RemoveKey_' + s >> beam.ParDo(RemoveKey())
                     | 'ToExample_' + s >> beam.Map(utils.df_to_example)
                     | 'Serialize_' + s >> beam.Map(utils.serialize)
                     | 'Write_' + s >> beam.io.WriteToTFRecord(
                                output_path,
                                file_name_suffix='.gz'))
Beispiel #16
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]):
        """Overrides the tfx_pusher_executor.

        Args:
          input_dict: Input dict from input key to a list of artifacts,
          including:
            - model_export: exported model from trainer.
            - model_blessing: model blessing path from evaluator.
          output_dict: Output dict from key to a list of artifacts, including:
            - model_push: A list of 'ModelPushPath' artifact of size one. It
            will
              include the model in this push execution if the model was pushed.
          exec_properties: Mostly a passthrough input dict for
            tfx.components.Pusher.executor.custom_config
        Raises:
          ValueError: if custom config not present or not a dict.
          RuntimeError: if
        """
        self._log_startup(input_dict, output_dict, exec_properties)

        # check model blessing
        model_push = artifact_utils.get_single_instance(
            output_dict[tfx_pusher_executor.PUSHED_MODEL_KEY])
        if not self.CheckBlessing(input_dict):
            self._MarkNotPushed(model_push)
            return

        model_export = artifact_utils.get_single_instance(
            input_dict[tfx_pusher_executor.MODEL_KEY])

        custom_config = json_utils.loads(
            exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
        if custom_config is not None and not isinstance(custom_config, Dict):
            raise ValueError(
                'custom_config in execution properties needs to be a '
                'dict.')

        cortex_serving_args = custom_config.get(SERVING_ARGS_KEY)
        if not cortex_serving_args:
            raise ValueError(
                '\'cortex_serving_args\' is missing in \'custom_config\'')

        # Deploy the model.
        io_utils.copy_dir(
            src=path_utils.serving_model_path(model_export.uri),
            dst=model_push.uri)
        model_path = model_push.uri

        # Cortex implementation starts here
        # pop the env and initialize client
        cx = cortex.client(cortex_serving_args.pop('env'))

        # load the predictor
        predictor_path = cortex_serving_args.pop('predictor_path')
        predictor = load_source_path_class(predictor_path)

        # edit the api_config
        api_config = cortex_serving_args.pop('api_config')
        if 'config' not in api_config['predictor']:
            api_config['predictor']['config'] = {}
        api_config['predictor']['config']['model_artifact'] = model_path

        # launch the api
        cx.create_api(
            api_config=api_config, predictor=predictor, **cortex_serving_args)

        self._MarkPushed(
            model_push,
            pushed_destination=model_path)
Beispiel #17
0
def preprocessing_fn(inputs, custom_config):
    c = load_source_path_class(custom_config[StepKeys.SOURCE])
    args = custom_config[StepKeys.ARGS]
    return c(**args).get_preprocessing_fn()(inputs)
Beispiel #18
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:

        # Check the inputs
        if constants.EXAMPLES not in input_dict:
            raise ValueError(f'{constants.EXAMPLES} is missing from inputs')
        examples_artifact = input_dict[constants.EXAMPLES]

        input_uri = artifact_utils.get_single_uri(examples_artifact)
        if len(zenml_path_utils.list_dir(input_uri)) == 0:
            raise AssertionError(
                'ZenML can not run the evaluation as the provided input '
                'configuration does not point towards any data. Specifically, '
                'if you are using the agnostic evaluator, please make sure '
                'that you are using a proper test_fn in your trainer step to '
                'write these results.')

        else:
            # Check the outputs
            if constants.EVALUATION not in output_dict:
                raise ValueError(
                    f'{constants.EVALUATION} is missing from outputs')
            evaluation_artifact = output_dict[constants.EVALUATION]
            output_uri = artifact_utils.get_single_uri(evaluation_artifact)

            # Resolve the schema
            schema = None
            if constants.SCHEMA in input_dict:
                schema_artifact = input_dict[constants.SCHEMA]
                schema_uri = artifact_utils.get_single_uri(schema_artifact)
                reader = io_utils.SchemaReader()
                schema = reader.read(io_utils.get_only_uri_in_dir(schema_uri))

            # Create the step with the schema attached if provided
            source = exec_properties[StepKeys.SOURCE]
            args = exec_properties[StepKeys.ARGS]
            c = source_utils.load_source_path_class(source)
            evaluator_step: BaseEvaluatorStep = c(**args)

            # Check the execution parameters
            eval_config = evaluator_step.build_config()
            eval_config = tfma.update_eval_config_with_defaults(eval_config)
            tfma.verify_eval_config(eval_config)

            # Resolve the model
            if constants.MODEL in input_dict:
                model_artifact = input_dict[constants.MODEL]
                model_uri = artifact_utils.get_single_uri(model_artifact)
                model_path = path_utils.serving_model_path(model_uri)

                model_fn = try_get_fn(evaluator_step.CUSTOM_MODULE,
                                      'custom_eval_shared_model'
                                      ) or tfma.default_eval_shared_model

                eval_shared_model = model_fn(
                    model_name='',  # TODO: Fix with model names
                    eval_saved_model_path=model_path,
                    eval_config=eval_config)
            else:
                eval_shared_model = None

            self._log_startup(input_dict, output_dict, exec_properties)

            # Main pipeline
            logging.info('Evaluating model.')
            with self._make_beam_pipeline() as pipeline:
                examples_list = []
                tensor_adapter_config = None

                if tfma.is_batched_input(eval_shared_model, eval_config):
                    tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
                        examples=[
                            artifact_utils.get_single_instance(
                                examples_artifact)
                        ],
                        telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
                        schema=schema,
                        raw_record_column_name=tfma_constants.
                        ARROW_INPUT_COLUMN)
                    for split in evaluator_step.splits:
                        file_pattern = io_utils.all_files_pattern(
                            artifact_utils.get_split_uri(
                                examples_artifact, split))
                        tfxio = tfxio_factory(file_pattern)
                        data = (pipeline
                                | 'ReadFromTFRecordToArrow[%s]' % split >>
                                tfxio.BeamSource())
                        examples_list.append(data)
                    if schema is not None:
                        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
                            arrow_schema=tfxio.ArrowSchema(),
                            tensor_representations=tfxio.TensorRepresentations(
                            ))
                else:
                    for split in evaluator_step.splits:
                        file_pattern = io_utils.all_files_pattern(
                            artifact_utils.get_split_uri(
                                examples_artifact, split))
                        data = (pipeline
                                | 'ReadFromTFRecord[%s]' % split >> beam.io.
                                ReadFromTFRecord(file_pattern=file_pattern))
                        examples_list.append(data)

                # Resolve custom extractors
                custom_extractors = try_get_fn(evaluator_step.CUSTOM_MODULE,
                                               'custom_extractors')
                extractors = None
                if custom_extractors:
                    extractors = custom_extractors(
                        eval_shared_model=eval_shared_model,
                        eval_config=eval_config,
                        tensor_adapter_config=tensor_adapter_config)

                # Resolve custom evaluators
                custom_evaluators = try_get_fn(evaluator_step.CUSTOM_MODULE,
                                               'custom_evaluators')
                evaluators = None
                if custom_evaluators:
                    evaluators = custom_evaluators(
                        eval_shared_model=eval_shared_model,
                        eval_config=eval_config,
                        tensor_adapter_config=tensor_adapter_config)

                # Extract, evaluate and write
                (examples_list | 'FlattenExamples' >> beam.Flatten()
                 | 'ExtractEvaluateAndWriteResults' >>
                 tfma.ExtractEvaluateAndWriteResults(
                     eval_config=eval_config,
                     eval_shared_model=eval_shared_model,
                     output_path=output_uri,
                     extractors=extractors,
                     evaluators=evaluators,
                     tensor_adapter_config=tensor_adapter_config))
            logging.info('Evaluation complete. Results written to %s.',
                         output_uri)