Пример #1
0
    def test_label_has_shape_defined(self):
        serialized = 'fake'
        decoder = fsns.get_split('train', dataset_dir()).decoder

        [label_tf] = decoder.decode(serialized, ['label'])

        self.assertEqual(label_tf.get_shape().dims[0], 37)
Пример #2
0
    def test_decodes_example_proto(self):
        expected_label = range(37)
        expected_image, encoded = unittest_utils.create_random_image(
            'PNG', shape=(150, 600, 3))
        serialized = unittest_utils.create_serialized_example({
            'image/encoded': [encoded],
            'image/format': [b'PNG'],
            'image/class':
            expected_label,
            'image/unpadded_class':
            range(10),
            'image/text': [b'Raw text'],
            'image/orig_width': [150],
            'image/width': [600]
        })

        decoder = fsns.get_split('train', dataset_dir()).decoder
        with self.test_session() as sess:
            data_tuple = collections.namedtuple('DecodedData',
                                                decoder.list_items())
            data = sess.run(data_tuple(*decoder.decode(serialized)))

        self.assertAllEqual(expected_image, data.image)
        self.assertAllEqual(expected_label, data.label)
        self.assertEqual([b'Raw text'], data.text)
        self.assertEqual([1], data.num_of_views)
Пример #3
0
def get_split(split_name, dataset_dir=None, config=None):
    if not dataset_dir:
        dataset_dir = DEFAULT_DATASET_DIR
    if not config:
        config = DEFAULT_CONFIG

    return fsns.get_split(split_name, dataset_dir, config)
Пример #4
0
    def test_dataset_tuple_has_all_extra_attributes(self):
        dataset = fsns.get_split('train', dataset_dir())

        self.assertTrue(dataset.charset)
        self.assertTrue(dataset.num_char_classes)
        self.assertTrue(dataset.num_of_views)
        self.assertTrue(dataset.max_sequence_length)
        self.assertTrue(dataset.null_code)
Пример #5
0
def get_test_split():
    config = fsns.DEFAULT_CONFIG.copy()
    config['splits'] = {'test': {'size': 5, 'pattern': 'fsns-00000-of-00001'}}
    return fsns.get_split('test', dataset_dir(), config)