Exemple #1
0
def _dump_examples_to_tfrecord(path, examples):
    """Writes list of example dicts to a TFRecord file of tf.Example protos."""
    logging.info("Writing examples to TFRecord: %s", path)
    with tf.io.TFRecordWriter(path) as writer:
        for ex in examples:
            writer.write(
                dataset_utils.dict_to_tfexample(ex).SerializeToString())
Exemple #2
0
  def test_dict_to_tfexample(self):
    tfe = utils.dict_to_tfexample({
        "inputs": "this is an input",
        "targets": "this is a target",
        "weight": 5.0,
        "idx1": np.array([1, 2], np.int32),
        "idx2": np.array([3, 4], np.int64),
        "is_correct": False,
    })

    self.assertLen(tfe.features.feature, 6)
    self.assertEqual(tfe.features.feature["inputs"].bytes_list.value,
                     [b"this is an input"])
    self.assertEqual(tfe.features.feature["targets"].bytes_list.value,
                     [b"this is a target"])
    self.assertEqual(tfe.features.feature["weight"].float_list.value, [5.0])
    np.testing.assert_array_equal(
        tfe.features.feature["idx1"].int64_list.value,
        np.array([1, 2], np.int64))
    np.testing.assert_array_equal(
        tfe.features.feature["idx2"].int64_list.value,
        np.array([3, 4], np.int64))
    np.testing.assert_array_equal(
        tfe.features.feature["is_correct"].int64_list.value,
        np.array([0], np.int64))
Exemple #3
0
    def test_dict_to_tfexample(self):
        tfe = utils.dict_to_tfexample({
            "inputs": "this is an input",
            "targets": "this is a target",
            "weight": 5.0,
        })

        self.assertLen(tfe.features.feature, 3)
        self.assertEqual(tfe.features.feature["inputs"].bytes_list.value,
                         [b"this is an input"])
        self.assertEqual(tfe.features.feature["targets"].bytes_list.value,
                         [b"this is a target"])
        self.assertEqual(tfe.features.feature["weight"].float_list.value,
                         [5.0])