def test_get_shape_from_examples_path_invalid_path(self): with self.assertRaisesRegexp(Exception, '/this/path/does/not'): data_providers.DeepVariantDataSet( name='test_invalid_path', source='/this/path/does/not/exist', num_examples=1, num_classes=3)
def test_dataset_definition(self): ds = data_providers.DeepVariantDataSet(name='name', source='test.tfrecord', num_examples=10, num_classes=2, tensor_shape=[11, 13, 7]) self.assertEqual('name', ds.name) self.assertEqual('test.tfrecord', ds.source) self.assertEqual(10, ds.num_examples) self.assertEqual(2, ds.num_classes) self.assertEqual([11, 13, 7], ds.tensor_shape)
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'make_golden_dataset.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords(test_utils.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = test_utils.GOLDEN_TRAINING_EXAMPLES return data_providers.DeepVariantDataSet(name='labeled_golden', source=source_path, num_examples=49)
def test_get_shape_from_examples_path(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend( valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([example], output_file) ds = data_providers.DeepVariantDataSet( name='test_shape', source=test_utils.test_tmpfile(tfrecord_path_to_match), num_examples=1) self.assertEqual(valid_shape, ds.tensor_shape)