示例#1
0
    def testSequenceRecord(self):
        vector = np.array([[0.2, 0.3], [0.4, 0.5]], dtype=np.float32)

        record_file = os.path.join(self.get_temp_dir(), "data.records")
        record_inputter.create_sequence_records([vector], record_file)

        inputter = record_inputter.SequenceRecordInputter(2)
        features, transformed = self._makeDataset(
            inputter,
            record_file,
            dataset_size=None,
            shapes={"tensor": [None, None, 2], "length": [None]},
        )

        self.assertEqual([2], features["length"])
        self.assertAllEqual([vector], features["tensor"])
        self.assertAllEqual([vector], transformed)
示例#2
0
 def testSequenceRecordWithCompression(self):
     vector = np.array([[0.2, 0.3], [0.4, 0.5]], dtype=np.float32)
     compression = "GZIP"
     record_file = os.path.join(self.get_temp_dir(), "data.records")
     record_file = record_inputter.create_sequence_records(
         [vector], record_file, compression=compression)
     inputter = record_inputter.SequenceRecordInputter(2)
     dataset = inputter.make_inference_dataset(record_file, batch_size=1)
     iterator = iter(dataset)
     self.assertAllEqual(next(iterator)["tensor"].numpy()[0], vector)
示例#3
0
    def testSequenceRecordBatch(self):
        vectors = [
            np.random.rand(3, 2),
            np.random.rand(6, 2),
            np.random.rand(1, 2),
        ]

        record_file = os.path.join(self.get_temp_dir(), "data.records")
        record_inputter.create_sequence_records(vectors, record_file)

        inputter = record_inputter.SequenceRecordInputter(2)
        dataset = inputter.make_dataset(record_file)
        dataset = dataset.batch(3)
        dataset = dataset.map(inputter.make_features)

        features = next(iter(dataset))
        lengths = features["length"]
        tensors = features["tensor"]
        self.assertAllEqual(lengths, [3, 6, 1])
        for length, tensor, expected_vector in zip(lengths, tensors, vectors):
            self.assertAllClose(tensor[:length], expected_vector)