Пример #1
0
    def testTrainBaselineModel(self):
        csv_path = self.data_path + "annotations/manual/dev.corpus.csv"
        default_types = [tf.string, tf.string, tf.string, tf.string]

        if not os.path.isfile(csv_path):
            self.skipTest(reason="Debug data not found.")

        dataset = tf.data.experimental.CsvDataset(
            filenames=csv_path,
            record_defaults=default_types,
            field_delim="|",
            header=True,
        )
        dataset = dataset.map(
            create_parse_fn(self.features_path, self.vocab_file))

        def slice_fn(frames, label):
            return frames[:32, :, :, :], label

        dataset = dataset.map(slice_fn)
        dataset = dataset.padded_batch(2,
                                       padded_shapes=([None, 224, 224,
                                                       3], [None]))

        model = BaselineModel(vocab_size=980)
        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=[WER(),
                     tf.keras.metrics.SparseCategoricalAccuracy()],
        )
        model.fit(dataset, validation_data=dataset, epochs=100)
Пример #2
0
    def testParseExample(self):
        if not os.path.isdir(self.features_path):
            self.skipTest(reason="Debug data not found.")

        parse_example = create_parse_fn(self.features_path, self.vocab_file)
        frames, label = parse_example(self.id, self.folder, self.signer,
                                      self.annotation)

        self.assertEqual(frames.shape, (self.num_frames, 224, 224, 3))
        self.assertAllEqual(label, self.tokenized_annotation)
Пример #3
0
 def make_dataset() -> tf.data.Dataset:
     dataset = tf.data.experimental.CsvDataset(
         filenames=self.data_config["validation_csv"],
         record_defaults=[tf.string, tf.string, tf.string, tf.string],
         field_delim="|",
         header=True,
     )
     dataset = dataset.map(
         create_parse_fn(
             self.data_config["features_path"] + "dev/",
             self.data_config["vocab_file"],
         ))
     return dataset
Пример #4
0
    def testMakeDataset(self):
        csv_path = self.data_path + "annotations/manual/dev.corpus.csv"
        default_types = [tf.string, tf.string, tf.string, tf.string]

        if not os.path.isfile(csv_path):
            self.skipTest(reason="Debug data not found.")

        dataset = tf.data.experimental.CsvDataset(
            filenames=csv_path,
            record_defaults=default_types,
            field_delim="|",
            header=True,
        )
        dataset = dataset.map(
            create_parse_fn(self.features_path, self.vocab_file))