Esempio n. 1
0
  def test_evaluate(self):
    es_size_2 = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=2,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')
    es_size_4 = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=4,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')
    metrics = [tf.keras.metrics.MeanSquaredError()]

    es_size_2.fit(self.fit_examples, self.fit_label)
    es_size_4.fit(self.fit_examples, self.fit_label)
    es_2_mse = es_size_2.evaluate(self.fit_examples, self.fit_label, metrics)[0]
    es_4_mse = es_size_4.evaluate(self.fit_examples, self.fit_label, metrics)[0]

    self.assertLessEqual(es_4_mse, es_2_mse)
Esempio n. 2
0
    def testDoWithMajoritVoting(self):

        exec_properties = self._exec_properties.copy()
        exec_properties['tuner_fn'] = '%s.%s' % (
            tuner_module.tuner_fn.__module__, tuner_module.tuner_fn.__name__)
        exec_properties['metalearning_algorithm'] = 'majority_voting'

        input_dict = self._input_dict.copy()

        ps_type = ps_pb2.Type(
            binary_classification=ps_pb2.BinaryClassification(label='class'))
        ps = ps_pb2.ProblemStatement(
            owner=['nitroml'],
            tasks=[ps_pb2.Task(
                name='mockdata_1',
                type=ps_type,
            )])

        exec_properties['custom_config'] = json_utils.dumps({
            'problem_statement':
            text_format.MessageToString(message=ps, as_utf8=True),
        })
        hps_artifact = artifacts.KCandidateHyperParameters()
        hps_artifact.uri = os.path.join(self._testdata_dir,
                                        'MetaLearner.majority_voting',
                                        'hparams_out')
        input_dict['warmup_hyperparameters'] = [hps_artifact]

        tuner = executor.Executor(self._context)
        tuner.Do(input_dict=input_dict,
                 output_dict=self._output_dict,
                 exec_properties=exec_properties)
        self._verify_output()
Esempio n. 3
0
    def testDoWithTunerFn(self):

        self._exec_properties['tuner_fn'] = '%s.%s' % (
            tuner_module.tuner_fn.__module__, tuner_module.tuner_fn.__name__)

        ps_type = ps_pb2.Type(
            binary_classification=ps_pb2.BinaryClassification(label='class'))
        ps = ps_pb2.ProblemStatement(
            owner=['nitroml'],
            tasks=[ps_pb2.Task(
                name='mockdata_1',
                type=ps_type,
            )])

        self._exec_properties['custom_config'] = json_utils.dumps({
            'problem_statement':
            text_format.MessageToString(message=ps, as_utf8=True),
        })

        tuner = executor.Executor(self._context)
        tuner.Do(input_dict=self._input_dict,
                 output_dict=self._output_dict,
                 exec_properties=self._exec_properties)

        self._verify_output()
Esempio n. 4
0
    def problem_statement(self) -> ps_pb2.ProblemStatement:
        """Returns the ProblemStatement associated with this Task."""

        return ps_pb2.ProblemStatement(
            owner=['nitroml'],
            tasks=[ps_pb2.Task(
                name=self.name,
                type=self._get_task_type(),
            )])
Esempio n. 5
0
def preprocessing_fn(inputs: Dict[str, Tensor],
                     custom_config=Dict[str, Any]) -> Dict[str, Tensor]:
    """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.
    custom_config: Custom configuration dictionary for passing the task's
      ProblemStatement as a text proto, since custom_config must be
      JSON-serializable.

  Returns:
    Map from string feature key to transformed feature operations.
  """

    problem_statement = ps_pb2.ProblemStatement()
    text_format.Parse(
        text=custom_config[BasicPreprocessor.PROBLEM_STATEMENT_KEY],
        message=problem_statement)

    outputs = {}
    for key in [k for k, v in inputs.items() if v.dtype == tf.float32]:
        # TODO(weill): Handle case when an int field can actually represents numeric
        # rather than categorical values.
        task_type = problem_statement.tasks[0].type
        if task_type.HasField('one_dimensional_regression') and (
                key == task_type.one_dimensional_regression.label):
            outputs[key] = inputs[key]
            # Skip normalizing regression tasks.
            continue

        # Preserve this feature as a dense float, setting nan's to the mean.
        outputs[_sanitize_feature_name(key)] = tft.scale_to_z_score(
            _fill_in_missing(inputs[key]))

    for key in [k for k, v in inputs.items() if v.dtype != tf.float32]:
        # Build a vocabulary for this feature.
        # TODO(weill): Risk here to blow up computation needlessly.
        output = tft.compute_and_apply_vocabulary(_fill_in_missing(
            inputs[key]),
                                                  top_k=None,
                                                  num_oov_buckets=1)

        # Don't sanitize the label key name.
        task_type = problem_statement.tasks[0].type
        if task_type.HasField('multi_class_classification') and (
                key == task_type.multi_class_classification.label):
            outputs[key] = output
            continue
        if task_type.HasField('binary_classification') and (
                key == task_type.binary_classification.label):
            outputs[key] = output
            continue

        # Do sanitize feature key names.
        outputs[_sanitize_feature_name(key)] = output

    return outputs
Esempio n. 6
0
    def problem_statement(self) -> ps_pb2.ProblemStatement:
        """Returns the ProblemStatement associated with this BenchmarkTask."""

        return ps_pb2.ProblemStatement(
            owner=['nitroml'],
            tasks=[
                ps_pb2.Task(
                    name='Test',
                    type=ps_pb2.Type(one_dimensional_regression=ps_pb2.
                                     OneDimensionalRegression(label='test')),
                )
            ])
Esempio n. 7
0
  def test_lifecycle(self):
    es = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=3,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')
    test_dir = os.path.join(self.data_path, 'test_examples.tfrecord')
    test_examples = np.asarray(
        list(tf.data.TFRecordDataset(test_dir).as_numpy_iterator()))
    test_examples_tensor = tf.convert_to_tensor(test_examples)
    model_predictions = {}
    for model_id, path in self.saved_model_paths.items():
      reloaded_model = tf.saved_model.load(path)
      model_predictions[model_id] = reloaded_model.signatures[
          'serving_default'](test_examples_tensor)['output_0'].numpy()
    want_weights = {'2': 0.3333333333333333, '4': 0.6666666666666666}
    want_prediction = want_weights['2'] * model_predictions['2'] + want_weights[
        '4'] * model_predictions['4']
    mse = tf.keras.metrics.MeanSquaredError()
    mse_scores = []
    for pred in model_predictions.values():
      mse_scores.append(mse(self.fit_label, pred))
      mse.reset_states()
    export_dir = os.path.join(
        tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir()),
        'from_estimator')

    es.fit(self.fit_examples, self.fit_label)
    ensemble_predictions = es.predict(test_examples)
    ensemble_mse = es.evaluate(self.fit_examples, self.fit_label, [mse])[0]
    ensemble_path = es.save(export_dir)
    reloaded_ensemble = tf.saved_model.load(ensemble_path)
    loaded_ensemble_prediction = reloaded_ensemble.signatures[
        'serving_default'](input=test_examples_tensor)['output'].numpy()

    self.assertEqual(want_weights, es.weights)
    self.assertEqual((10, 1), ensemble_predictions.shape)
    np.testing.assert_array_almost_equal(want_prediction, ensemble_predictions,
                                         1)
    self.assertLessEqual(ensemble_mse, min(mse_scores))
    np.testing.assert_array_almost_equal(ensemble_predictions,
                                         loaded_ensemble_prediction, 1)
Esempio n. 8
0
  def problem_statement(self) -> ps_pb2.ProblemStatement:
    """Returns the ProblemStatement associated with this Task."""

    # Supervised keys is a two-tuple.
    _, target_key = self._dataset_builder.info.supervised_keys
    return ps_pb2.ProblemStatement(
        owner=['nitroml'],
        tasks=[
            ps_pb2.Task(
                name=self.name,
                type=ps_pb2.Type(
                    binary_classification=ps_pb2.BinaryClassification(
                        label=target_key)),
            )
        ])
Esempio n. 9
0
def ensemble_selection(
    problem_statement: Parameter[str],
    examples: InputArtifact[standard_artifacts.Examples],
    evaluation_split_name: Parameter[str],
    ensemble_size: Parameter[int],
    metric: Parameter[str],
    goal: Parameter[str],
    model: OutputArtifact[standard_artifacts.Model],
    input_model0: InputArtifact[standard_artifacts.Model] = None,
    input_model1: InputArtifact[standard_artifacts.Model] = None,
    input_model2: InputArtifact[standard_artifacts.Model] = None,
    input_model3: InputArtifact[standard_artifacts.Model] = None,
    input_model4: InputArtifact[standard_artifacts.Model] = None,
    input_model5: InputArtifact[standard_artifacts.Model] = None,
    input_model6: InputArtifact[standard_artifacts.Model] = None,
    input_model7: InputArtifact[standard_artifacts.Model] = None,
    input_model8: InputArtifact[standard_artifacts.Model] = None,
    input_model9: InputArtifact[standard_artifacts.Model] = None,
) -> None:  # pytype: disable=invalid-annotation,wrong-arg-types
    """Runs the SimpleML trainer as a separate component."""

    problem_statement = text_format.Parse(problem_statement,
                                          ps_pb2.ProblemStatement())
    input_models = [
        input_model0, input_model1, input_model2, input_model3, input_model4,
        input_model5, input_model6, input_model7, input_model8, input_model9
    ]
    saved_model_paths = {
        str(i): path_utils.serving_model_path(model.uri)
        for i, model in enumerate(input_models) if model
    }
    logging.info('Saved model paths: %s', saved_model_paths)

    label_key = _label_key(problem_statement)

    es = es_lib.EnsembleSelection(problem_statement=problem_statement,
                                  saved_model_paths=saved_model_paths,
                                  ensemble_size=ensemble_size,
                                  metric=tf.keras.metrics.deserialize(
                                      json.loads(metric)),
                                  goal=goal)

    es.fit(*_data_from_examples(examples_path=os.path.join(
        examples.uri, evaluation_split_name),
                                label_key=label_key))
    logging.info('Selected ensemble weights: %s', es.weights)
    es.save(export_path=os.path.join(path_utils.serving_model_dir(model.uri),
                                     'export', 'serving'))
Esempio n. 10
0
  def test_get_predictions(self):
    # TODO(liumich): improve test predictions with the following steps
    # - reduce the number of samples to 4-5
    # - manually compute MSE after each iteration for each partial ensemble
    # - also output the ground truth (labels) so that we can verify
    es = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=3,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')
    want_predictions = {
        '0':
            np.array([[268520.7], [172055.8], [172840.52], [203374.36],
                      [629715.5], [160393.], [242507.27], [156286.08],
                      [262261.7], [221169.3]]),
        '1':
            np.array([[262822.53], [168104.17], [168874.69], [198855.67],
                      [617477.56], [156652.53], [237280.08], [152620.],
                      [256676.81], [216328.45]]),
        '2':
            np.array([[247936.98], [158487.19], [159214.84], [187528.2],
                      [582864.9], [147672.52], [223815.3], [143864.3],
                      [242133.11], [204029.08]]),
        '3':
            np.array([[268206.75], [171761.81], [172546.36], [203073.86],
                      [629326.7], [160101.38], [242198.66], [155995.36],
                      [261948.98], [220865.14]]),
        '4':
            np.array([[257493.5], [164639.19], [165394.55], [194785.5],
                      [605169.06], [153412.9], [232453.73], [149459.73],
                      [251468.73], [211914.42]])
    }

    predictions = es._get_predictions_dict(self.fit_examples)

    for model_id in predictions.keys():
      np.testing.assert_array_almost_equal(want_predictions[model_id],
                                           predictions[model_id], 1)
Esempio n. 11
0
  def test_predict_before_fit(self):
    es = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=3,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')

    with self.assertRaisesRegex(
        ValueError,
        'Weights cannot be empty. Must call `fit` before `predict`.'):
      _ = es.predict(self.fit_examples)
Esempio n. 12
0
  def test_calculate_weights(self):
    es = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=4,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')
    ensemble_count = {'model_1': 1, 'model_2': 2, 'model_3': 1}
    want_weights = {'model_1': 0.25, 'model_2': 0.5, 'model_3': 0.25}

    es._calculate_weights(ensemble_count)

    self.assertEqual(want_weights, es.weights)
Esempio n. 13
0
  def test_evaluate_metrics(self):
    es = ensemble_selection.EnsembleSelection(
        problem_statement=ps_pb2.ProblemStatement(tasks=[
            ps_pb2.Task(
                type=ps_pb2.Type(
                    one_dimensional_regression=ps_pb2.OneDimensionalRegression(
                        label='label')))
        ]),
        saved_model_paths=self.saved_model_paths,
        predict_fn=_test_predict_fn,
        ensemble_size=3,
        metric=tf.keras.metrics.MeanSquaredError(),
        goal='minimize')
    metrics = [
        tf.keras.metrics.MeanSquaredError(),
        tf.keras.metrics.MeanAbsoluteError(),
        tf.keras.metrics.RootMeanSquaredError()
    ]

    es.fit(self.fit_examples, self.fit_label)
    ensemble_metrics = es.evaluate(self.fit_examples, self.fit_label, metrics)

    self.assertLen(ensemble_metrics, len(metrics))
Esempio n. 14
0
def tuner_fn(fn_args: fn_args_utils.FnArgs) -> TunerFnResult:
    """Build the tuner using the KerasTuner API.

  Args:
    fn_args: Holds args as name/value pairs.
      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.
      - custom_config: A dict with a single 'problem_statement' entry containing
        a text-format serialized ProblemStatement proto which defines the task.

  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.
  """

    problem_statement = text_format.Parse(
        fn_args.custom_config['problem_statement'], ps_pb2.ProblemStatement())
    autodata_adapter = kma.KerasModelAdapter(
        problem_statement=problem_statement,
        transform_graph_dir=fn_args.transform_graph_path)

    build_keras_model_fn = functools.partial(_build_keras_model,
                                             autodata_adapter=autodata_adapter)
    if 'warmup_hyperparameters' in fn_args.custom_config:
        hyperparameters = hp_module.HyperParameters.from_config(
            fn_args.custom_config['warmup_hyperparameters'])
    else:
        hyperparameters = _get_hyperparameters()

    tuner_cls = get_tuner_cls_with_callbacks(kerastuner.RandomSearch)
    tuner = tuner_cls(build_keras_model_fn,
                      max_trials=fn_args.custom_config.get('max_trials', 10),
                      hyperparameters=hyperparameters,
                      allow_new_entries=False,
                      objective=autodata_adapter.tuner_objective,
                      directory=fn_args.working_dir,
                      project_name=f'{problem_statement.tasks[0].name}_tuning')

    # TODO(nikhilmehta): Make batch-size tunable hyperparameter.
    train_dataset = autodata_adapter.get_dataset(
        file_pattern=fn_args.train_files,
        batch_size=128,
        num_epochs=None,
        shuffle=True)

    eval_dataset = autodata_adapter.get_dataset(
        file_pattern=fn_args.eval_files,
        batch_size=128,
        num_epochs=1,
        shuffle=False)

    return TunerFnResult(tuner=tuner,
                         fit_kwargs={
                             'x': train_dataset,
                             'validation_data': eval_dataset,
                             'steps_per_epoch': fn_args.train_steps,
                         })
Esempio n. 15
0
def run_fn(fn_args: trainer_executor.TrainerFnArgs):
    """Train a DNN Keras Model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
      - train_files: A list of uris for train files.
      - transform_output: An optional single uri for transform graph produced by
        TFT. Will be None if not specified.
      - serving_model_dir: A single uri for the output directory of the serving
        model.
      - eval_model_dir: A single uri for the output directory of the eval model.
        Note that this is for estimator only, Keras doesn't require it for TFMA.
      - eval_files:  A list of uris for eval files.
      - schema_file: A single uri for schema file.
      - train_steps: Number of train steps.
      - eval_steps: Number of eval steps.
      - base_model: Base model that will be used for this training job.
      - hyperparameters: An optional kerastuner.HyperParameters config.
      - custom_config: A dict with a single 'problem_statement' entry containing
        a text-format serialized ProblemStatement proto which defines the task.
  """

    # Use EstimatorAdapter here because we will wrap the Keras Model into an
    # Estimator for training and export.
    autodata_adapter = ea.EstimatorAdapter(
        problem_statement=text_format.Parse(
            fn_args.custom_config['problem_statement'],
            ps_pb2.ProblemStatement()),
        transform_graph_dir=fn_args.transform_output)

    if fn_args.hyperparameters:
        hparams = kerastuner.HyperParameters.from_config(
            fn_args.hyperparameters)
    else:
        hparams = _get_hyperparameters()
    logging.info('HyperParameters for training: %s', hparams.get_config())

    # Use KerasAdapter here because we create need it to create the Keras Model.
    keras_autodata_adapter = kma.KerasModelAdapter(
        problem_statement=text_format.Parse(
            fn_args.custom_config['problem_statement'],
            ps_pb2.ProblemStatement()),
        transform_graph_dir=fn_args.transform_output)
    model = _build_keras_model(hparams,
                               keras_autodata_adapter,
                               sequence_length=fn_args.custom_config.get(
                                   'sequence_length', None))

    train_spec = tf.estimator.TrainSpec(input_fn=autodata_adapter.get_input_fn(
        file_pattern=fn_args.train_files,
        batch_size=64,
        num_epochs=None,
        shuffle=True),
                                        max_steps=fn_args.train_steps)

    serving_receiver_fn = autodata_adapter.get_serving_input_receiver_fn()
    exporters = [
        tf.estimator.FinalExporter('serving_model_dir', serving_receiver_fn),
    ]
    eval_spec = tf.estimator.EvalSpec(
        input_fn=autodata_adapter.get_input_fn(file_pattern=fn_args.eval_files,
                                               batch_size=64,
                                               num_epochs=1,
                                               shuffle=False),
        steps=fn_args.eval_steps,
        exporters=exporters,
        # Since eval runs in parallel, we can begin evaluation as soon as new
        # checkpoints are written.
        start_delay_secs=1,
        throttle_secs=5)

    run_config = tf.estimator.RunConfig(model_dir=fn_args.serving_model_dir,
                                        save_checkpoints_steps=999,
                                        keep_checkpoint_max=3)

    estimator = tf.keras.estimator.model_to_estimator(model, config=run_config)
    logging.info('Training model...')
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    logging.info('Training complete.')

    # Export an eval savedmodel for TFMA. If distributed training, it must only
    # be written by the chief worker, as would be done for serving savedmodel.
    if run_config.is_chief:
        logging.info('Exporting eval_savedmodel for TFMA.')
        tfma.export.export_eval_savedmodel(
            estimator=estimator,
            export_dir_base=fn_args.eval_model_dir,
            eval_input_receiver_fn=autodata_adapter.get_eval_input_receiver_fn(
            ))

        logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
    else:
        logging.info('eval_savedmodel export for TFMA is skipped because '
                     'this is not the chief worker.')
Esempio n. 16
0
 def encode(
     self,
     component_spec: Optional[tfx_types.ComponentSpec] = None
 ) -> message.Message:
     # Return an arbitrary proto.
     return ps_pb2.ProblemStatement()
Esempio n. 17
0
def run_fn(fn_args: trainer_executor.TrainerFnArgs):
    """Train a DNNEstimator based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
      - train_files: A list of uris for train files.
      - transform_output: An optional single uri for transform graph produced by
        TFT. Will be None if not specified.
      - serving_model_dir: A single uri for the output directory of the serving
        model.
      - eval_model_dir: A single uri for the output directory of the eval model.
        Note that this is estimator only, Keras doesn't require it for TFMA.
      - eval_files:  A list of uris for eval files.
      - schema_file: A single uri for schema file.
      - train_steps: Number of train steps.
      - eval_steps: Number of eval steps.
      - base_model: Base model that will be used for this training job.
      - hyperparameters: An optional kerastuner.HyperParameters config.
      - custom_config: A dict with a single 'problem_statement' entry containing
        a text-format serialized ProblemStatement proto which defines the task.
  """
    sequence_length = fn_args.custom_config.get('sequence_length', None)
    if sequence_length:
        raise ValueError('Sequential prediction tasks are not supported. '
                         'Set `use_keras=True` in AutoTrainer instead.')

    autodata_adapter = estimator_adapter.EstimatorAdapter(
        problem_statement=text_format.Parse(
            fn_args.custom_config['problem_statement'],
            ps_pb2.ProblemStatement()),
        transform_graph_dir=fn_args.transform_output)

    run_config = tf.estimator.RunConfig(model_dir=fn_args.serving_model_dir,
                                        save_checkpoints_steps=999,
                                        keep_checkpoint_max=3)

    estimator = tf.estimator.DNNEstimator(
        head=autodata_adapter.head,
        feature_columns=autodata_adapter.get_dense_feature_columns(),
        hidden_units=[128, 128],
        config=run_config)

    train_spec = tf.estimator.TrainSpec(input_fn=autodata_adapter.get_input_fn(
        file_pattern=fn_args.train_files,
        batch_size=64,
        num_epochs=None,
        shuffle=True),
                                        max_steps=fn_args.train_steps)

    serving_receiver_fn = autodata_adapter.get_serving_input_receiver_fn()
    exporters = [
        tf.estimator.FinalExporter('serving_model_dir', serving_receiver_fn),
    ]

    eval_spec = tf.estimator.EvalSpec(
        input_fn=autodata_adapter.get_input_fn(file_pattern=fn_args.eval_files,
                                               batch_size=64,
                                               num_epochs=1,
                                               shuffle=False),
        steps=fn_args.eval_steps,
        exporters=exporters,
        # Since eval runs in parallel, we can begin evaluation as soon as new
        # checkpoints are written.
        start_delay_secs=1,
        throttle_secs=5)

    # Train/Tune the model
    logging.info('Training model...')
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    logging.info('Training complete.')

    # Export an eval savedmodel for TFMA. If distributed training, it must only
    # be written by the chief worker, as would be done for serving savedmodel.
    if run_config.is_chief:
        logging.info('Exporting eval_savedmodel for TFMA.')
        tfma.export.export_eval_savedmodel(
            estimator=estimator,
            export_dir_base=fn_args.eval_model_dir,
            eval_input_receiver_fn=autodata_adapter.get_eval_input_receiver_fn(
            ))

        logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
    else:
        logging.info('eval_savedmodel export for TFMA is skipped because '
                     'this is not the chief worker.')