class CallVariantsAcceleratorTests( six.with_metaclass(parameterized.TestGeneratorMetaclass, tf.test.TestCase)): @parameterized.parameters(modeling.production_models()) def test_call_variants_runs_on_gpus(self, model): call_variants.call_variants( examples_filename=testdata.GOLDEN_CALLING_EXAMPLES, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, execution_hardware='accelerator', output_file=test_utils.test_tmpfile('zzz.tfrecord'))
class CallVariantsEndToEndTests( six.with_metaclass(parameterized.TestGeneratorMetaclass, tf.test.TestCase)): def setUp(self): self.checkpoint_dir = tf.test.get_temp_dir() def assertCallVariantsEmitsNRecordsForInceptionV3(self, filename, num_examples): outfile = test_utils.test_tmpfile( 'inception_v3.call_variants.tfrecord') model = modeling.get_model('inception_v3') checkpoint_path = _LEAVE_MODEL_UNINITIALIZED call_variants.call_variants(examples_filename=filename, checkpoint_path=checkpoint_path, model=model, output_file=outfile, batch_size=4, max_batches=None) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual(len(call_variants_outputs), num_examples) def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename, num_examples): checkpoint_path = _LEAVE_MODEL_UNINITIALIZED outfile = test_utils.test_tmpfile('call_variants.tfrecord') model = modeling.get_model('random_guess') call_variants.call_variants(examples_filename=filename, checkpoint_path=checkpoint_path, model=model, output_file=outfile, batch_size=4, max_batches=None, master='', use_tpu=FLAGS.use_tpu) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual(len(call_variants_outputs), num_examples) def test_call_end2end_with_empty_shards(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) # Write to 15 shards, which means there will be multiple empty shards. source_path = test_utils.test_tmpfile('sharded@{}'.format(15)) io_utils.write_tfrecords(examples, source_path) self.assertCallVariantsEmitsNRecordsForRandomGuess( source_path, len(examples)) def test_call_end2end_empty_first_shard(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) empty_first_file = test_utils.test_tmpfile( 'empty_1st_shard-00000-of-00002') io_utils.write_tfrecords([], empty_first_file) second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002') io_utils.write_tfrecords(examples, second_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('empty_1st_shard@2'), len(examples)) def test_call_end2end_zero_record_file_for_inception_v3(self): zero_record_file = test_utils.test_tmpfile('zero_record_file') io_utils.write_tfrecords([], zero_record_file) self.assertCallVariantsEmitsNRecordsForInceptionV3( test_utils.test_tmpfile('zero_record_file'), 0) def _call_end2end_helper(self, examples_path, model, shard_inputs): examples = list(io_utils.read_tfrecords(examples_path)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) io_utils.write_tfrecords(examples, source_path) else: source_path = examples_path # If we point the test at a headless server, it will often be 2x2, # which has 8 replicas. Otherwise a smaller batch size is fine. if FLAGS.use_tpu: batch_size = 8 else: batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=_LEAVE_MODEL_UNINITIALIZED, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches, master='', use_tpu=FLAGS.use_tpu, ) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) return call_variants_outputs, examples, batch_size, max_batches @parameterized.parameters(model for model in modeling.production_models()) @flagsaver.FlagSaver def test_call_end2end_with_labels(self, model): FLAGS.debugging_true_label_mode = True (call_variants_outputs, examples, batch_size, max_batches) = self._call_end2end_helper( testdata.GOLDEN_TRAINING_EXAMPLES, model, False) # Check that we have the right number of output protos. self.assertEqual( len(call_variants_outputs), batch_size * max_batches if max_batches else len(examples)) # Checks that at least some of the `true_label`s are filled. self.assertTrue( any(cvo.debug_info.true_label > 0 for cvo in call_variants_outputs)) @parameterized.parameters(model for model in modeling.production_models()) @flagsaver.FlagSaver def test_call_end2end_no_labels_fails(self, model): FLAGS.debugging_true_label_mode = True if not FLAGS.use_tpu: # On TPUs, I got this error: # # OP_REQUIRES failed at example_parsing_ops.cc:240 : # Invalid argument: Feature: label (data type: int64) is required but # could not be found. # # which cannot be caught by assertRaises. with self.assertRaises(tf.errors.OpError): self._call_end2end_helper(testdata.GOLDEN_CALLING_EXAMPLES, model, False) @parameterized.parameters((model, shard_inputs, include_debug_info) for shard_inputs in [False, True] for model in modeling.production_models() for include_debug_info in [False, True]) @flagsaver.FlagSaver def test_call_end2end(self, model, shard_inputs, include_debug_info): FLAGS.include_debug_info = include_debug_info (call_variants_outputs, examples, batch_size, max_batches) = self._call_end2end_helper( testdata.GOLDEN_CALLING_EXAMPLES, model, shard_inputs) # Check that we have the right number of output protos. self.assertEqual( len(call_variants_outputs), batch_size * max_batches if max_batches else len(examples)) # Check that our CallVariantsOutput (CVO) have the following critical # properties: # - we have one CVO for each example we processed. # - the variant in the CVO is exactly what was in the example. # - the alt_allele_indices of the CVO match those of its corresponding # example. # - there are 3 genotype probabilities and these are between 0.0 and 1.0. # We can only do this test when processing all of the variants (max_batches # is None), since we processed all of the examples with that model. if max_batches is None: self.assertItemsEqual( [cvo.variant for cvo in call_variants_outputs], [tf_utils.example_variant(ex) for ex in examples]) # Check the CVO debug_info: not filled if include_debug_info is False; # else, filled by logic based on CVO. if not include_debug_info: for cvo in call_variants_outputs: self.assertEqual( cvo.debug_info, deepvariant_pb2.CallVariantsOutput.DebugInfo()) else: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info.has_insertion, variant_utils.has_insertion(cvo.variant)) self.assertEqual(cvo.debug_info.has_deletion, variant_utils.has_deletion(cvo.variant)) self.assertEqual(cvo.debug_info.is_snp, variant_utils.is_snp(cvo.variant)) self.assertEqual(cvo.debug_info.predicted_label, np.argmax(cvo.genotype_probabilities)) def example_matches_call_variants_output(example, call_variants_output): return (tf_utils.example_variant(example) == call_variants_output.variant and tf_utils.example_alt_alleles_indices(example) == call_variants_output.alt_allele_indices.indices) for call_variants_output in call_variants_outputs: # Find all matching examples. matches = [ ex for ex in examples if example_matches_call_variants_output( ex, call_variants_output) ] # We should have exactly one match. self.assertEqual(len(matches), 1) example = matches[0] # Check that we've faithfully copied in the alt alleles (though currently # as implemented we find our example using this information so it cannot # fail). Included here in case that changes in the future. self.assertEqual( list(tf_utils.example_alt_alleles_indices(example)), list(call_variants_output.alt_allele_indices.indices)) # We should have exactly three genotype probabilities (assuming our # ploidy == 2). self.assertEqual(len(call_variants_output.genotype_probabilities), 3) # These are probabilities so they should be between 0 and 1. self.assertTrue( 0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities) @parameterized.parameters((model, bad_format) for model in modeling.production_models() for bad_format in ['', 'png']) def test_call_variants_with_invalid_format(self, model, bad_format): # Read one good record from a valid file. example = next( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Overwrite the image/format field to be an invalid value # (anything but 'raw'). example.features.feature['image/format'].bytes_list.value[ 0] = bad_format source_path = test_utils.test_tmpfile('make_examples_output.tfrecord') io_utils.write_tfrecords([example], source_path) outfile = test_utils.test_tmpfile( 'call_variants_invalid_format.tfrecord') with self.assertRaises(ValueError): call_variants.call_variants( examples_filename=source_path, checkpoint_path=_LEAVE_MODEL_UNINITIALIZED, model=model, output_file=outfile, batch_size=1, max_batches=1, use_tpu=FLAGS.use_tpu) @parameterized.parameters(model for model in modeling.production_models()) def test_call_variants_with_no_shape(self, model): # Read one good record from a valid file. example = next( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Remove image/shape. del example.features.feature['image/shape'] source_path = test_utils.test_tmpfile( 'make_examples_out_noshape.tfrecord') io_utils.write_tfrecords([example], source_path) with self.assertRaisesRegexp( ValueError, 'Invalid image/shape: we expect to find an image/shape ' 'field with length 3.'): ds = call_variants.prepare_inputs(source_path) _ = list( data_providers.get_infer_batches(ds, model=model, batch_size=1)) def test_call_variants_with_empty_input(self): source_path = test_utils.test_tmpfile('empty.tfrecord') io_utils.write_tfrecords([], source_path) # Make sure that prepare_inputs don't crash on empty input. ds = call_variants.prepare_inputs(source_path) m = modeling.get_model('random_guess') # The API specifies that OutOfRangeError is thrown in this case. batches = list( data_providers.get_infer_batches(ds, model=m, batch_size=1)) with self.test_session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) try: _ = sess.run(batches) except tf.errors.OutOfRangeError: pass
class ModelEvalTest( six.with_metaclass(parameterized.TestGeneratorMetaclass, tf.test.TestCase)): def testSelectVariantsWeights(self): variants = [ test_utils.make_variant(start=10, alleles=['C', 'T']), test_utils.make_variant(start=11, alleles=['C', 'TA']), test_utils.make_variant(start=12, alleles=['C', 'A']), test_utils.make_variant(start=13, alleles=['CA', 'T']), ] encoded = tf.constant([v.SerializeToString() for v in variants]) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) op = model_eval.select_variants_weights(variantutils.is_snp, encoded, name='tf_is_snp') self.assertTrue(op.name.startswith('tf_is_snp')) npt.assert_array_equal(op.eval(), [1.0, 0.0, 1.0, 0.0]) def testCallingMetrics(self): def make_mock_metric(name): # pylint: disable=unused-argument def _side_effect(predictions, labels, weights): if weights: return name + ':' + ','.join(str(int(w)) for w in weights) else: return name + ':None' return mock.MagicMock(side_effect=_side_effect) predictions = tf.constant([0, 1, 2, 0]) labels = tf.constant([0, 2, 1, 1]) metrics = { 'm1': make_mock_metric('mock_metric1'), 'm2': make_mock_metric('mock_metric2') } selectors = {'s1': [1, 1, 1, 1], 's2': [0, 0, 0, 0], 's3': None} # The returned dictionary has the expected keys and values. self.assertEqual( { 'm1/s1': 'mock_metric1:1,1,1,1', 'm1/s2': 'mock_metric1:0,0,0,0', 'm1/s3': 'mock_metric1:None', 'm2/s1': 'mock_metric2:1,1,1,1', 'm2/s2': 'mock_metric2:0,0,0,0', 'm2/s3': 'mock_metric2:None', }, model_eval.calling_metrics(metrics_map=metrics, selectors_map=selectors, predictions=predictions, labels=labels)) # Check that our mocked metrics have all of the calls we. for mocked in metrics.values(): self.assertEqual([ mock.call(predictions, labels, weights=selectors[x]) for x in selectors ], mocked.call_args_list) @parameterized.parameters(model.name for model in modeling.production_models() if model.is_trainable) @flagsaver.FlagSaver @mock.patch('deepvariant.data_providers.get_dataset') def test_end2end(self, model_name, mock_get_dataset): """End-to-end test of model_eval.""" checkpoint_dir = tf.test.get_temp_dir() # Create a model with 3 classes, and save it to our checkpoint dir. with self.test_session() as sess: model = modeling.get_model(model_name) # Needed to protect ourselves for models without an input image shape. h, w = getattr(model, 'input_image_shape', (100, 221)) images = tf.placeholder(tf.float32, shape=(4, h, w, pileup_image.DEFAULT_NUM_CHANNEL)) model.create(images, num_classes=3, is_training=True) # This is gross, but necessary as model_eval assumes the model was trained # with model_train which uses exp moving averages. Unfortunately we cannot # just call into model_train as it uses FLAGS which conflict with the # flags in use by model_eval. So we inline the creation of the EMA here. variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, slim.get_or_create_global_step()) tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, variable_averages.apply(slim.get_model_variables())) sess.run(tf.global_variables_initializer()) save = tf.train.Saver(slim.get_variables()) save.save(sess, os.path.join(checkpoint_dir, 'model')) # Start up eval, loading that checkpoint. FLAGS.batch_size = 2 FLAGS.checkpoint_dir = checkpoint_dir FLAGS.eval_dir = tf.test.get_temp_dir() FLAGS.batches_per_eval_step = 1 FLAGS.max_evaluations = 1 FLAGS.eval_interval_secs = 0 FLAGS.model_name = model_name FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' # Always try to read in compressed inputs to stress that case. Uncompressed # inputs are certain to work. This test is expensive to run, so we want to # minimize the number of times we need to run this. mock_get_dataset.return_value = data_providers_test.make_golden_dataset( compressed_inputs=True) model_eval.main(0) mock_get_dataset.assert_called_once_with(FLAGS.dataset_config_pbtxt)
class CallVariantsEndToEndTests( six.with_metaclass(parameterized.TestGeneratorMetaclass, tf.test.TestCase)): def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename, num_examples): outfile = test_utils.test_tmpfile('call_variants.tfrecord') model = modeling.get_model('random_guess') call_variants.call_variants( examples_filename=filename, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=4, max_batches=None) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual(len(call_variants_outputs), num_examples) def test_call_end2end_with_empty_shards(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) # Write to 15 shards, which means there will be multiple empty shards. source_path = test_utils.test_tmpfile('sharded@{}'.format(15)) io_utils.write_tfrecords(examples, source_path) self.assertCallVariantsEmitsNRecordsForRandomGuess( source_path, len(examples)) def test_call_end2end_empty_first_shard(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) empty_first_file = test_utils.test_tmpfile( 'empty_1st_shard-00000-of-00002') io_utils.write_tfrecords([], empty_first_file) second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002') io_utils.write_tfrecords(examples, second_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('empty_1st_shard@2'), len(examples)) def test_call_end2end_zero_record_file(self): zero_record_file = test_utils.test_tmpfile('zero_record_file') io_utils.write_tfrecords([], zero_record_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('zero_record_file'), 0) @parameterized.parameters((model, shard_inputs, include_debug_info) for shard_inputs in [False, True] for model in modeling.production_models() for include_debug_info in [False, True]) @flagsaver.FlagSaver def test_call_end2end(self, model, shard_inputs, include_debug_info): FLAGS.include_debug_info = include_debug_info examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) io_utils.write_tfrecords(examples, source_path) else: source_path = testdata.GOLDEN_CALLING_EXAMPLES batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual( len(call_variants_outputs), batch_size * max_batches if max_batches else len(examples)) # Check that our CallVariantsOutput (CVO) have the following critical # properties: # - we have one CVO for each example we processed. # - the variant in the CVO is exactly what was in the example. # - the alt_allele_indices of the CVO match those of its corresponding # example. # - there are 3 genotype probabilities and these are between 0.0 and 1.0. # We can only do this test when processing all of the variants (max_batches # is None), since we processed all of the examples with that model. if max_batches is None: self.assertItemsEqual( [cvo.variant for cvo in call_variants_outputs], [tf_utils.example_variant(ex) for ex in examples]) # Check the CVO debug_info: not filled if include_debug_info is False; # else, filled by logic based on CVO. if not include_debug_info: for cvo in call_variants_outputs: self.assertEqual( cvo.debug_info, deepvariant_pb2.CallVariantsOutput.DebugInfo()) else: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info.has_insertion, variant_utils.has_insertion(cvo.variant)) self.assertEqual(cvo.debug_info.has_deletion, variant_utils.has_deletion(cvo.variant)) self.assertEqual(cvo.debug_info.is_snp, variant_utils.is_snp(cvo.variant)) self.assertEqual(cvo.debug_info.predicted_label, np.argmax(cvo.genotype_probabilities)) def example_matches_call_variants_output(example, call_variants_output): return (tf_utils.example_variant(example) == call_variants_output.variant and tf_utils.example_alt_alleles_indices(example) == call_variants_output.alt_allele_indices.indices) for call_variants_output in call_variants_outputs: # Find all matching examples. matches = [ ex for ex in examples if example_matches_call_variants_output( ex, call_variants_output) ] # We should have exactly one match. self.assertEqual(len(matches), 1) example = matches[0] # Check that we've faithfully copied in the alt alleles (though currently # as implemented we find our example using this information so it cannot # fail). Included here in case that changes in the future. self.assertEqual( list(tf_utils.example_alt_alleles_indices(example)), list(call_variants_output.alt_allele_indices.indices)) # We should have exactly three genotype probabilities (assuming our # ploidy == 2). self.assertEqual(len(call_variants_output.genotype_probabilities), 3) # These are probabilities so they should be between 0 and 1. self.assertTrue( 0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities) @parameterized.parameters((model, bad_format) for model in modeling.production_models() for bad_format in ['', 'png']) def test_call_variants_with_invalid_format(self, model, bad_format): # Read one good record from a valid file. example = next( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Overwrite the image/format field to be an invalid value # (anything but 'raw'). example.features.feature['image/format'].bytes_list.value[ 0] = bad_format source_path = test_utils.test_tmpfile('make_examples_output.tfrecord') io_utils.write_tfrecords([example], source_path) outfile = test_utils.test_tmpfile( 'call_variants_invalid_format.tfrecord') with self.assertRaises(ValueError): call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=1, max_batches=1) @parameterized.parameters(model for model in modeling.production_models()) def test_call_variants_with_no_shape(self, model): # Read one good record from a valid file. example = next( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Remove image/shape. del example.features.feature['image/shape'] source_path = test_utils.test_tmpfile( 'make_examples_out_noshape.tfrecord') io_utils.write_tfrecords([example], source_path) with self.assertRaisesRegexp( ValueError, 'Invalid image/shape: we expect to find an image/shape ' 'field with length 3.'): call_variants.prepare_inputs(source_path, model, batch_size=1) def test_call_variants_with_empty_input(self): source_path = test_utils.test_tmpfile('empty.tfrecord') io_utils.write_tfrecords([], source_path) # Make sure that prepare_inputs don't crash on empty input. call_variants.prepare_inputs(source_path, modeling.get_model('random_guess'), batch_size=1)
class ModelTrainTest(parameterized.TestCase, tf.test.TestCase): @flagsaver.FlagSaver def test_training_works_with_compressed_inputs(self): """End-to-end test of model_train script.""" self._run_tiny_training( model_name='mobilenet_v1', dataset=data_providers_test.make_golden_dataset( compressed_inputs=True, use_tpu=FLAGS.use_tpu)) def _run_tiny_training(self, model_name, dataset, warm_start_from=''): """Runs one training step. This function always starts a new train_dir.""" with mock.patch( 'deepvariant.data_providers.' 'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset: mock_get_input_fn_from_dataset.return_value = dataset FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = -1 FLAGS.save_interval_steps = 1 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = warm_start_from FLAGS.master = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_input_fn_from_dataset.assert_called_once_with( dataset_config_filename=FLAGS.dataset_config_pbtxt, mode=tf.estimator.ModeKeys.TRAIN, use_tpu=mock.ANY, ) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir)) @mock.patch('deepvariant' '.modeling.slim.losses.softmax_cross_entropy') @mock.patch('deepvariant' '.modeling.slim.losses.get_total_loss') def test_loss(self, mock_total_loss, mock_cross): labels = [[0, 1, 0], [1, 0, 0]] logits = 'Logits' smoothing = 0.01 actual = model_train.loss(logits, labels, smoothing) mock_total_loss.assert_called_once_with() self.assertEqual(actual, mock_total_loss.return_value) mock_cross.assert_called_once_with(logits, labels, label_smoothing=smoothing, weights=1.0) @parameterized.parameters(model.name for model in modeling.production_models() if model.is_trainable) @flagsaver.FlagSaver def test_end2end(self, model_name): """End-to-end test of model_train script.""" self._run_tiny_training( model_name=model_name, dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu)) @flagsaver.FlagSaver def test_end2end_inception_v3_warm_up_from(self): """End-to-end test of model_train script.""" checkpoint_dir = tf_test_utils.test_tmpdir('inception_v3_warm_up_from') tf_test_utils.write_fake_checkpoint('inception_v3', self.test_session(), checkpoint_dir) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model') @flagsaver.FlagSaver def test_end2end_inception_v3_warm_up_from_mobilenet_v1(self): """Tests the behavior when warm start from mobilenet but train inception.""" checkpoint_dir = tf_test_utils.test_tmpdir( 'inception_v3_warm_up_from_mobilenet_v1') tf_test_utils.write_fake_checkpoint('mobilenet_v1', self.test_session(), checkpoint_dir) self.assertTrue( tf_test_utils.check_equals_checkpoint_top_scopes( checkpoint_dir + '/model', ['MobilenetV1', 'global_step'])) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model') self.assertTrue( tf_test_utils.check_equals_checkpoint_top_scopes( FLAGS.train_dir + '/model.ckpt-1', ['InceptionV3', 'global_step'])) @flagsaver.FlagSaver def test_end2end_inception_v3_failed_warm_up_from(self): """End-to-end test of model_train script with a non-existent path.""" with self.assertRaises(tf.errors.OpError): self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from='this/path/does/not/exist') @parameterized.parameters((False), (True)) @flagsaver.FlagSaver @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_internal(self, use_tpu, mock_run, mock_device_setter): FLAGS.master = 'some_master' FLAGS.use_tpu = use_tpu FLAGS.ps_tasks = 10 FLAGS.task = 5 model_train.parse_and_run() mock_device_setter.assert_called_once_with(10) mock_run.assert_called_once_with('some_master' if use_tpu else '', False, device_fn=mock.ANY, use_tpu=mock.ANY) @mock.patch('deepvariant.model_train.os.environ') @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_tfconfig_local(self, mock_run, mock_device_setter, mock_environ): mock_environ.get.return_value = '{}' model_train.parse_and_run() mock_device_setter.assert_called_once_with(0) mock_run.assert_called_once_with('', True, device_fn=mock.ANY, use_tpu=mock.ANY) @parameterized.named_parameters( ('master', 'master', 0, True, '/job:master/task:0'), ('worker', 'worker', 10, False, '/job:worker/task:10'), ) @mock.patch('deepvariant.model_train.tf.train.Server') @mock.patch('deepvariant.model_train.os.environ') @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief, expected_worker, mock_run, mock_device_setter, mock_environ, mock_server): tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': job_name, 'index': task_index, }, } class FakeServer(object): target = 'some-target' mock_environ.get.return_value = json.dumps(tf_config) mock_server.return_value = FakeServer() model_train.parse_and_run() mock_device_setter.assert_called_once_with( 2, worker_device=expected_worker, cluster=mock.ANY) mock_run.assert_called_once_with('some-target', expected_is_chief, device_fn=mock.ANY, use_tpu=mock.ANY) @parameterized.parameters( ('master', 'some-master'), ('task', 10), ('ps_tasks', 5), ) @flagsaver.FlagSaver @mock.patch('deepvariant.model_train.os.environ') def test_main_invalid_args(self, flag_name, flag_value, mock_environ): # Ensure an exception is raised if flags and TF_CONFIG are set. tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': 'master', 'index': 0, }, } mock_environ.get.return_value = json.dumps(tf_config) setattr(FLAGS, flag_name, flag_value) self.assertRaises(ValueError, model_train.parse_and_run)
class ModelTrainTest(parameterized.TestCase): @flagsaver.FlagSaver def test_training_works_with_compressed_inputs(self): """End-to-end test of model_train script.""" self._run_tiny_training( model_name='mobilenet_v1', dataset=data_providers_test.make_golden_dataset( compressed_inputs=True)) def _run_tiny_training(self, model_name, dataset): with mock.patch( 'deepvariant.data_providers.get_dataset') as mock_get_dataset: mock_get_dataset.return_value = dataset FLAGS.train_dir = test_utils.test_tmpfile(model_name) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = 0 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_dataset.assert_called_once_with( FLAGS.dataset_config_pbtxt) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir)) @parameterized.parameters(model.name for model in modeling.production_models() if model.is_trainable) @flagsaver.FlagSaver def test_end2end(self, model_name): """End-to-end test of model_train script.""" self._run_tiny_training( model_name=model_name, dataset=data_providers_test.make_golden_dataset()) @parameterized.parameters( (None, None), ('', None), ('/path/to/file', MOCK_SENTINEL_RETURN_VALUE), ('USE_FLAG_VALUE', MOCK_SENTINEL_RETURN_VALUE), ) def test_model_init_function(self, path, expected): model = mock.Mock(spec=modeling.DeepVariantModel) model.initialize_from_checkpoint.return_value = MOCK_SENTINEL_RETURN_VALUE self.assertEqual(expected, model_train.model_init_function(model, 3, path)) if expected: model.initialize_from_checkpoint.assert_called_once_with( path, 3, is_training=True) else: test_utils.assert_not_called_workaround( model.initialize_from_checkpoint) @flagsaver.FlagSaver @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_internal(self, mock_run, mock_device_setter): FLAGS.master = 'some_master' FLAGS.ps_tasks = 10 FLAGS.task = 5 model_train.parse_and_run() mock_device_setter.assert_called_once_with(10) mock_run.assert_called_once_with('some_master', False, device_fn=mock.ANY) @mock.patch('deepvariant.model_train.os.environ') @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_tfconfig_local(self, mock_run, mock_device_setter, mock_environ): mock_environ.get.return_value = '{}' model_train.parse_and_run() mock_device_setter.assert_called_once_with(0) mock_run.assert_called_once_with('', True, device_fn=mock.ANY) @parameterized.named_parameters( ('master', 'master', 0, True, '/job:master/task:0'), ('worker', 'worker', 10, False, '/job:worker/task:10'), ) @mock.patch('deepvariant.model_train.tf.train.Server') @mock.patch('deepvariant.model_train.os.environ') @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief, expected_worker, mock_run, mock_device_setter, mock_environ, mock_server): tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': job_name, 'index': task_index, }, } class FakeServer(object): target = 'some-target' mock_environ.get.return_value = json.dumps(tf_config) mock_server.return_value = FakeServer() model_train.parse_and_run() mock_device_setter.assert_called_once_with( 2, worker_device=expected_worker, cluster=mock.ANY) mock_run.assert_called_once_with('some-target', expected_is_chief, device_fn=mock.ANY) @parameterized.parameters( ('master', 'some-master'), ('task', 10), ('ps_tasks', 5), ) @flagsaver.FlagSaver @mock.patch('deepvariant.model_train.os.environ') def test_main_invalid_args(self, flag_name, flag_value, mock_environ): # Ensure an exception is raised if flags and TF_CONFIG are set. tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': 'master', 'index': 0, }, } mock_environ.get.return_value = json.dumps(tf_config) setattr(FLAGS, flag_name, flag_value) self.assertRaises(ValueError, model_train.parse_and_run)
class ModelTrainTest(parameterized.TestCase, tf.test.TestCase): def _run_tiny_training(self, model_name, dataset, warm_start_from=''): """Runs one training step. This function always starts a new train_dir.""" with mock.patch( 'deepvariant.data_providers.' 'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset: mock_get_input_fn_from_dataset.return_value = dataset FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = -1 FLAGS.save_interval_steps = 1 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = warm_start_from FLAGS.master = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_input_fn_from_dataset.assert_called_once_with( dataset_config_filename=FLAGS.dataset_config_pbtxt, mode=tf.estimator.ModeKeys.TRAIN, use_tpu=mock.ANY, max_examples=None, ) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir)) @mock.patch('deepvariant' '.modeling.tf.compat.v1.losses.softmax_cross_entropy') @mock.patch('deepvariant' '.modeling.tf.compat.v1.losses.get_total_loss') def test_loss(self, mock_total_loss, mock_cross): labels = [[0, 1, 0], [1, 0, 0]] logits = 'Logits' smoothing = 0.01 actual = model_train.loss(logits, labels, smoothing) mock_total_loss.assert_called_once_with() self.assertEqual(actual, mock_total_loss.return_value) mock_cross.assert_called_once_with( logits, labels, label_smoothing=smoothing, weights=1.0) # pylint: disable=g-complex-comprehension @parameterized.parameters((model.name, compressed_inputs) for model in modeling.production_models() if model.is_trainable for compressed_inputs in [True, False]) # pylint: enable=g-complex-comprehension @flagsaver.flagsaver def test_end2end(self, model_name, compressed_inputs): """End-to-end test of model_train script.""" self._run_tiny_training( model_name=model_name, dataset=data_providers_test.make_golden_dataset( compressed_inputs=compressed_inputs, use_tpu=FLAGS.use_tpu)) @flagsaver.flagsaver def test_end2end_inception_v3_warm_up_from(self): """End-to-end test of model_train script.""" checkpoint_dir = tf_test_utils.test_tmpdir('inception_v3_warm_up_from') tf_test_utils.write_fake_checkpoint('inception_v3', self.test_session(), checkpoint_dir) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model') @flagsaver.flagsaver def test_end2end_inception_v3_warm_up_allow_different_num_channels(self): """End-to-end test of model_train script.""" FLAGS.allow_warmstart_from_different_num_channels = True checkpoint_dir = tf_test_utils.test_tmpdir( 'inception_v3_warm_up_allow_different_num_channels') tf_test_utils.write_fake_checkpoint( 'inception_v3', self.test_session(), checkpoint_dir, num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model') @flagsaver.flagsaver def test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels(self): """End-to-end test of model_train script.""" checkpoint_dir = tf_test_utils.test_tmpdir( 'test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels') tf_test_utils.write_fake_checkpoint( 'inception_v3', self.test_session(), checkpoint_dir, num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1) with self.assertRaisesRegex( ValueError, r'Shape of variable InceptionV3/Conv2d_1a_3x3/weights:0 \(\(.*\)\) ' r'doesn\'t match with shape of tensor ' r'InceptionV3/Conv2d_1a_3x3/weights \(\[.*\]\) from checkpoint reader.' ): self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model') @flagsaver.flagsaver def test_end2end_inception_v3_failed_warm_up_from(self): """End-to-end test of model_train script with a non-existent path.""" # Internal TF raises tf.errors.OpError, public TF raises ValueError. with self.assertRaises((tf.errors.OpError, ValueError)): self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from='this/path/does/not/exist') @flagsaver.flagsaver def test_end2end_inception_v3_embedding_invalid_embedding_size(self): """End-to-end test of model_train script with an invalid embedding size.""" with six.assertRaisesRegex( self, ValueError, 'Expected seq_type_embedding_size ' 'to be a positive number but saw -100 ' 'instead.'): FLAGS.seq_type_embedding_size = -100 self._run_tiny_training( model_name='inception_v3_embedding', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu)) @parameterized.parameters((False), (True)) @flagsaver.flagsaver @mock.patch('deepvariant.model_train.' 'tf.compat.v1.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_internal(self, use_tpu, mock_run, mock_device_setter): FLAGS.master = 'some_master' FLAGS.use_tpu = use_tpu FLAGS.ps_tasks = 10 FLAGS.task = 5 model_train.parse_and_run() mock_device_setter.assert_called_once_with(10) mock_run.assert_called_once_with( 'some_master' if use_tpu else '', False, device_fn=mock.ANY, use_tpu=mock.ANY) @mock.patch('deepvariant.model_train.os.environ') @mock.patch('deepvariant.model_train.' 'tf.compat.v1.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_tfconfig_local(self, mock_run, mock_device_setter, mock_environ): mock_environ.get.return_value = '{}' model_train.parse_and_run() mock_device_setter.assert_called_once_with(0) mock_run.assert_called_once_with( '', True, device_fn=mock.ANY, use_tpu=mock.ANY) @parameterized.named_parameters( ('master', 'master', 0, True, '/job:master/task:0'), ('worker', 'worker', 10, False, '/job:worker/task:10'), ) @mock.patch( 'deepvariant.model_train.tf.distribute.Server') @mock.patch('deepvariant.model_train.os.environ') @mock.patch('deepvariant.model_train.' 'tf.compat.v1.train.replica_device_setter') @mock.patch('deepvariant.model_train.run') def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief, expected_worker, mock_run, mock_device_setter, mock_environ, mock_server): tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': job_name, 'index': task_index, }, } class FakeServer(object): target = 'some-target' mock_environ.get.return_value = json.dumps(tf_config) mock_server.return_value = FakeServer() model_train.parse_and_run() mock_device_setter.assert_called_once_with( 2, worker_device=expected_worker, cluster=mock.ANY) mock_run.assert_called_once_with( 'some-target', expected_is_chief, device_fn=mock.ANY, use_tpu=mock.ANY) @parameterized.parameters( ('master', 'some-master'), ('task', 10), ('ps_tasks', 5), ) @flagsaver.flagsaver @mock.patch('deepvariant.model_train.os.environ') def test_main_invalid_args(self, flag_name, flag_value, mock_environ): # Ensure an exception is raised if flags and TF_CONFIG are set. tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': 'master', 'index': 0, }, } mock_environ.get.return_value = json.dumps(tf_config) setattr(FLAGS, flag_name, flag_value) self.assertRaises(ValueError, model_train.parse_and_run)