def test_get_shape_from_examples_path_invalid_path(self): with self.assertRaisesRegexp(Exception, '/this/path/does/not'): data_providers.DeepVariantInput( mode=tf.estimator.ModeKeys.PREDICT, name='test_invalid_path', input_file_spec='/this/path/does/not/exist', num_examples=1)
def test_max_examples_overrides_num_examples(self, num_examples, max_examples, expected): dataset = data_providers.DeepVariantInput( # Use predict mode so we can have num_examples == None. mode=tf.estimator.ModeKeys.PREDICT, input_file_spec=testdata.GOLDEN_TRAINING_EXAMPLES, num_examples=num_examples, max_examples=max_examples) self.assertEqual(expected, dataset.num_examples)
def get_batch_feed(self, batch_size=1, use_tpu=False): # This is an input_fn reading test_utils.N_GOLDEN_CALLING_EXAMPLES records. # Use PREDICT mode so we get finite input. dvi = data_providers.DeepVariantInput( mode=tf.estimator.ModeKeys.PREDICT, input_file_spec=testdata.GOLDEN_CALLING_EXAMPLES, num_examples=testdata.N_GOLDEN_CALLING_EXAMPLES, tensor_shape=None, use_tpu=use_tpu) params = {'batch_size': batch_size} batch_feed = dvi(params).make_one_shot_iterator().get_next() return batch_feed
def test_get_shape_from_examples_path(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend(valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) tfrecord.write_tfrecords([example], output_file) ds = data_providers.DeepVariantInput( mode=tf.estimator.ModeKeys.PREDICT, name='test_shape', input_file_spec=test_utils.test_tmpfile(tfrecord_path_to_match), num_examples=1) self.assertEqual(valid_shape, ds.tensor_shape)
def test_dataset_definition(self): ds = data_providers.DeepVariantInput( mode=tf.estimator.ModeKeys.PREDICT, name='name', input_file_spec='test.tfrecord', num_examples=10, num_classes=dv_constants.NUM_CLASSES, tensor_shape=[11, 13, dv_constants.PILEUP_NUM_CHANNELS]) self.assertEqual('name', ds.name) self.assertEqual('test.tfrecord', ds.input_file_spec) self.assertEqual(10, ds.num_examples) self.assertEqual(dv_constants.NUM_CLASSES, ds.num_classes) self.assertEqual([11, 13, dv_constants.PILEUP_NUM_CHANNELS], ds.tensor_shape)