Example #1
0
 def test_get_shape_from_examples_path_invalid_path(self):
     with self.assertRaisesRegexp(Exception, '/this/path/does/not'):
         data_providers.DeepVariantDataSet(
             name='test_invalid_path',
             source='/this/path/does/not/exist',
             num_examples=1,
             num_classes=3)
 def test_dataset_definition(self):
     ds = data_providers.DeepVariantDataSet(name='name',
                                            source='test.tfrecord',
                                            num_examples=10,
                                            num_classes=2,
                                            tensor_shape=[11, 13, 7])
     self.assertEqual('name', ds.name)
     self.assertEqual('test.tfrecord', ds.source)
     self.assertEqual(10, ds.num_examples)
     self.assertEqual(2, ds.num_classes)
     self.assertEqual([11, 13, 7], ds.tensor_shape)
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = test_utils.GOLDEN_TRAINING_EXAMPLES
    return data_providers.DeepVariantDataSet(name='labeled_golden',
                                             source=source_path,
                                             num_examples=49)
 def test_get_shape_from_examples_path(self, file_name_to_write,
                                       tfrecord_path_to_match):
     example = example_pb2.Example()
     valid_shape = [1, 2, 3]
     example.features.feature['image/shape'].int64_list.value.extend(
         valid_shape)
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([example], output_file)
     ds = data_providers.DeepVariantDataSet(
         name='test_shape',
         source=test_utils.test_tmpfile(tfrecord_path_to_match),
         num_examples=1)
     self.assertEqual(valid_shape, ds.tensor_shape)