Ejemplo n.º 1
0
class ModelingTest(
        six.with_metaclass(parameterized.TestGeneratorMetaclass,
                           tf.test.TestCase)):
    @parameterized.parameters((model_class().name, type(model_class()))
                              for model_class in modeling.all_models())
    def test_get_model_existing_models(self, model_name, expected_class):
        self.assertIsInstance(modeling.get_model(model_name), expected_class)

    def test_get_model_unknown_model_signals_error(self):
        with six.assertRaisesRegex(self, ValueError, 'Unknown model'):
            modeling.get_model('unknown_model_1234')

    def test_make_deepvariant_slim_model(self):
        model = modeling.DeepVariantSlimModel(
            name='foo',
            n_classes_model_variable=['n_classes'],
            excluded_scopes_for_incompatible_classes=['logits'],
            excluded_scopes_for_incompatible_channels=['logits'],
            pretrained_model_path='path')

        self.assertEqual('foo', model.name)
        self.assertEqual(['n_classes'], model.n_classes_model_variable)
        self.assertEqual(['logits'],
                         model.excluded_scopes_for_incompatible_classes)
        self.assertEqual(['logits'],
                         model.excluded_scopes_for_incompatible_channels)
        self.assertEqual('path', model.pretrained_model_path)

    def test_is_encoded_variant_type(self):
        types = [
            tf_utils.EncodedVariantType.SNP.value,
            tf_utils.EncodedVariantType.INDEL.value
        ]
        tensor = tf.constant(types * 4, dtype=tf.int64)

        def _run(tensor_to_run):
            with self.test_session() as sess:
                return list(sess.run(tensor_to_run))

        self.assertEqual(
            _run(
                modeling.is_encoded_variant_type(
                    tensor, tf_utils.EncodedVariantType.SNP)),
            [True, False] * 4)
        self.assertEqual(
            _run(
                modeling.is_encoded_variant_type(
                    tensor, tf_utils.EncodedVariantType.INDEL)),
            [False, True] * 4)

    @parameterized.parameters(
        dict(labels=[0, 2, 1, 0], target_class=0, expected=[0, 1, 1, 0]),
        dict(labels=[0, 2, 1, 0], target_class=1, expected=[1, 1, 0, 1]),
        dict(labels=[0, 2, 1, 0], target_class=2, expected=[1, 0, 1, 1]),
    )
    def test_binarize(self, labels, target_class, expected):
        with self.test_session() as sess:
            result = sess.run(
                modeling.binarize(np.array(labels), np.array(target_class)))
            self.assertListEqual(result.tolist(), expected)

    @parameterized.parameters([True, False])
    def test_eval_metric_fn(self, include_variant_types):
        labels = tf.constant([1, 0], dtype=tf.int64)
        predictions = tf.constant([[1, 0], [0, 1]], dtype=tf.int64)
        if include_variant_types:
            variant_types = tf.constant([0, 1], dtype=tf.int64)
        else:
            variant_types = None

        expected = modeling.eval_function_metrics(
            has_variant_types=include_variant_types)
        actual = modeling.eval_metric_fn(labels, predictions, variant_types)
        self.assertEqual(set(expected.keys()), set(actual.keys()))

    def test_variables_to_restore_from_model(self):
        model = modeling.DeepVariantModel('test', 'path')
        # We haven't created a slim model, so the variables_to_restore_from_model
        # should be returning an empty list.
        self.assertEqual([], model.variables_to_restore_from_model())

        # Create two model variable and one regular variables.
        with tf.compat.v1.variable_scope('model'):
            with tf.compat.v1.variable_scope('l1'):
                w1 = slim.model_variable('w1', shape=[10, 3, 3])
            with tf.compat.v1.variable_scope('l2'):
                w2 = slim.model_variable('w2', shape=[10, 3, 3])
                w3 = slim.model_variable('w3', shape=[10, 3, 3])
        v1 = slim.variable('my_var', shape=[20, 1])

        # The only variables in the system are the three we've created.
        six.assertCountEqual(self, [w1, w2, w3, v1], slim.get_variables())

        # We get just the three model variables without any excludes.
        six.assertCountEqual(self, [w1, w2, w3],
                             model.variables_to_restore_from_model())
        # As well as when exclude_scopes is an empty list.
        six.assertCountEqual(
            self, [w1, w2, w3],
            model.variables_to_restore_from_model(exclude_scopes=[]))

        # Excluding model/l1 variables gives us w2 and w3.
        six.assertCountEqual(
            self, [w2, w3],
            model.variables_to_restore_from_model(exclude_scopes=['model/l1']))
        # Excluding model/l2 gives us just w1 back.
        six.assertCountEqual(
            self, [w1],
            model.variables_to_restore_from_model(exclude_scopes=['model/l2']))
        # Excluding multiple scopes works as expected.
        six.assertCountEqual(
            self, [],
            model.variables_to_restore_from_model(
                exclude_scopes=['model/l1', 'model/l2']))
        # Excluding the root model scope also produces no variables..
        six.assertCountEqual(
            self, [],
            model.variables_to_restore_from_model(exclude_scopes=['model']))
Ejemplo n.º 2
0
class ModelingTest(
    six.with_metaclass(parameterized.TestGeneratorMetaclass, tf.test.TestCase)):

  @parameterized.parameters(
      (model.name, type(model)) for model in modeling.all_models())
  def test_get_model_existing_models(self, model_name, expected_class):
    self.assertIsInstance(modeling.get_model(model_name), expected_class)

  def test_get_model_unknown_model_signals_error(self):
    with self.assertRaisesRegexp(ValueError, 'Unknown model'):
      modeling.get_model('unknown_model_1234')

  def test_make_deepvariant_slim_model(self):
    model = modeling.DeepVariantSlimModel(
        name='foo',
        n_classes_model_variable=['n_classes'],
        excluded_scopes=['logits'],
        pretrained_model_path='path')

    self.assertEqual('foo', model.name)
    self.assertEqual(['n_classes'], model.n_classes_model_variable)
    self.assertEqual(['logits'], model.excluded_scopes)
    self.assertEqual('path', model.pretrained_model_path)

  def test_variables_to_restore_from_model(self):
    model = modeling.DeepVariantModel('test', 'path')
    # We haven't created a slim model, so the variables_to_restore_from_model
    # should be returning an empty list.
    self.assertEqual([], model.variables_to_restore_from_model())

    # Create two model variable and one regular variables.
    with tf.variable_scope('model'):
      with tf.variable_scope('l1'):
        w1 = slim.model_variable('w1', shape=[10, 3, 3])
      with tf.variable_scope('l2'):
        w2 = slim.model_variable('w2', shape=[10, 3, 3])
        w3 = slim.model_variable('w3', shape=[10, 3, 3])
    v1 = slim.variable('my_var', shape=[20, 1])

    # The only variables in the system are the three we've created.
    self.assertItemsEqual([w1, w2, w3, v1], slim.get_variables())

    # We get just the three model variables without any excludes.
    self.assertItemsEqual([w1, w2, w3], model.variables_to_restore_from_model())
    # As well as when exclude_scopes is an empty list.
    self.assertItemsEqual(
        [w1, w2, w3], model.variables_to_restore_from_model(exclude_scopes=[]))

    # Excluding model/l1 variables gives us w2 and w3.
    self.assertItemsEqual(
        [w2, w3],
        model.variables_to_restore_from_model(exclude_scopes=['model/l1']))
    # Excluding model/l2 gives us just w1 back.
    self.assertItemsEqual(
        [w1],
        model.variables_to_restore_from_model(exclude_scopes=['model/l2']))
    # Excluding multiple scopes works as expected.
    self.assertItemsEqual(
        [],
        model.variables_to_restore_from_model(
            exclude_scopes=['model/l1', 'model/l2']))
    # Excluding the root model scope also produces no variables..
    self.assertItemsEqual(
        [], model.variables_to_restore_from_model(exclude_scopes=['model']))