示例#1
0
 def test_get_shape_from_examples_path_invalid_path(self):
     with self.assertRaisesRegexp(Exception, '/this/path/does/not'):
         data_providers.DeepVariantInput(
             mode=tf.estimator.ModeKeys.PREDICT,
             name='test_invalid_path',
             input_file_spec='/this/path/does/not/exist',
             num_examples=1)
 def test_max_examples_overrides_num_examples(self, num_examples,
                                              max_examples, expected):
     dataset = data_providers.DeepVariantInput(
         # Use predict mode so we can have num_examples == None.
         mode=tf.estimator.ModeKeys.PREDICT,
         input_file_spec=testdata.GOLDEN_TRAINING_EXAMPLES,
         num_examples=num_examples,
         max_examples=max_examples)
     self.assertEqual(expected, dataset.num_examples)
示例#3
0
 def get_batch_feed(self, batch_size=1, use_tpu=False):
     # This is an input_fn reading test_utils.N_GOLDEN_CALLING_EXAMPLES records.
     # Use PREDICT mode so we get finite input.
     dvi = data_providers.DeepVariantInput(
         mode=tf.estimator.ModeKeys.PREDICT,
         input_file_spec=testdata.GOLDEN_CALLING_EXAMPLES,
         num_examples=testdata.N_GOLDEN_CALLING_EXAMPLES,
         tensor_shape=None,
         use_tpu=use_tpu)
     params = {'batch_size': batch_size}
     batch_feed = dvi(params).make_one_shot_iterator().get_next()
     return batch_feed
示例#4
0
 def test_get_shape_from_examples_path(self, file_name_to_write,
                                       tfrecord_path_to_match):
   example = example_pb2.Example()
   valid_shape = [1, 2, 3]
   example.features.feature['image/shape'].int64_list.value.extend(valid_shape)
   output_file = test_utils.test_tmpfile(file_name_to_write)
   tfrecord.write_tfrecords([example], output_file)
   ds = data_providers.DeepVariantInput(
       mode=tf.estimator.ModeKeys.PREDICT,
       name='test_shape',
       input_file_spec=test_utils.test_tmpfile(tfrecord_path_to_match),
       num_examples=1)
   self.assertEqual(valid_shape, ds.tensor_shape)
示例#5
0
 def test_dataset_definition(self):
     ds = data_providers.DeepVariantInput(
         mode=tf.estimator.ModeKeys.PREDICT,
         name='name',
         input_file_spec='test.tfrecord',
         num_examples=10,
         num_classes=dv_constants.NUM_CLASSES,
         tensor_shape=[11, 13, dv_constants.PILEUP_NUM_CHANNELS])
     self.assertEqual('name', ds.name)
     self.assertEqual('test.tfrecord', ds.input_file_spec)
     self.assertEqual(10, ds.num_examples)
     self.assertEqual(dv_constants.NUM_CLASSES, ds.num_classes)
     self.assertEqual([11, 13, dv_constants.PILEUP_NUM_CHANNELS],
                      ds.tensor_shape)