Example #1
0
    def test_disable_backprop_homology(self):
        encoder_cls = functools.partial(encoders.LookupEncoder, emb_dim=32)
        aligner_cls = functools.partial(
            aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties)
        heads_cls = multi_task.Backbone(
            embeddings=[nlp_layers.DensePerTokenOutputHead],
            alignments=[dedal.Selector, homology.LogCorrectedLogits])
        backprop = multi_task.Backbone(embeddings=[True],
                                       alignments=[True, False])
        model = dedal.Dedal(encoder_cls=encoder_cls,
                            aligner_cls=aligner_cls,
                            heads_cls=heads_cls,
                            backprop=backprop)
        inputs = tf.random.uniform((4, 16), maxval=25, dtype=tf.int32)

        with tf.GradientTape(persistent=True) as tape:
            y_pred = model(inputs).flatten()
            dummy_loss = tf.reduce_sum(y_pred['alignments/1']**2)
        grads_encoder = tape.gradient(
            dummy_loss,
            model.encoder.trainable_variables[0],
            unconnected_gradients=tf.UnconnectedGradients.ZERO)
        self.assertAllClose(tf.linalg.norm(grads_encoder), 0.0)

        with tf.GradientTape(persistent=True) as tape:
            y_pred = model(inputs).flatten()
            dummy_loss = tf.reduce_sum(y_pred['embeddings/0']**2)
        grads_encoder = tape.gradient(
            dummy_loss,
            model.encoder.trainable_variables[0],
            unconnected_gradients=tf.UnconnectedGradients.ZERO)
        self.assertGreater(tf.linalg.norm(grads_encoder), 0.0)
Example #2
0
    def setUp(self):
        super().setUp()
        self.dim = 3
        self.heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits],
                                             alignments=[
                                                 homology.UncorrectedLogits,
                                                 homology.LogCorrectedLogits
                                             ])
        self.model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder,
                                 aligner_cls=functools.partial(
                                     aligners.SoftAligner,
                                     gap_pen_cls=aligners.ConstantGapPenalties,
                                     align_fn=fake_align),
                                 heads_cls=self.heads_cls)

        batch1, batch2 = 32, 16
        seq_len1, seq_len2 = 100, 50
        self.inputs1 = tf.random.uniform((batch1, seq_len1),
                                         maxval=35,
                                         dtype=tf.int32)
        self.inputs2 = tf.random.uniform((batch2, seq_len2),
                                         maxval=35,
                                         dtype=tf.int32)

        self.switch = multi_task.SwitchBackbone(embeddings=[1],
                                                alignments=[0, 0])
Example #3
0
 def test_alignment_outputs(self):
     """A test with complex outputs."""
     heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits],
                                     alignments=[
                                         dedal.Selector,
                                         homology.UncorrectedLogits,
                                         homology.LogCorrectedLogits
                                     ])
     model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder,
                         aligner_cls=functools.partial(
                             aligners.SoftAligner,
                             gap_pen_cls=aligners.ConstantGapPenalties,
                             align_fn=fake_align),
                         heads_cls=heads_cls,
                         process_negatives=True)
     preds = model(self.inputs)
     self.assertEqual(preds.embeddings[0].shape, (32, 10))
     self.assertEqual(preds.alignments[1].shape, (32, 1))
     self.assertEqual(preds.alignments[2].shape, (32, 1))
     aligment_pred = preds.alignments[0]
     self.assertLen(aligment_pred, 3)
     scores, paths, sw_params = aligment_pred
     self.assertEqual(scores.shape, (32, ))
     self.assertLen(sw_params, 3)
     self.assertIsNone(paths)
     self.assertEqual(sw_params[0].shape, (32, 100, 100))
Example #4
0
 def test_without_negatives(self):
     model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder,
                         aligner_cls=functools.partial(
                             aligners.SoftAligner,
                             gap_pen_cls=aligners.ConstantGapPenalties,
                             align_fn=fake_align),
                         heads_cls=self.heads_cls,
                         process_negatives=False)
     preds = model(self.inputs)
     self.assertEqual(preds.embeddings[0].shape, (32, 10))
     self.assertEqual(preds.alignments[0].shape, (16, 1))
     self.assertEqual(preds.alignments[1].shape, (16, 1))