Пример #1
0
def _input_fn(file_pattern: List[Text],
              data_accessor: DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
    """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
    return data_accessor.tf_dataset_factory(
        file_pattern,
        dataset_options.TensorFlowDatasetOptions(
            batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)),
        tf_transform_output.transformed_metadata.schema)
Пример #2
0
def _input_fn(file_pattern: List[Text],
              data_accessor: DataAccessor,
              schema: schema_pb2.Schema,
              batch_size: int = 20) -> tf.data.Dataset:
    """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    schema: Schema of the input data.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
    return data_accessor.tf_dataset_factory(
        file_pattern,
        dataset_options.TensorFlowDatasetOptions(batch_size=batch_size,
                                                 label_key=_LABEL_KEY),
        schema).repeat()
Пример #3
0
def _input_fn(
    file_pattern: Text,
    data_accessor: DataAccessor,
    schema: schema_pb2.Schema,
    batch_size: int = 20,
) -> Tuple[np.ndarray, np.ndarray]:
    """Generates features and label for tuning/training.

  Args:
    file_pattern: input tfrecord file pattern.
    data_accessor: DataAccessor for converting input to RecordBatch.
    schema: schema of the input data.
    batch_size: An int representing the number of records to combine in a single
      batch.

  Returns:
    A (features, indices) tuple where features is a matrix of features, and
      indices is a single vector of label indices.
  """
    record_batch_iterator = data_accessor.record_batch_factory(
        file_pattern,
        dataset_options.RecordBatchesOptions(batch_size=batch_size,
                                             num_epochs=1), schema)

    feature_list = []
    label_list = []
    for record_batch in record_batch_iterator:
        record_dict = {}
        for column, field in zip(record_batch, record_batch.schema):
            record_dict[field.name] = column.flatten()

        label_list.append(record_dict[_LABEL_KEY])
        features = [record_dict[key] for key in _FEATURE_KEYS]
        feature_list.append(np.stack(features, axis=-1))

    return np.concatenate(feature_list), np.concatenate(label_list)
Пример #4
0
    def testTrainerFn(self):
        temp_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        schema_file = os.path.join(self._testdata_path,
                                   'schema_gen/schema.pbtxt')
        output_dir = os.path.join(temp_dir, 'output_dir')
        data_accessor = DataAccessor(tf_dataset_factory=tfxio_utils.
                                     get_tf_dataset_factory_from_artifact(
                                         [standard_artifacts.Examples()], []),
                                     record_batch_factory=None)
        trainer_fn_args = trainer_executor.TrainerFnArgs(
            train_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/train/*.gz'),
            transform_output=os.path.join(self._testdata_path,
                                          'transform/transform_graph'),
            output_dir=output_dir,
            serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'),
            eval_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/eval/*.gz'),
            schema_file=schema_file,
            train_steps=1,
            eval_steps=1,
            verbosity='INFO',
            base_model=None,
            data_accessor=data_accessor)
        schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
        training_spec = taxi_utils.trainer_fn(trainer_fn_args, schema)

        estimator = training_spec['estimator']
        train_spec = training_spec['train_spec']
        eval_spec = training_spec['eval_spec']
        eval_input_receiver_fn = training_spec['eval_input_receiver_fn']

        self.assertIsInstance(estimator,
                              tf.estimator.DNNLinearCombinedClassifier)
        self.assertIsInstance(train_spec, tf.estimator.TrainSpec)
        self.assertIsInstance(eval_spec, tf.estimator.EvalSpec)
        self.assertIsInstance(eval_input_receiver_fn, types.FunctionType)

        # Test keep_max_checkpoint in RunConfig
        self.assertGreater(estimator._config.keep_checkpoint_max, 1)

        # Train for one step, then eval for one step.
        eval_result, exports = tf.estimator.train_and_evaluate(
            estimator, train_spec, eval_spec)
        self.assertGreater(eval_result['loss'], 0.0)
        self.assertEqual(len(exports), 1)
        self.assertGreaterEqual(len(tf.io.gfile.listdir(exports[0])), 1)

        # Export the eval saved model.
        eval_savedmodel_path = tfma.export.export_eval_savedmodel(
            estimator=estimator,
            export_dir_base=path_utils.eval_model_dir(output_dir),
            eval_input_receiver_fn=eval_input_receiver_fn)
        self.assertGreaterEqual(len(tf.io.gfile.listdir(eval_savedmodel_path)),
                                1)

        # Test exported serving graph.
        with tf.compat.v1.Session() as sess:
            metagraph_def = tf.compat.v1.saved_model.loader.load(
                sess, [tf.saved_model.SERVING], exports[0])
            self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef)