示例#1
0
 def test_shard_input_file_with_more_shards(self):
     input_file_pattern = os.path.join(self._base_dir, "*.tfrecord")
     shard_files, indicator = shard_input_files(input_file_pattern, 20, 1)
     expected_files = [os.path.join(self._base_dir, '1.tfrecord')]
     self.assertAllEqual(shard_files, expected_files)
     self.assertTrue(indicator)
     shard_files, indicator = shard_input_files(input_file_pattern, 20, 19)
     self.assertEqual(len(shard_files), 0)
     self.assertTrue(indicator)
    def predict(self, output_dir, input_data_path, metadata_file,
                checkpoint_path, execution_context, schema_params):
        n_records = 0
        n_batch = 0
        # Predict on the dataset
        sharded_dataset_paths, file_level_sharding = shard_input_files(
            input_data_path, execution_context[constants.NUM_SHARDS],
            execution_context[constants.SHARD_INDEX])
        if file_level_sharding and len(sharded_dataset_paths) == 0:
            logger.info("No input dataset is found, returning...")
            return

        inference_dataset = lambda: input_fn(
            input_pattern=','.join(sharded_dataset_paths),  # noqa: E731
            # DeText uses metadata_path
            metadata_path=self.model_params.metadata_path,
            batch_size=self.model_params.test_batch_size,
            mode=tf.estimator.ModeKeys.EVAL,
            vocab_table=vocab_utils.read_tf_vocab(self.model_params.vocab_file,
                                                  self.model_params.UNK),
            vocab_table_for_id_ftr=vocab_utils.read_tf_vocab(
                self.model_params.vocab_file_for_id_ftr, self.model_params.
                UNK_FOR_ID_FTR),
            feature_names=self.model_params.feature_names,
            CLS=self.model_params.CLS,
            SEP=self.model_params.SEP,
            PAD=self.model_params.PAD,
            PAD_FOR_ID_FTR=self.model_params.PAD_FOR_ID_FTR,
            max_len=self.model_params.max_len,
            min_len=self.model_params.min_len,
            cnn_filter_window_size=max(self.model_params.filter_window_sizes)
            if self.model_params.ftr_ext == 'cnn' else 0)

        self.estimator_based_model = detext_train.get_estimator(
            self.model_params,
            strategy=None,  # local mode
            best_checkpoint=self.best_checkpoint)
        output = self.estimator_based_model.predict(
            inference_dataset, yield_single_examples=False)
        detext_writer = DetextWriter(schema_params=schema_params)
        shard_index = execution_context[constants.SHARD_INDEX]
        output_file = os.path.join(output_dir,
                                   "part-{0:05d}.avro".format(shard_index))
        for batch_score in output:
            if n_batch == 0:
                with tf.io.gfile.GFile(output_file, 'wb') as f:
                    f.seekable = lambda: False
                    n_records, n_batch = detext_writer.save_batch(
                        f, batch_score, output_file, n_records, n_batch)
            else:
                with tf.io.gfile.GFile(output_file, 'ab+') as f:
                    f.seek(0, 2)
                    f.seekable = lambda: True
                    f.readable = lambda: True
                    n_records, n_batch = detext_writer.save_batch(
                        f, batch_score, output_file, n_records, n_batch)
        logger.info("{} batches, e.g. {} records inferenced".format(
            n_batch, n_records))
 def _get_assigned_files(input_data_path, num_shards, shard_index):
     """
     Get the assigned files from the shard
     :param input_data_path:
     :return: a list of assigned file names.
     """
     assigned_files, sample_level_shard = shard_input_files(input_data_path, num_shards, shard_index)
     assert not sample_level_shard, "Doesn't support sample level sharding," \
                                    "number of files must >= number of workers"
     return assigned_files
示例#4
0
 def test_shard_input_files_with_wrong_params(self):
     with self.assertRaises(AssertionError):
         shard_input_files(self._base_dir, 1, 2)
     with self.assertRaises(AssertionError):
         shard_input_files(self._base_dir, -1, -2)
     with self.assertRaises(tf.errors.NotFoundError):
         shard_input_files(os.path.join(self._base_dir, "nowhere/nofile"), 3, 2)
    def predict(self,
                output_dir,
                input_data_path,
                metadata_file,
                checkpoint_path,
                execution_context,
                schema_params):
        n_records = 0
        n_batch = 0
        # Predict on the dataset
        sharded_dataset_paths, file_level_sharding = shard_input_files(input_data_path,
                                                                       execution_context[constants.NUM_SHARDS],
                                                                       execution_context[constants.SHARD_INDEX])
        if file_level_sharding and len(sharded_dataset_paths) == 0:
            logger.info("No input dataset is found, returning...")
            return

        inference_dataset = input_fn_tfrecord(input_pattern=','.join(sharded_dataset_paths),  # noqa: E731
                                              batch_size=self.model_params.test_batch_size,
                                              mode=tf.estimator.ModeKeys.EVAL,
                                              feature_type2name=self.model_params.feature_type2name,
                                              feature_name2num=self.model_params.feature_name2num,
                                              task_type=self.model_params.task_type)

        self.model = train_model_helper.load_model_with_ckpt(
            parsing_utils.HParams(**asdict(self.model_params)),
            self.best_checkpoint)
        output = train_flow_helper.predict_with_additional_info(inference_dataset,
                                                                self.model,
                                                                self.model_params.feature_type2name)
        detext_writer = DetextWriter(schema_params=schema_params)
        shard_index = execution_context[constants.SHARD_INDEX]
        output_file = os.path.join(output_dir, f"part-{shard_index:05d}.avro")
        for batch_score in output:
            if n_batch == 0:
                with tf.io.gfile.GFile(output_file, 'wb') as f:
                    f.seekable = lambda: False
                    n_records, n_batch = detext_writer.save_batch(f, batch_score, output_file,
                                                                  n_records, n_batch)
            else:
                with tf.io.gfile.GFile(output_file, 'ab+') as f:
                    f.seek(0, 2)
                    f.seekable = lambda: True
                    f.readable = lambda: True
                    n_records, n_batch = detext_writer.save_batch(f, batch_score, output_file,
                                                                  n_records, n_batch)
        logger.info(f"{n_batch} records, e.g. {n_records} records inferenced")
示例#6
0
 def test_shard_input_file_with_filename_pattern(self):
     input_file_pattern = os.path.join(self._base_dir, "*.tfrecord")
     shard_files, indicator = shard_input_files(input_file_pattern, 3, 1)
     expected_files = [os.path.join(self._base_dir, f'{i}.tfrecord') for i in range(1, 10, 3)]
     self.assertAllEqual(shard_files, expected_files)
     self.assertFalse(indicator)
示例#7
0
 def test_shard_input_files_with_directory(self):
     shard_files, _ = shard_input_files(self._base_dir, 2, 0)
     expected_files = [os.path.join(self._base_dir, f'{i}.avro') for i in range(10)]
     self.assertAllEqual(shard_files, expected_files)
示例#8
0
def per_record_input_fn(input_path,
                        metadata_file,
                        num_shards,
                        shard_index,
                        batch_size,
                        data_format,
                        custom_input_fn=None):
    """
    Input function for per-record dataset. In the dataset, the records are individual examples.
    Batch size adds extra dimension on top of that.
    :param input_path: input directory or file pattern.
    :param metadata_file: tensor metadata file
    :param num_shards: number of shards
    :param shard_index: the index of this worker
    :param batch_size: batch size
    :param data_format: tfrecord or avro
    :param custom_input_fn: full name "package.module.fn" for the external custom input_fn.
    :return: a batched dataset.
    """
    if data_format == constants.TFRECORD:
        logger.info("using {} dataset".format(constants.TFRECORD))

        def build_features(tensors):
            """
            Create features from metadata, used to deserialize the tfrecord.
            :param tensors: list of metadata for all tensors.
            :return: tfrecord features
            """
            tf_features = {}
            for feature in tensors:
                if feature.isSparse:
                    # If this is a sparse tensor, we process indices and values separately.
                    # Note in the metadata, we don't see _indices and _values,
                    # only the feature name.
                    tf_features[feature.name] = tf.io.SparseFeature(
                        index_key=f"{feature.name}_{DatasetMetadata.INDICES}",
                        value_key=f"{feature.name}_{DatasetMetadata.VALUES}",
                        dtype=DatasetMetadata.map_int(feature.dtype),
                        size=_unpack_one_element_list(feature.shape))
                else:
                    tf_features[feature.name] = tf.io.FixedLenFeature(
                        shape=feature.shape,
                        dtype=DatasetMetadata.map_int(feature.dtype))
            return tf_features

        def map_fn(serialized, feature_tensors, label_tensors):
            """
            Deserialize TF records to features. This is done after batching since we are using
            tf.io.parse_example
            :param serialized: Serialized TF records
            :param feature_tensors: list of feature tensors
            :param label_tensors: list of label tensors
            :return: (features, labels) tuple where each of them is a map.
            """
            tensors = feature_tensors + label_tensors
            tf_features = build_features(tensors)
            example = tf.io.parse_example(serialized,
                                          tf_features,
                                          example_names=None,
                                          name=None)
            # then split features from labels
            return _splits_label_and_features(example, label_tensors)

        # Get shard input files
        input_filename_pattern = _convert_dir_to_filename_pattern(
            input_path, constants.TFRECORD_GLOB_PATTERN)
        input_files, _ = shard_input_files(input_filename_pattern, num_shards,
                                           shard_index)

        # Get metadata
        feature_tensors, label_tensors = _get_features_and_labels_info(
            metadata_file)

        # Batching
        dataset = tf.data.TFRecordDataset(input_files).batch(
            batch_size, drop_remainder=False)

        # Deserialize to features
        dataset = dataset.map(partial(map_fn,
                                      feature_tensors=feature_tensors,
                                      label_tensors=label_tensors),
                              num_parallel_calls=16)
    elif custom_input_fn is not None:
        logger.info("loading {} dataset by {}".format(data_format,
                                                      custom_input_fn))
        import importlib
        module_name, fn_name = custom_input_fn.rsplit('.', 1)
        dataset_module = importlib.import_module(module_name)
        dataset = getattr(dataset_module,
                          fn_name)(input_path, metadata_file, num_shards,
                                   shard_index, batch_size, data_format)
    else:
        raise Exception("Unknown data format :{}".format(data_format))
    return dataset
示例#9
0
def per_entity_grouped_input_fn(input_path,
                                metadata_file,
                                num_shards,
                                shard_index,
                                batch_size,
                                data_format,
                                entity_name,
                                custom_input_fn=None):
    """
    Input function for per-entity grouped dataset. In the dataset, the records are grouped based on
    entity Id. each feature is a vector except the entity Id which is a scalar. Batch size adds extra
    dimension on top of that.
    :param input_path: input directory or file pattern.
    :param metadata_file: tensor metadata file
    :param num_shards: number of shards
    :param shard_index: the index of this worker
    :param batch_size: batch size
    :param data_format: tfrecord or avro
    :param entity_name: the name of the entity which is used to group the records.
    :param custom_input_fn: full name "package.module.fn" for the external custom input_fn.
    :return: a batched dataset.
    """

    if data_format == constants.TFRECORD:
        logger.info(f"using {constants.TFRECORD} dataset")

        # Build features
        def build_features(tensors, entity_name):
            """
            Create features from metadata, used to deserialize the tfrecord.
            :param tensors: list of metadata for all tensors.
            :param entity_name: entity by which the records are grouped.
            :return: a tuple of context_features and sequence_features
            """
            sequence_features = dict()
            context_features = dict()
            for tensor in tensors:
                tensor_dtype = DatasetMetadata.map_int(tensor.dtype)
                if tensor.name == entity_name:
                    # entity_name column is a scalar
                    context_features[entity_name] = tf.io.FixedLenFeature(
                        shape=[], dtype=tensor_dtype)
                else:
                    if tensor.isSparse:
                        # If this is a sparse tensor, we process indices and values separately.
                        # Note in the metadata, we don't see _indices and _values,
                        # only the feature name.
                        indices_name = f"{tensor.name}_{DatasetMetadata.INDICES}"
                        values_name = f"{tensor.name}_{DatasetMetadata.VALUES}"
                        sequence_features[indices_name] = tf.io.VarLenFeature(
                            dtype=tf.int64)
                        sequence_features[values_name] = tf.io.VarLenFeature(
                            dtype=tensor_dtype)
                    else:
                        context_features[tensor.name] = tf.io.VarLenFeature(
                            dtype=tensor_dtype)
            if len(sequence_features) == 0:
                sequence_features = None
            if len(context_features) == 0:
                context_features = None
            return context_features, sequence_features

        def map_fn(serialized, feature_tensors, label_tensors, entity_name):
            """
            Map serialized tfrecord to a dict of tensors.
            :param serialized: serialized tfrecord
            :param feature_tensors: list of metadata for features
            :param label_tensors: list of metadata for labels
            :param entity_name: entity by which the records are grouped.
            :return: A tuple (features, labels)
            """
            tensors = feature_tensors + label_tensors
            context_features, sequence_features = build_features(
                tensors, entity_name)
            example = tf.io.parse_sequence_example(
                serialized,
                context_features=context_features,
                sequence_features=sequence_features,
                example_names=None,
                name=None)

            # Split features from labels
            context, sequence = example[0], example[1]
            sequence.update(context)
            return _splits_label_and_features(sequence, label_tensors)

        # Get shard input files
        input_filename_pattern = _convert_dir_to_filename_pattern(
            input_path, constants.TFRECORD_GLOB_PATTERN)
        input_files, _ = shard_input_files(input_filename_pattern, num_shards,
                                           shard_index)

        # Get metadata
        feature_tensors, label_tensors = _get_features_and_labels_info(
            metadata_file)

        # Check if entity_name is one of the features
        feature_names = [x.name for x in feature_tensors]
        if entity_name not in feature_names:
            raise ValueError(
                f"entity name {entity_name} is not found among the features")

        # Batching
        dataset = tf.data.TFRecordDataset(input_files).batch(
            batch_size, drop_remainder=False)

        # Deserialize to features
        dataset = dataset.map(partial(map_fn,
                                      feature_tensors=feature_tensors,
                                      label_tensors=label_tensors,
                                      entity_name=entity_name),
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)

    elif custom_input_fn:
        logger.info(f"loading {data_format} dataset by {custom_input_fn}")
        import importlib
        module_name, fn_name = custom_input_fn.rsplit('.', 1)
        dataset_module = importlib.import_module(module_name)
        dataset = getattr(dataset_module,
                          fn_name)(input_path, metadata_file, num_shards,
                                   shard_index, batch_size, entity_name,
                                   data_format)
    else:
        raise Exception(f"Unknown data format : {data_format}")
    return dataset