def test_disable_backprop_homology(self): encoder_cls = functools.partial(encoders.LookupEncoder, emb_dim=32) aligner_cls = functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties) heads_cls = multi_task.Backbone( embeddings=[nlp_layers.DensePerTokenOutputHead], alignments=[dedal.Selector, homology.LogCorrectedLogits]) backprop = multi_task.Backbone(embeddings=[True], alignments=[True, False]) model = dedal.Dedal(encoder_cls=encoder_cls, aligner_cls=aligner_cls, heads_cls=heads_cls, backprop=backprop) inputs = tf.random.uniform((4, 16), maxval=25, dtype=tf.int32) with tf.GradientTape(persistent=True) as tape: y_pred = model(inputs).flatten() dummy_loss = tf.reduce_sum(y_pred['alignments/1']**2) grads_encoder = tape.gradient( dummy_loss, model.encoder.trainable_variables[0], unconnected_gradients=tf.UnconnectedGradients.ZERO) self.assertAllClose(tf.linalg.norm(grads_encoder), 0.0) with tf.GradientTape(persistent=True) as tape: y_pred = model(inputs).flatten() dummy_loss = tf.reduce_sum(y_pred['embeddings/0']**2) grads_encoder = tape.gradient( dummy_loss, model.encoder.trainable_variables[0], unconnected_gradients=tf.UnconnectedGradients.ZERO) self.assertGreater(tf.linalg.norm(grads_encoder), 0.0)
def setUp(self): super().setUp() self.dim = 3 self.heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits], alignments=[ homology.UncorrectedLogits, homology.LogCorrectedLogits ]) self.model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder, aligner_cls=functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties, align_fn=fake_align), heads_cls=self.heads_cls) batch1, batch2 = 32, 16 seq_len1, seq_len2 = 100, 50 self.inputs1 = tf.random.uniform((batch1, seq_len1), maxval=35, dtype=tf.int32) self.inputs2 = tf.random.uniform((batch2, seq_len2), maxval=35, dtype=tf.int32) self.switch = multi_task.SwitchBackbone(embeddings=[1], alignments=[0, 0])
def test_alignment_outputs(self): """A test with complex outputs.""" heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits], alignments=[ dedal.Selector, homology.UncorrectedLogits, homology.LogCorrectedLogits ]) model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder, aligner_cls=functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties, align_fn=fake_align), heads_cls=heads_cls, process_negatives=True) preds = model(self.inputs) self.assertEqual(preds.embeddings[0].shape, (32, 10)) self.assertEqual(preds.alignments[1].shape, (32, 1)) self.assertEqual(preds.alignments[2].shape, (32, 1)) aligment_pred = preds.alignments[0] self.assertLen(aligment_pred, 3) scores, paths, sw_params = aligment_pred self.assertEqual(scores.shape, (32, )) self.assertLen(sw_params, 3) self.assertIsNone(paths) self.assertEqual(sw_params[0].shape, (32, 100, 100))
def test_without_negatives(self): model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder, aligner_cls=functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties, align_fn=fake_align), heads_cls=self.heads_cls, process_negatives=False) preds = model(self.inputs) self.assertEqual(preds.embeddings[0].shape, (32, 10)) self.assertEqual(preds.alignments[0].shape, (16, 1)) self.assertEqual(preds.alignments[1].shape, (16, 1))