class CallVariantsAcceleratorTests(
        six.with_metaclass(parameterized.TestGeneratorMetaclass,
                           tf.test.TestCase)):
    @parameterized.parameters(modeling.production_models())
    def test_call_variants_runs_on_gpus(self, model):
        call_variants.call_variants(
            examples_filename=testdata.GOLDEN_CALLING_EXAMPLES,
            checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
            model=model,
            execution_hardware='accelerator',
            output_file=test_utils.test_tmpfile('zzz.tfrecord'))
Exemplo n.º 2
0
class CallVariantsEndToEndTests(
        six.with_metaclass(parameterized.TestGeneratorMetaclass,
                           tf.test.TestCase)):
    def setUp(self):
        self.checkpoint_dir = tf.test.get_temp_dir()

    def assertCallVariantsEmitsNRecordsForInceptionV3(self, filename,
                                                      num_examples):
        outfile = test_utils.test_tmpfile(
            'inception_v3.call_variants.tfrecord')
        model = modeling.get_model('inception_v3')
        checkpoint_path = _LEAVE_MODEL_UNINITIALIZED

        call_variants.call_variants(examples_filename=filename,
                                    checkpoint_path=checkpoint_path,
                                    model=model,
                                    output_file=outfile,
                                    batch_size=4,
                                    max_batches=None)
        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))
        # Check that we have the right number of output protos.
        self.assertEqual(len(call_variants_outputs), num_examples)

    def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                      num_examples):
        checkpoint_path = _LEAVE_MODEL_UNINITIALIZED
        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        model = modeling.get_model('random_guess')
        call_variants.call_variants(examples_filename=filename,
                                    checkpoint_path=checkpoint_path,
                                    model=model,
                                    output_file=outfile,
                                    batch_size=4,
                                    max_batches=None,
                                    master='',
                                    use_tpu=FLAGS.use_tpu)
        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))
        # Check that we have the right number of output protos.
        self.assertEqual(len(call_variants_outputs), num_examples)

    def test_call_end2end_with_empty_shards(self):
        # Get only up to 10 examples.
        examples = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                    max_records=10))
        # Write to 15 shards, which means there will be multiple empty shards.
        source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
        io_utils.write_tfrecords(examples, source_path)
        self.assertCallVariantsEmitsNRecordsForRandomGuess(
            source_path, len(examples))

    def test_call_end2end_empty_first_shard(self):
        # Get only up to 10 examples.
        examples = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                    max_records=10))
        empty_first_file = test_utils.test_tmpfile(
            'empty_1st_shard-00000-of-00002')
        io_utils.write_tfrecords([], empty_first_file)
        second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
        io_utils.write_tfrecords(examples, second_file)
        self.assertCallVariantsEmitsNRecordsForRandomGuess(
            test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))

    def test_call_end2end_zero_record_file_for_inception_v3(self):
        zero_record_file = test_utils.test_tmpfile('zero_record_file')
        io_utils.write_tfrecords([], zero_record_file)
        self.assertCallVariantsEmitsNRecordsForInceptionV3(
            test_utils.test_tmpfile('zero_record_file'), 0)

    def _call_end2end_helper(self, examples_path, model, shard_inputs):
        examples = list(io_utils.read_tfrecords(examples_path))

        if shard_inputs:
            # Create a sharded version of our golden examples.
            source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
            io_utils.write_tfrecords(examples, source_path)
        else:
            source_path = examples_path

        # If we point the test at a headless server, it will often be 2x2,
        # which has 8 replicas.  Otherwise a smaller batch size is fine.
        if FLAGS.use_tpu:
            batch_size = 8
        else:
            batch_size = 4

        if model.name == 'random_guess':
            # For the random guess model we can run everything.
            max_batches = None
        else:
            # For all other models we only run a single batch for inference.
            max_batches = 1

        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        call_variants.call_variants(
            examples_filename=source_path,
            checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
            model=model,
            output_file=outfile,
            batch_size=batch_size,
            max_batches=max_batches,
            master='',
            use_tpu=FLAGS.use_tpu,
        )

        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))

        return call_variants_outputs, examples, batch_size, max_batches

    @parameterized.parameters(model for model in modeling.production_models())
    @flagsaver.FlagSaver
    def test_call_end2end_with_labels(self, model):
        FLAGS.debugging_true_label_mode = True
        (call_variants_outputs, examples, batch_size,
         max_batches) = self._call_end2end_helper(
             testdata.GOLDEN_TRAINING_EXAMPLES, model, False)
        # Check that we have the right number of output protos.
        self.assertEqual(
            len(call_variants_outputs),
            batch_size * max_batches if max_batches else len(examples))

        # Checks that at least some of the `true_label`s are filled.
        self.assertTrue(
            any(cvo.debug_info.true_label > 0
                for cvo in call_variants_outputs))

    @parameterized.parameters(model for model in modeling.production_models())
    @flagsaver.FlagSaver
    def test_call_end2end_no_labels_fails(self, model):
        FLAGS.debugging_true_label_mode = True
        if not FLAGS.use_tpu:
            # On TPUs, I got this error:
            #
            # OP_REQUIRES failed at example_parsing_ops.cc:240 :
            # Invalid argument: Feature: label (data type: int64) is required but
            # could not be found.
            #
            # which cannot be caught by assertRaises.
            with self.assertRaises(tf.errors.OpError):
                self._call_end2end_helper(testdata.GOLDEN_CALLING_EXAMPLES,
                                          model, False)

    @parameterized.parameters((model, shard_inputs, include_debug_info)
                              for shard_inputs in [False, True]
                              for model in modeling.production_models()
                              for include_debug_info in [False, True])
    @flagsaver.FlagSaver
    def test_call_end2end(self, model, shard_inputs, include_debug_info):
        FLAGS.include_debug_info = include_debug_info
        (call_variants_outputs, examples, batch_size,
         max_batches) = self._call_end2end_helper(
             testdata.GOLDEN_CALLING_EXAMPLES, model, shard_inputs)
        # Check that we have the right number of output protos.
        self.assertEqual(
            len(call_variants_outputs),
            batch_size * max_batches if max_batches else len(examples))

        # Check that our CallVariantsOutput (CVO) have the following critical
        # properties:
        # - we have one CVO for each example we processed.
        # - the variant in the CVO is exactly what was in the example.
        # - the alt_allele_indices of the CVO match those of its corresponding
        #   example.
        # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
        # We can only do this test when processing all of the variants (max_batches
        # is None), since we processed all of the examples with that model.
        if max_batches is None:
            self.assertItemsEqual(
                [cvo.variant for cvo in call_variants_outputs],
                [tf_utils.example_variant(ex) for ex in examples])

        # Check the CVO debug_info: not filled if include_debug_info is False;
        # else, filled by logic based on CVO.
        if not include_debug_info:
            for cvo in call_variants_outputs:
                self.assertEqual(
                    cvo.debug_info,
                    deepvariant_pb2.CallVariantsOutput.DebugInfo())
        else:
            for cvo in call_variants_outputs:
                self.assertEqual(cvo.debug_info.has_insertion,
                                 variant_utils.has_insertion(cvo.variant))
                self.assertEqual(cvo.debug_info.has_deletion,
                                 variant_utils.has_deletion(cvo.variant))
                self.assertEqual(cvo.debug_info.is_snp,
                                 variant_utils.is_snp(cvo.variant))
                self.assertEqual(cvo.debug_info.predicted_label,
                                 np.argmax(cvo.genotype_probabilities))

        def example_matches_call_variants_output(example,
                                                 call_variants_output):
            return (tf_utils.example_variant(example)
                    == call_variants_output.variant
                    and tf_utils.example_alt_alleles_indices(example)
                    == call_variants_output.alt_allele_indices.indices)

        for call_variants_output in call_variants_outputs:
            # Find all matching examples.
            matches = [
                ex for ex in examples if example_matches_call_variants_output(
                    ex, call_variants_output)
            ]
            # We should have exactly one match.
            self.assertEqual(len(matches), 1)
            example = matches[0]
            # Check that we've faithfully copied in the alt alleles (though currently
            # as implemented we find our example using this information so it cannot
            # fail). Included here in case that changes in the future.
            self.assertEqual(
                list(tf_utils.example_alt_alleles_indices(example)),
                list(call_variants_output.alt_allele_indices.indices))
            # We should have exactly three genotype probabilities (assuming our
            # ploidy == 2).
            self.assertEqual(len(call_variants_output.genotype_probabilities),
                             3)
            # These are probabilities so they should be between 0 and 1.
            self.assertTrue(
                0 <= gp <= 1
                for gp in call_variants_output.genotype_probabilities)

    @parameterized.parameters((model, bad_format)
                              for model in modeling.production_models()
                              for bad_format in ['', 'png'])
    def test_call_variants_with_invalid_format(self, model, bad_format):
        # Read one good record from a valid file.
        example = next(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
        # Overwrite the image/format field to be an invalid value
        # (anything but 'raw').
        example.features.feature['image/format'].bytes_list.value[
            0] = bad_format
        source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
        io_utils.write_tfrecords([example], source_path)
        outfile = test_utils.test_tmpfile(
            'call_variants_invalid_format.tfrecord')

        with self.assertRaises(ValueError):
            call_variants.call_variants(
                examples_filename=source_path,
                checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
                model=model,
                output_file=outfile,
                batch_size=1,
                max_batches=1,
                use_tpu=FLAGS.use_tpu)

    @parameterized.parameters(model for model in modeling.production_models())
    def test_call_variants_with_no_shape(self, model):
        # Read one good record from a valid file.
        example = next(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
        # Remove image/shape.
        del example.features.feature['image/shape']
        source_path = test_utils.test_tmpfile(
            'make_examples_out_noshape.tfrecord')
        io_utils.write_tfrecords([example], source_path)
        with self.assertRaisesRegexp(
                ValueError,
                'Invalid image/shape: we expect to find an image/shape '
                'field with length 3.'):
            ds = call_variants.prepare_inputs(source_path)
            _ = list(
                data_providers.get_infer_batches(ds, model=model,
                                                 batch_size=1))

    def test_call_variants_with_empty_input(self):
        source_path = test_utils.test_tmpfile('empty.tfrecord')
        io_utils.write_tfrecords([], source_path)
        # Make sure that prepare_inputs don't crash on empty input.
        ds = call_variants.prepare_inputs(source_path)
        m = modeling.get_model('random_guess')

        # The API specifies that OutOfRangeError is thrown in this case.
        batches = list(
            data_providers.get_infer_batches(ds, model=m, batch_size=1))
        with self.test_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())
            try:
                _ = sess.run(batches)
            except tf.errors.OutOfRangeError:
                pass
Exemplo n.º 3
0
class ModelEvalTest(
        six.with_metaclass(parameterized.TestGeneratorMetaclass,
                           tf.test.TestCase)):
    def testSelectVariantsWeights(self):
        variants = [
            test_utils.make_variant(start=10, alleles=['C', 'T']),
            test_utils.make_variant(start=11, alleles=['C', 'TA']),
            test_utils.make_variant(start=12, alleles=['C', 'A']),
            test_utils.make_variant(start=13, alleles=['CA', 'T']),
        ]
        encoded = tf.constant([v.SerializeToString() for v in variants])

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            op = model_eval.select_variants_weights(variantutils.is_snp,
                                                    encoded,
                                                    name='tf_is_snp')
            self.assertTrue(op.name.startswith('tf_is_snp'))
            npt.assert_array_equal(op.eval(), [1.0, 0.0, 1.0, 0.0])

    def testCallingMetrics(self):
        def make_mock_metric(name):
            # pylint: disable=unused-argument
            def _side_effect(predictions, labels, weights):
                if weights:
                    return name + ':' + ','.join(str(int(w)) for w in weights)
                else:
                    return name + ':None'

            return mock.MagicMock(side_effect=_side_effect)

        predictions = tf.constant([0, 1, 2, 0])
        labels = tf.constant([0, 2, 1, 1])
        metrics = {
            'm1': make_mock_metric('mock_metric1'),
            'm2': make_mock_metric('mock_metric2')
        }
        selectors = {'s1': [1, 1, 1, 1], 's2': [0, 0, 0, 0], 's3': None}

        # The returned dictionary has the expected keys and values.
        self.assertEqual(
            {
                'm1/s1': 'mock_metric1:1,1,1,1',
                'm1/s2': 'mock_metric1:0,0,0,0',
                'm1/s3': 'mock_metric1:None',
                'm2/s1': 'mock_metric2:1,1,1,1',
                'm2/s2': 'mock_metric2:0,0,0,0',
                'm2/s3': 'mock_metric2:None',
            },
            model_eval.calling_metrics(metrics_map=metrics,
                                       selectors_map=selectors,
                                       predictions=predictions,
                                       labels=labels))

        # Check that our mocked metrics have all of the calls we.
        for mocked in metrics.values():
            self.assertEqual([
                mock.call(predictions, labels, weights=selectors[x])
                for x in selectors
            ], mocked.call_args_list)

    @parameterized.parameters(model.name
                              for model in modeling.production_models()
                              if model.is_trainable)
    @flagsaver.FlagSaver
    @mock.patch('deepvariant.data_providers.get_dataset')
    def test_end2end(self, model_name, mock_get_dataset):
        """End-to-end test of model_eval."""
        checkpoint_dir = tf.test.get_temp_dir()

        # Create a model with 3 classes, and save it to our checkpoint dir.
        with self.test_session() as sess:
            model = modeling.get_model(model_name)
            # Needed to protect ourselves for models without an input image shape.
            h, w = getattr(model, 'input_image_shape', (100, 221))
            images = tf.placeholder(tf.float32,
                                    shape=(4, h, w,
                                           pileup_image.DEFAULT_NUM_CHANNEL))
            model.create(images, num_classes=3, is_training=True)
            # This is gross, but necessary as model_eval assumes the model was trained
            # with model_train which uses exp moving averages. Unfortunately we cannot
            # just call into model_train as it uses FLAGS which conflict with the
            # flags in use by model_eval. So we inline the creation of the EMA here.
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, slim.get_or_create_global_step())
            tf.add_to_collection(
                tf.GraphKeys.UPDATE_OPS,
                variable_averages.apply(slim.get_model_variables()))
            sess.run(tf.global_variables_initializer())
            save = tf.train.Saver(slim.get_variables())
            save.save(sess, os.path.join(checkpoint_dir, 'model'))

        # Start up eval, loading that checkpoint.
        FLAGS.batch_size = 2
        FLAGS.checkpoint_dir = checkpoint_dir
        FLAGS.eval_dir = tf.test.get_temp_dir()
        FLAGS.batches_per_eval_step = 1
        FLAGS.max_evaluations = 1
        FLAGS.eval_interval_secs = 0
        FLAGS.model_name = model_name
        FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
        # Always try to read in compressed inputs to stress that case. Uncompressed
        # inputs are certain to work. This test is expensive to run, so we want to
        # minimize the number of times we need to run this.
        mock_get_dataset.return_value = data_providers_test.make_golden_dataset(
            compressed_inputs=True)
        model_eval.main(0)
        mock_get_dataset.assert_called_once_with(FLAGS.dataset_config_pbtxt)
Exemplo n.º 4
0
class CallVariantsEndToEndTests(
        six.with_metaclass(parameterized.TestGeneratorMetaclass,
                           tf.test.TestCase)):
    def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                      num_examples):
        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        model = modeling.get_model('random_guess')
        call_variants.call_variants(
            examples_filename=filename,
            checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
            model=model,
            output_file=outfile,
            batch_size=4,
            max_batches=None)
        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))
        # Check that we have the right number of output protos.
        self.assertEqual(len(call_variants_outputs), num_examples)

    def test_call_end2end_with_empty_shards(self):
        # Get only up to 10 examples.
        examples = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                    max_records=10))
        # Write to 15 shards, which means there will be multiple empty shards.
        source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
        io_utils.write_tfrecords(examples, source_path)
        self.assertCallVariantsEmitsNRecordsForRandomGuess(
            source_path, len(examples))

    def test_call_end2end_empty_first_shard(self):
        # Get only up to 10 examples.
        examples = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                    max_records=10))
        empty_first_file = test_utils.test_tmpfile(
            'empty_1st_shard-00000-of-00002')
        io_utils.write_tfrecords([], empty_first_file)
        second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
        io_utils.write_tfrecords(examples, second_file)
        self.assertCallVariantsEmitsNRecordsForRandomGuess(
            test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))

    def test_call_end2end_zero_record_file(self):
        zero_record_file = test_utils.test_tmpfile('zero_record_file')
        io_utils.write_tfrecords([], zero_record_file)
        self.assertCallVariantsEmitsNRecordsForRandomGuess(
            test_utils.test_tmpfile('zero_record_file'), 0)

    @parameterized.parameters((model, shard_inputs, include_debug_info)
                              for shard_inputs in [False, True]
                              for model in modeling.production_models()
                              for include_debug_info in [False, True])
    @flagsaver.FlagSaver
    def test_call_end2end(self, model, shard_inputs, include_debug_info):
        FLAGS.include_debug_info = include_debug_info
        examples = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))

        if shard_inputs:
            # Create a sharded version of our golden examples.
            source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
            io_utils.write_tfrecords(examples, source_path)
        else:
            source_path = testdata.GOLDEN_CALLING_EXAMPLES

        batch_size = 4
        if model.name == 'random_guess':
            # For the random guess model we can run everything.
            max_batches = None
        else:
            # For all other models we only run a single batch for inference.
            max_batches = 1

        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        call_variants.call_variants(
            examples_filename=source_path,
            checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
            model=model,
            output_file=outfile,
            batch_size=batch_size,
            max_batches=max_batches)

        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))

        # Check that we have the right number of output protos.
        self.assertEqual(
            len(call_variants_outputs),
            batch_size * max_batches if max_batches else len(examples))

        # Check that our CallVariantsOutput (CVO) have the following critical
        # properties:
        # - we have one CVO for each example we processed.
        # - the variant in the CVO is exactly what was in the example.
        # - the alt_allele_indices of the CVO match those of its corresponding
        #   example.
        # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
        # We can only do this test when processing all of the variants (max_batches
        # is None), since we processed all of the examples with that model.
        if max_batches is None:
            self.assertItemsEqual(
                [cvo.variant for cvo in call_variants_outputs],
                [tf_utils.example_variant(ex) for ex in examples])

        # Check the CVO debug_info: not filled if include_debug_info is False;
        # else, filled by logic based on CVO.
        if not include_debug_info:
            for cvo in call_variants_outputs:
                self.assertEqual(
                    cvo.debug_info,
                    deepvariant_pb2.CallVariantsOutput.DebugInfo())
        else:
            for cvo in call_variants_outputs:
                self.assertEqual(cvo.debug_info.has_insertion,
                                 variant_utils.has_insertion(cvo.variant))
                self.assertEqual(cvo.debug_info.has_deletion,
                                 variant_utils.has_deletion(cvo.variant))
                self.assertEqual(cvo.debug_info.is_snp,
                                 variant_utils.is_snp(cvo.variant))
                self.assertEqual(cvo.debug_info.predicted_label,
                                 np.argmax(cvo.genotype_probabilities))

        def example_matches_call_variants_output(example,
                                                 call_variants_output):
            return (tf_utils.example_variant(example)
                    == call_variants_output.variant
                    and tf_utils.example_alt_alleles_indices(example)
                    == call_variants_output.alt_allele_indices.indices)

        for call_variants_output in call_variants_outputs:
            # Find all matching examples.
            matches = [
                ex for ex in examples if example_matches_call_variants_output(
                    ex, call_variants_output)
            ]
            # We should have exactly one match.
            self.assertEqual(len(matches), 1)
            example = matches[0]
            # Check that we've faithfully copied in the alt alleles (though currently
            # as implemented we find our example using this information so it cannot
            # fail). Included here in case that changes in the future.
            self.assertEqual(
                list(tf_utils.example_alt_alleles_indices(example)),
                list(call_variants_output.alt_allele_indices.indices))
            # We should have exactly three genotype probabilities (assuming our
            # ploidy == 2).
            self.assertEqual(len(call_variants_output.genotype_probabilities),
                             3)
            # These are probabilities so they should be between 0 and 1.
            self.assertTrue(
                0 <= gp <= 1
                for gp in call_variants_output.genotype_probabilities)

    @parameterized.parameters((model, bad_format)
                              for model in modeling.production_models()
                              for bad_format in ['', 'png'])
    def test_call_variants_with_invalid_format(self, model, bad_format):
        # Read one good record from a valid file.
        example = next(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
        # Overwrite the image/format field to be an invalid value
        # (anything but 'raw').
        example.features.feature['image/format'].bytes_list.value[
            0] = bad_format
        source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
        io_utils.write_tfrecords([example], source_path)
        outfile = test_utils.test_tmpfile(
            'call_variants_invalid_format.tfrecord')

        with self.assertRaises(ValueError):
            call_variants.call_variants(
                examples_filename=source_path,
                checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
                model=model,
                output_file=outfile,
                batch_size=1,
                max_batches=1)

    @parameterized.parameters(model for model in modeling.production_models())
    def test_call_variants_with_no_shape(self, model):
        # Read one good record from a valid file.
        example = next(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
        # Remove image/shape.
        del example.features.feature['image/shape']
        source_path = test_utils.test_tmpfile(
            'make_examples_out_noshape.tfrecord')
        io_utils.write_tfrecords([example], source_path)
        with self.assertRaisesRegexp(
                ValueError,
                'Invalid image/shape: we expect to find an image/shape '
                'field with length 3.'):
            call_variants.prepare_inputs(source_path, model, batch_size=1)

    def test_call_variants_with_empty_input(self):
        source_path = test_utils.test_tmpfile('empty.tfrecord')
        io_utils.write_tfrecords([], source_path)
        # Make sure that prepare_inputs don't crash on empty input.
        call_variants.prepare_inputs(source_path,
                                     modeling.get_model('random_guess'),
                                     batch_size=1)
Exemplo n.º 5
0
class ModelTrainTest(parameterized.TestCase, tf.test.TestCase):
    @flagsaver.FlagSaver
    def test_training_works_with_compressed_inputs(self):
        """End-to-end test of model_train script."""
        self._run_tiny_training(
            model_name='mobilenet_v1',
            dataset=data_providers_test.make_golden_dataset(
                compressed_inputs=True, use_tpu=FLAGS.use_tpu))

    def _run_tiny_training(self, model_name, dataset, warm_start_from=''):
        """Runs one training step. This function always starts a new train_dir."""
        with mock.patch(
                'deepvariant.data_providers.'
                'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset:
            mock_get_input_fn_from_dataset.return_value = dataset
            FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex)
            FLAGS.batch_size = 2
            FLAGS.model_name = model_name
            FLAGS.save_interval_secs = -1
            FLAGS.save_interval_steps = 1
            FLAGS.number_of_steps = 1
            FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
            FLAGS.start_from_checkpoint = warm_start_from
            FLAGS.master = ''
            model_train.parse_and_run()
            # We have a checkpoint after training.
            mock_get_input_fn_from_dataset.assert_called_once_with(
                dataset_config_filename=FLAGS.dataset_config_pbtxt,
                mode=tf.estimator.ModeKeys.TRAIN,
                use_tpu=mock.ANY,
            )
            self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))

    @mock.patch('deepvariant' '.modeling.slim.losses.softmax_cross_entropy')
    @mock.patch('deepvariant' '.modeling.slim.losses.get_total_loss')
    def test_loss(self, mock_total_loss, mock_cross):
        labels = [[0, 1, 0], [1, 0, 0]]
        logits = 'Logits'
        smoothing = 0.01
        actual = model_train.loss(logits, labels, smoothing)
        mock_total_loss.assert_called_once_with()
        self.assertEqual(actual, mock_total_loss.return_value)
        mock_cross.assert_called_once_with(logits,
                                           labels,
                                           label_smoothing=smoothing,
                                           weights=1.0)

    @parameterized.parameters(model.name
                              for model in modeling.production_models()
                              if model.is_trainable)
    @flagsaver.FlagSaver
    def test_end2end(self, model_name):
        """End-to-end test of model_train script."""
        self._run_tiny_training(
            model_name=model_name,
            dataset=data_providers_test.make_golden_dataset(
                use_tpu=FLAGS.use_tpu))

    @flagsaver.FlagSaver
    def test_end2end_inception_v3_warm_up_from(self):
        """End-to-end test of model_train script."""
        checkpoint_dir = tf_test_utils.test_tmpdir('inception_v3_warm_up_from')
        tf_test_utils.write_fake_checkpoint('inception_v3',
                                            self.test_session(),
                                            checkpoint_dir)
        self._run_tiny_training(
            model_name='inception_v3',
            dataset=data_providers_test.make_golden_dataset(
                use_tpu=FLAGS.use_tpu),
            warm_start_from=checkpoint_dir + '/model')

    @flagsaver.FlagSaver
    def test_end2end_inception_v3_warm_up_from_mobilenet_v1(self):
        """Tests the behavior when warm start from mobilenet but train inception."""
        checkpoint_dir = tf_test_utils.test_tmpdir(
            'inception_v3_warm_up_from_mobilenet_v1')
        tf_test_utils.write_fake_checkpoint('mobilenet_v1',
                                            self.test_session(),
                                            checkpoint_dir)
        self.assertTrue(
            tf_test_utils.check_equals_checkpoint_top_scopes(
                checkpoint_dir + '/model', ['MobilenetV1', 'global_step']))
        self._run_tiny_training(
            model_name='inception_v3',
            dataset=data_providers_test.make_golden_dataset(
                use_tpu=FLAGS.use_tpu),
            warm_start_from=checkpoint_dir + '/model')
        self.assertTrue(
            tf_test_utils.check_equals_checkpoint_top_scopes(
                FLAGS.train_dir + '/model.ckpt-1',
                ['InceptionV3', 'global_step']))

    @flagsaver.FlagSaver
    def test_end2end_inception_v3_failed_warm_up_from(self):
        """End-to-end test of model_train script with a non-existent path."""
        with self.assertRaises(tf.errors.OpError):
            self._run_tiny_training(
                model_name='inception_v3',
                dataset=data_providers_test.make_golden_dataset(
                    use_tpu=FLAGS.use_tpu),
                warm_start_from='this/path/does/not/exist')

    @parameterized.parameters((False), (True))
    @flagsaver.FlagSaver
    @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter')
    @mock.patch('deepvariant.model_train.run')
    def test_main_internal(self, use_tpu, mock_run, mock_device_setter):
        FLAGS.master = 'some_master'
        FLAGS.use_tpu = use_tpu
        FLAGS.ps_tasks = 10
        FLAGS.task = 5

        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(10)
        mock_run.assert_called_once_with('some_master' if use_tpu else '',
                                         False,
                                         device_fn=mock.ANY,
                                         use_tpu=mock.ANY)

    @mock.patch('deepvariant.model_train.os.environ')
    @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter')
    @mock.patch('deepvariant.model_train.run')
    def test_main_tfconfig_local(self, mock_run, mock_device_setter,
                                 mock_environ):
        mock_environ.get.return_value = '{}'
        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(0)
        mock_run.assert_called_once_with('',
                                         True,
                                         device_fn=mock.ANY,
                                         use_tpu=mock.ANY)

    @parameterized.named_parameters(
        ('master', 'master', 0, True, '/job:master/task:0'),
        ('worker', 'worker', 10, False, '/job:worker/task:10'),
    )
    @mock.patch('deepvariant.model_train.tf.train.Server')
    @mock.patch('deepvariant.model_train.os.environ')
    @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter')
    @mock.patch('deepvariant.model_train.run')
    def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief,
                                expected_worker, mock_run, mock_device_setter,
                                mock_environ, mock_server):
        tf_config = {
            'cluster': {
                'ps': ['ps1:800', 'ps2:800']
            },
            'task': {
                'type': job_name,
                'index': task_index,
            },
        }

        class FakeServer(object):
            target = 'some-target'

        mock_environ.get.return_value = json.dumps(tf_config)
        mock_server.return_value = FakeServer()

        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(
            2, worker_device=expected_worker, cluster=mock.ANY)
        mock_run.assert_called_once_with('some-target',
                                         expected_is_chief,
                                         device_fn=mock.ANY,
                                         use_tpu=mock.ANY)

    @parameterized.parameters(
        ('master', 'some-master'),
        ('task', 10),
        ('ps_tasks', 5),
    )
    @flagsaver.FlagSaver
    @mock.patch('deepvariant.model_train.os.environ')
    def test_main_invalid_args(self, flag_name, flag_value, mock_environ):
        # Ensure an exception is raised if flags and TF_CONFIG are set.
        tf_config = {
            'cluster': {
                'ps': ['ps1:800', 'ps2:800']
            },
            'task': {
                'type': 'master',
                'index': 0,
            },
        }

        mock_environ.get.return_value = json.dumps(tf_config)
        setattr(FLAGS, flag_name, flag_value)
        self.assertRaises(ValueError, model_train.parse_and_run)
Exemplo n.º 6
0
class ModelTrainTest(parameterized.TestCase):
    @flagsaver.FlagSaver
    def test_training_works_with_compressed_inputs(self):
        """End-to-end test of model_train script."""
        self._run_tiny_training(
            model_name='mobilenet_v1',
            dataset=data_providers_test.make_golden_dataset(
                compressed_inputs=True))

    def _run_tiny_training(self, model_name, dataset):
        with mock.patch(
                'deepvariant.data_providers.get_dataset') as mock_get_dataset:
            mock_get_dataset.return_value = dataset
            FLAGS.train_dir = test_utils.test_tmpfile(model_name)
            FLAGS.batch_size = 2
            FLAGS.model_name = model_name
            FLAGS.save_interval_secs = 0
            FLAGS.number_of_steps = 1
            FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
            FLAGS.start_from_checkpoint = ''
            model_train.parse_and_run()
            # We have a checkpoint after training.
            mock_get_dataset.assert_called_once_with(
                FLAGS.dataset_config_pbtxt)
            self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))

    @parameterized.parameters(model.name
                              for model in modeling.production_models()
                              if model.is_trainable)
    @flagsaver.FlagSaver
    def test_end2end(self, model_name):
        """End-to-end test of model_train script."""
        self._run_tiny_training(
            model_name=model_name,
            dataset=data_providers_test.make_golden_dataset())

    @parameterized.parameters(
        (None, None),
        ('', None),
        ('/path/to/file', MOCK_SENTINEL_RETURN_VALUE),
        ('USE_FLAG_VALUE', MOCK_SENTINEL_RETURN_VALUE),
    )
    def test_model_init_function(self, path, expected):
        model = mock.Mock(spec=modeling.DeepVariantModel)
        model.initialize_from_checkpoint.return_value = MOCK_SENTINEL_RETURN_VALUE
        self.assertEqual(expected,
                         model_train.model_init_function(model, 3, path))
        if expected:
            model.initialize_from_checkpoint.assert_called_once_with(
                path, 3, is_training=True)
        else:
            test_utils.assert_not_called_workaround(
                model.initialize_from_checkpoint)

    @flagsaver.FlagSaver
    @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter')
    @mock.patch('deepvariant.model_train.run')
    def test_main_internal(self, mock_run, mock_device_setter):
        FLAGS.master = 'some_master'
        FLAGS.ps_tasks = 10
        FLAGS.task = 5

        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(10)
        mock_run.assert_called_once_with('some_master',
                                         False,
                                         device_fn=mock.ANY)

    @mock.patch('deepvariant.model_train.os.environ')
    @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter')
    @mock.patch('deepvariant.model_train.run')
    def test_main_tfconfig_local(self, mock_run, mock_device_setter,
                                 mock_environ):
        mock_environ.get.return_value = '{}'
        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(0)
        mock_run.assert_called_once_with('', True, device_fn=mock.ANY)

    @parameterized.named_parameters(
        ('master', 'master', 0, True, '/job:master/task:0'),
        ('worker', 'worker', 10, False, '/job:worker/task:10'),
    )
    @mock.patch('deepvariant.model_train.tf.train.Server')
    @mock.patch('deepvariant.model_train.os.environ')
    @mock.patch('deepvariant.model_train.' 'tf.train.replica_device_setter')
    @mock.patch('deepvariant.model_train.run')
    def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief,
                                expected_worker, mock_run, mock_device_setter,
                                mock_environ, mock_server):
        tf_config = {
            'cluster': {
                'ps': ['ps1:800', 'ps2:800']
            },
            'task': {
                'type': job_name,
                'index': task_index,
            },
        }

        class FakeServer(object):
            target = 'some-target'

        mock_environ.get.return_value = json.dumps(tf_config)
        mock_server.return_value = FakeServer()

        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(
            2, worker_device=expected_worker, cluster=mock.ANY)
        mock_run.assert_called_once_with('some-target',
                                         expected_is_chief,
                                         device_fn=mock.ANY)

    @parameterized.parameters(
        ('master', 'some-master'),
        ('task', 10),
        ('ps_tasks', 5),
    )
    @flagsaver.FlagSaver
    @mock.patch('deepvariant.model_train.os.environ')
    def test_main_invalid_args(self, flag_name, flag_value, mock_environ):
        # Ensure an exception is raised if flags and TF_CONFIG are set.
        tf_config = {
            'cluster': {
                'ps': ['ps1:800', 'ps2:800']
            },
            'task': {
                'type': 'master',
                'index': 0,
            },
        }

        mock_environ.get.return_value = json.dumps(tf_config)
        setattr(FLAGS, flag_name, flag_value)
        self.assertRaises(ValueError, model_train.parse_and_run)
Exemplo n.º 7
0
class ModelTrainTest(parameterized.TestCase, tf.test.TestCase):

  def _run_tiny_training(self, model_name, dataset, warm_start_from=''):
    """Runs one training step. This function always starts a new train_dir."""
    with mock.patch(
        'deepvariant.data_providers.'
        'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset:
      mock_get_input_fn_from_dataset.return_value = dataset
      FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex)
      FLAGS.batch_size = 2
      FLAGS.model_name = model_name
      FLAGS.save_interval_secs = -1
      FLAGS.save_interval_steps = 1
      FLAGS.number_of_steps = 1
      FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
      FLAGS.start_from_checkpoint = warm_start_from
      FLAGS.master = ''
      model_train.parse_and_run()
      # We have a checkpoint after training.
      mock_get_input_fn_from_dataset.assert_called_once_with(
          dataset_config_filename=FLAGS.dataset_config_pbtxt,
          mode=tf.estimator.ModeKeys.TRAIN,
          use_tpu=mock.ANY,
          max_examples=None,
      )
      self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))

  @mock.patch('deepvariant'
              '.modeling.tf.compat.v1.losses.softmax_cross_entropy')
  @mock.patch('deepvariant'
              '.modeling.tf.compat.v1.losses.get_total_loss')
  def test_loss(self, mock_total_loss, mock_cross):
    labels = [[0, 1, 0], [1, 0, 0]]
    logits = 'Logits'
    smoothing = 0.01
    actual = model_train.loss(logits, labels, smoothing)
    mock_total_loss.assert_called_once_with()
    self.assertEqual(actual, mock_total_loss.return_value)
    mock_cross.assert_called_once_with(
        logits, labels, label_smoothing=smoothing, weights=1.0)

  # pylint: disable=g-complex-comprehension
  @parameterized.parameters((model.name, compressed_inputs)
                            for model in modeling.production_models()
                            if model.is_trainable
                            for compressed_inputs in [True, False])
  # pylint: enable=g-complex-comprehension
  @flagsaver.flagsaver
  def test_end2end(self, model_name, compressed_inputs):
    """End-to-end test of model_train script."""
    self._run_tiny_training(
        model_name=model_name,
        dataset=data_providers_test.make_golden_dataset(
            compressed_inputs=compressed_inputs, use_tpu=FLAGS.use_tpu))

  @flagsaver.flagsaver
  def test_end2end_inception_v3_warm_up_from(self):
    """End-to-end test of model_train script."""
    checkpoint_dir = tf_test_utils.test_tmpdir('inception_v3_warm_up_from')
    tf_test_utils.write_fake_checkpoint('inception_v3', self.test_session(),
                                        checkpoint_dir)
    self._run_tiny_training(
        model_name='inception_v3',
        dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu),
        warm_start_from=checkpoint_dir + '/model')

  @flagsaver.flagsaver
  def test_end2end_inception_v3_warm_up_allow_different_num_channels(self):
    """End-to-end test of model_train script."""
    FLAGS.allow_warmstart_from_different_num_channels = True
    checkpoint_dir = tf_test_utils.test_tmpdir(
        'inception_v3_warm_up_allow_different_num_channels')
    tf_test_utils.write_fake_checkpoint(
        'inception_v3',
        self.test_session(),
        checkpoint_dir,
        num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1)
    self._run_tiny_training(
        model_name='inception_v3',
        dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu),
        warm_start_from=checkpoint_dir + '/model')

  @flagsaver.flagsaver
  def test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels(self):
    """End-to-end test of model_train script."""
    checkpoint_dir = tf_test_utils.test_tmpdir(
        'test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels')
    tf_test_utils.write_fake_checkpoint(
        'inception_v3',
        self.test_session(),
        checkpoint_dir,
        num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1)
    with self.assertRaisesRegex(
        ValueError,
        r'Shape of variable InceptionV3/Conv2d_1a_3x3/weights:0 \(\(.*\)\) '
        r'doesn\'t match with shape of tensor '
        r'InceptionV3/Conv2d_1a_3x3/weights \(\[.*\]\) from checkpoint reader.'
    ):
      self._run_tiny_training(
          model_name='inception_v3',
          dataset=data_providers_test.make_golden_dataset(
              use_tpu=FLAGS.use_tpu),
          warm_start_from=checkpoint_dir + '/model')

  @flagsaver.flagsaver
  def test_end2end_inception_v3_failed_warm_up_from(self):
    """End-to-end test of model_train script with a non-existent path."""
    # Internal TF raises tf.errors.OpError, public TF raises ValueError.
    with self.assertRaises((tf.errors.OpError, ValueError)):
      self._run_tiny_training(
          model_name='inception_v3',
          dataset=data_providers_test.make_golden_dataset(
              use_tpu=FLAGS.use_tpu),
          warm_start_from='this/path/does/not/exist')

  @flagsaver.flagsaver
  def test_end2end_inception_v3_embedding_invalid_embedding_size(self):
    """End-to-end test of model_train script with an invalid embedding size."""
    with six.assertRaisesRegex(
        self, ValueError, 'Expected seq_type_embedding_size '
        'to be a positive number but saw -100 '
        'instead.'):
      FLAGS.seq_type_embedding_size = -100
      self._run_tiny_training(
          model_name='inception_v3_embedding',
          dataset=data_providers_test.make_golden_dataset(
              use_tpu=FLAGS.use_tpu))

  @parameterized.parameters((False), (True))
  @flagsaver.flagsaver
  @mock.patch('deepvariant.model_train.'
              'tf.compat.v1.train.replica_device_setter')
  @mock.patch('deepvariant.model_train.run')
  def test_main_internal(self, use_tpu, mock_run, mock_device_setter):
    FLAGS.master = 'some_master'
    FLAGS.use_tpu = use_tpu
    FLAGS.ps_tasks = 10
    FLAGS.task = 5

    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(10)
    mock_run.assert_called_once_with(
        'some_master' if use_tpu else '',
        False,
        device_fn=mock.ANY,
        use_tpu=mock.ANY)

  @mock.patch('deepvariant.model_train.os.environ')
  @mock.patch('deepvariant.model_train.'
              'tf.compat.v1.train.replica_device_setter')
  @mock.patch('deepvariant.model_train.run')
  def test_main_tfconfig_local(self, mock_run, mock_device_setter,
                               mock_environ):
    mock_environ.get.return_value = '{}'
    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(0)
    mock_run.assert_called_once_with(
        '', True, device_fn=mock.ANY, use_tpu=mock.ANY)

  @parameterized.named_parameters(
      ('master', 'master', 0, True, '/job:master/task:0'),
      ('worker', 'worker', 10, False, '/job:worker/task:10'),
  )
  @mock.patch(
      'deepvariant.model_train.tf.distribute.Server')
  @mock.patch('deepvariant.model_train.os.environ')
  @mock.patch('deepvariant.model_train.'
              'tf.compat.v1.train.replica_device_setter')
  @mock.patch('deepvariant.model_train.run')
  def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief,
                              expected_worker, mock_run, mock_device_setter,
                              mock_environ, mock_server):
    tf_config = {
        'cluster': {
            'ps': ['ps1:800', 'ps2:800']
        },
        'task': {
            'type': job_name,
            'index': task_index,
        },
    }

    class FakeServer(object):
      target = 'some-target'

    mock_environ.get.return_value = json.dumps(tf_config)
    mock_server.return_value = FakeServer()

    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(
        2, worker_device=expected_worker, cluster=mock.ANY)
    mock_run.assert_called_once_with(
        'some-target', expected_is_chief, device_fn=mock.ANY, use_tpu=mock.ANY)

  @parameterized.parameters(
      ('master', 'some-master'),
      ('task', 10),
      ('ps_tasks', 5),
  )
  @flagsaver.flagsaver
  @mock.patch('deepvariant.model_train.os.environ')
  def test_main_invalid_args(self, flag_name, flag_value, mock_environ):
    # Ensure an exception is raised if flags and TF_CONFIG are set.
    tf_config = {
        'cluster': {
            'ps': ['ps1:800', 'ps2:800']
        },
        'task': {
            'type': 'master',
            'index': 0,
        },
    }

    mock_environ.get.return_value = json.dumps(tf_config)
    setattr(FLAGS, flag_name, flag_value)
    self.assertRaises(ValueError, model_train.parse_and_run)