예제 #1
0
    def test_disable_backprop_homology(self):
        encoder_cls = functools.partial(encoders.LookupEncoder, emb_dim=32)
        aligner_cls = functools.partial(
            aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties)
        heads_cls = multi_task.Backbone(
            embeddings=[nlp_layers.DensePerTokenOutputHead],
            alignments=[dedal.Selector, homology.LogCorrectedLogits])
        backprop = multi_task.Backbone(embeddings=[True],
                                       alignments=[True, False])
        model = dedal.Dedal(encoder_cls=encoder_cls,
                            aligner_cls=aligner_cls,
                            heads_cls=heads_cls,
                            backprop=backprop)
        inputs = tf.random.uniform((4, 16), maxval=25, dtype=tf.int32)

        with tf.GradientTape(persistent=True) as tape:
            y_pred = model(inputs).flatten()
            dummy_loss = tf.reduce_sum(y_pred['alignments/1']**2)
        grads_encoder = tape.gradient(
            dummy_loss,
            model.encoder.trainable_variables[0],
            unconnected_gradients=tf.UnconnectedGradients.ZERO)
        self.assertAllClose(tf.linalg.norm(grads_encoder), 0.0)

        with tf.GradientTape(persistent=True) as tape:
            y_pred = model(inputs).flatten()
            dummy_loss = tf.reduce_sum(y_pred['embeddings/0']**2)
        grads_encoder = tape.gradient(
            dummy_loss,
            model.encoder.trainable_variables[0],
            unconnected_gradients=tf.UnconnectedGradients.ZERO)
        self.assertGreater(tf.linalg.norm(grads_encoder), 0.0)
예제 #2
0
    def __init__(self,
                 workdir,
                 strategy,
                 split=None,
                 task=None,
                 scalars=multi_task.Backbone(),
                 images=multi_task.Backbone(),
                 means=(),
                 every=1000,
                 reset_every_step=False,
                 start_clock=True):
        """Initialization.

    Args:
      workdir: the parent directory where to store data.
      strategy: distribution strategy.
      split: usually the name of the phase (train, test, valid).
      task: usually the name of the task (train, evaluate, downstream).
      scalars: the scalar metrics to be computed and dumped.
      images: the image metrics to be computed and dumped.
      means: the name of the scalar metrics that will be means. At the very
        least, "loss" and "gradient_norm" will be present.
      every: the periodicity to log the metrics.
      reset_every_step: whether to reset the metrics at every step.
      start_clock: whether or not to start the clock at instantiation.
    """
        split = '' if split is None else split
        self.workdir = os.path.join(workdir, split).rstrip('/')
        self._split = split
        self._task = task
        self._timer = timer.Timer()
        self._reset_every_step = reset_every_step
        self.training = task == 'train'

        # Take the bigger network structure.
        shape = tuple(max(scalars.shape[i], images.shape[i]) for i in range(2))
        enveloppe = multi_task.Backbone.constant_from_shape([], shape)

        means = set(means).union(['loss'])
        if self.training:
            means = means.union(['gradient_norm'])

        with strategy.scope():
            self._scalars = enveloppe.pack([[metric_factory(m) for m in ms]
                                            for ms in scalars],
                                           default_value=[])
            self._images = enveloppe.pack([[metric_factory(m) for m in ms]
                                           for ms in images],
                                          default_value=[])
            self._means = {name: tf.keras.metrics.Mean(name) for name in means}

        self._summary_writer = tf.summary.create_file_writer(self.workdir)
        self._every = every
        self._last_step = None if self.training else 0

        if start_clock:
            self.restart_clock()
예제 #3
0
    def setUp(self):
        super().setUp()
        self.dim = 3
        self.heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits],
                                             alignments=[
                                                 homology.UncorrectedLogits,
                                                 homology.LogCorrectedLogits
                                             ])
        self.model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder,
                                 aligner_cls=functools.partial(
                                     aligners.SoftAligner,
                                     gap_pen_cls=aligners.ConstantGapPenalties,
                                     align_fn=fake_align),
                                 heads_cls=self.heads_cls)

        batch1, batch2 = 32, 16
        seq_len1, seq_len2 = 100, 50
        self.inputs1 = tf.random.uniform((batch1, seq_len1),
                                         maxval=35,
                                         dtype=tf.int32)
        self.inputs2 = tf.random.uniform((batch2, seq_len2),
                                         maxval=35,
                                         dtype=tf.int32)

        self.switch = multi_task.SwitchBackbone(embeddings=[1],
                                                alignments=[0, 0])
예제 #4
0
 def test_alignment_outputs(self):
     """A test with complex outputs."""
     heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits],
                                     alignments=[
                                         dedal.Selector,
                                         homology.UncorrectedLogits,
                                         homology.LogCorrectedLogits
                                     ])
     model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder,
                         aligner_cls=functools.partial(
                             aligners.SoftAligner,
                             gap_pen_cls=aligners.ConstantGapPenalties,
                             align_fn=fake_align),
                         heads_cls=heads_cls,
                         process_negatives=True)
     preds = model(self.inputs)
     self.assertEqual(preds.embeddings[0].shape, (32, 10))
     self.assertEqual(preds.alignments[1].shape, (32, 1))
     self.assertEqual(preds.alignments[2].shape, (32, 1))
     aligment_pred = preds.alignments[0]
     self.assertLen(aligment_pred, 3)
     scores, paths, sw_params = aligment_pred
     self.assertEqual(scores.shape, (32, ))
     self.assertLen(sw_params, 3)
     self.assertIsNone(paths)
     self.assertEqual(sw_params[0].shape, (32, 100, 100))
예제 #5
0
 def _make_loop_with_reference(self):
     """Creates a training loop for alignments with self._loop as reference."""
     seq_len = gin.query_parameter('%SEQUENCE_LENGTH')
     model_cls = functools.partial(
         dedal.Dedal,
         encoder_cls=functools.partial(encoders.TransformerEncoder,
                                       emb_dim=48,
                                       num_layers=1,
                                       num_heads=2,
                                       mlp_dim=3 * seq_len,
                                       max_len=seq_len),
         aligner_cls=aligners.SoftAligner,
         heads_cls=multi_task.Backbone(
             embeddings=[], alignments=[homology.UncorrectedLogits]),
     )
     workdir2 = tempfile.mkdtemp()
     ds_builder = builder.DatasetBuilder(labels=multi_task.Backbone(
         alignments=[('target', 'weights')]))
     return training_loop.TrainingLoop(
         workdir=workdir2,
         strategy=self._loop.strategy,
         dataset_builder=ds_builder,
         logger_cls=functools.partial(
             logger.Logger,
             scalars=multi_task.Backbone(
                 alignments=[[tf.keras.metrics.BinaryAccuracy]]),
             every=5),
         loss_fn=losses.MultiTaskLoss(losses=multi_task.Backbone(
             alignments=[
                 tf.keras.losses.BinaryCrossentropy(
                     from_logits=True,
                     reduction=tf.keras.losses.Reduction.NONE)
             ])),
         optimizer_cls=self._loop._optimizer_cls,
         batch_size=self._loop._batch_size,
         model_cls=model_cls,
         num_steps=self._loop._num_steps,
         reference_workdir=self._loop._workdir,
         num_reference_steps=self._loop._num_steps)
예제 #6
0
    def __init__(self,
                 encoder_cls=gin.REQUIRED,
                 aligner_cls=gin.REQUIRED,
                 heads_cls=gin.REQUIRED,
                 process_negatives=True,
                 switch=None,
                 backprop=None,
                 **kwargs):
        """Initializes.

    Args:
      encoder_cls: a keras layer (or model) that turns a batch of sequences
        into a batch of sequence embeddings.
      aligner_cls: a layer that turns a pairs of sequence embeddings into
        scores (and optionally paths).
      heads_cls: for a multi-task setup all the layers to be plugged either on
        the embeddings or on the alignments.
      process_negatives: should the network consider the negatives or not in a
        batch. Aligning might be expensive, so switching off negative
        supervision saves half of the alignment cost.
      switch: optional mapping of output heads to inputs, required only when
        calling the model in multi-input mode. See `call` for additional
        details.
      backprop: for a multi-task setup, whether to backprop each loss from the
        output head to the encoder (i.e. finetune the encoder on the loss) or
        train the output head params only.
      **kwargs: optional keyword arguments to be passed to `tf.keras.Model`.
    """
        super().__init__(**kwargs)
        self.encoder = encoder_cls()
        self.aligner = aligner_cls() if aligner_cls else None
        self.heads = multi_task.Backbone(
            embeddings=[
                head_cls() if head_cls is not None else None
                for head_cls in heads_cls.embeddings
            ],
            alignments=[
                head_cls() if head_cls is not None else None
                for head_cls in heads_cls.alignments
            ])
        self.process_negatives = process_negatives
        self.switch = switch
        if self.switch is None:
            self.switch = multi_task.SwitchBackbone.constant_like(self.heads)
        self.backprop = backprop
        if self.backprop is None:
            self.backprop = self.heads.constant_copy(True)
        # For TF to keep track of variables
        self._flat_heads = self.heads.flatten()
예제 #7
0
def make_fake_builder(max_len = 512):
  return builder.DatasetBuilder(
      data_loader=FakePairsLoader(max_len=max_len),
      labels=multi_task.Backbone(alignments=[
          ('alignment/targets', 'alignment/weights'),
          'homology/targets']),
      batched_transformations=[
          transforms.Reshape(shape=(-1, max_len), on='sequence'),
          transforms.Reshape(shape=(-1,), on='fam_key'),
          align_transforms.CreateBatchedWeights(
              on='alignment/targets', out='alignment/weights'),
          align_transforms.PadNegativePairs(
              on=['alignment/targets', 'alignment/weights']),
          align_transforms.PadNegativePairs(
              value=-1.0, on=['alignment/pid1', 'alignment/pid3']),
          align_transforms.CreateHomologyTargets(
              process_negatives=True, on='fam_key', out='homology/targets'),
      ],
      metadata=('alignment/pid1', 'alignment/pid3'),
      split='train')
예제 #8
0
    def forward(self, inputs, selector=None, training=True):
        """Run the models on a single input and potentially selects some heads only.

    Args:
      inputs: a Tensor<int32>[batch, seq_len] representing protein sequences.
      selector: If set a multi_task.Backbone[bool] to specify which head to
        apply. For non selected heads, a None will replace the output. If not
        set, all the heads will be output.
      training: whether to run in training mode or eval mode.

    Returns:
      A multi_task.Backbone of tensor corresponding to the output of the
      different heads of the model.
    """
        selector = self.heads.constant_copy(
            True) if selector is None else selector

        embeddings = self.encoder(inputs, training=training)
        masks = self.encoder.compute_mask(inputs)

        result = multi_task.Backbone()
        for head, on, backprop, in zip(self.heads.embeddings,
                                       selector.embeddings,
                                       self.backprop.embeddings):
            head_output = self.head_output(head,
                                           on,
                                           backprop,
                                           embeddings,
                                           mask=masks,
                                           training=training)
            result.embeddings.append(head_output)

        if not self.heads.alignments or not any(selector.alignments):
            # Ensures structure of result matches self.heads even when method skips
            # alignment phase due to selector.
            for _ in selector.alignments:
                result.alignments.append(tf.constant([]))
            return result

        # For each head, we compute the output of positive pairs and negative ones,
        # then concatenate to obtain an output batch where the first half is
        # positive and the second half is negative.
        outputs = []
        pos_indices = pairs_lib.consecutive_indices(inputs)
        neg_indices = (pairs_lib.roll_indices(pos_indices)
                       if self.process_negatives else None)
        num_alignment_calls = 1 + int(self.process_negatives)
        for indices in (pos_indices, neg_indices)[:num_alignment_calls]:
            curr = []
            embedding_pairs, mask_pairs = pairs_lib.build(
                indices, embeddings, masks)
            alignments = self.aligner(embedding_pairs,
                                      mask=mask_pairs,
                                      training=training)
            for head, on, backprop, in zip(self.heads.alignments,
                                           selector.alignments,
                                           self.backprop.alignments):
                head_output = self.head_output(head,
                                               on,
                                               backprop,
                                               alignments,
                                               mask=mask_pairs,
                                               training=training)
                curr.append(head_output)
            outputs.append(curr)

        for output in merge(*outputs):
            result.alignments.append(output)
        return result
 def test_backbone(self):
     values = [1, 2, 4]
     container = multi_task.Backbone(values)
     self.assertSequenceEqual(container.embeddings, values)
     self.assertSequenceEqual(container.alignments, [])
     self.assertSequenceEqual(list(container), values)
예제 #10
0
def make_tape_builder(root_dir,
                      task,
                      target,
                      weights=None,
                      metadata=(),
                      max_len=1024,
                      input_sequence_key='primary',
                      output_sequence_key='sequence'):
    """Creates a DatasetBuilder for TAPE's benchmark."""
    supported_tasks = list(TAPE_NUM_OUTPUTS)
    if task not in supported_tasks:
        raise ValueError(f'Task {task} not recognized.'
                         f'Supported tasks: {", ".join(supported_tasks)}.')
    num_outputs = TAPE_NUM_OUTPUTS[task].get(target, 1)

    used_keys = [input_sequence_key, target]
    if weights is not None:
        used_keys.append(weights)
    if metadata:
        used_keys.extend(metadata)
    unused_keys = [k for k in TAPE_SPECS[task] if k not in used_keys]

    ds_transformations = []
    if max_len is not None:
        ds_transformations.append(
            transforms.FilterByLength(on=output_sequence_key,
                                      precomputed=False,
                                      max_len=max_len - 1))

    transformations = [
        transforms.Pop(on=unused_keys),
        transforms.Reshape(on=output_sequence_key, shape=[]),
        transforms.Encode(on=output_sequence_key),
        transforms.EOS(on=output_sequence_key),
        transforms.CropOrPad(on=output_sequence_key, size=max_len),
    ]

    if target in TAPE_MULTI_CL_TASKS:
        transformations.append(transforms.OneHot(on=target, depth=num_outputs))
    elif target in TAPE_BACKBONE_ANGLE_TASKS:
        transformations.append(transforms.BackboneAngleTransform(on=target))
    elif target in TAPE_PROT_ENGINEERING_TASKS:
        transformations.append(transforms.Reshape(on=target, shape=[-1]))

    if target in TAPE_SEQ2SEQ_TASKS:
        transformations.extend([
            transforms.Reshape(on=target, shape=[-1, num_outputs]),
            transforms.CropOrPadND(on=target, size=max_len, axis=0),
        ])

    if weights is not None:  # Note: no seq-level TAPE task has weights.
        transformations.extend([
            transforms.Reshape(on=weights, shape=[-1]),
            transforms.CropOrPadND(on=weights, size=max_len),
        ])

    embeddings_labels = [target] if weights is None else [(target, weights)]
    return builder.DatasetBuilder(
        data_loader=make_tape_loader(root_dir=root_dir, task=task),
        ds_transformations=ds_transformations,
        transformations=transformations,
        labels=multi_task.Backbone(embeddings=embeddings_labels),
        metadata=metadata,
        sequence_key=output_sequence_key)
예제 #11
0
def make_pair_builder(max_len=512,
                      index_keys=('fam_key', 'ci_100'),
                      process_negatives=True,
                      gap_token='-',
                      sequence_key='sequence',
                      context_sequence_key='full_sequence',
                      loader_cls=make_pfam_pairs_loader,
                      pairing_cls=None,
                      lm_cls=None,
                      has_context=False,
                      append_eos=True,
                      append_eos_context=True,
                      add_random_tails=False,
                      **kwargs):
    """Creates a dataset for pairs of sequences."""
    # Convenience function to index key pairs.
    paired_keys = lambda k: tuple(f'{k}_{i}' for i in (1, 2))

    def stack_and_pop(on):
        stack = transforms.Stack(on=paired_keys(on), out=on)
        pop = transforms.Pop(on=paired_keys(on))
        return [stack, pop]

    # Defines fields to be read from the TFRecords.
    metadata_keys = ['cla_key', 'seq_key'] + list(index_keys)
    extra_keys = metadata_keys.copy()
    # Pre-paired datasets already been filtered by length, seq_len only needed
    # when pairing sequences on-the-fly.
    if pairing_cls is not None:
        extra_keys.append('seq_len')
    # Optionally, adds fields needed by the `AddAlignmentContext` `Transform`.
    if has_context:
        extra_keys.extend(['start', 'end'])
    add_alignment_context_extra_args = (paired_keys(context_sequence_key) +
                                        paired_keys('start') +
                                        paired_keys('end'))
    # Accounts for EOS token if necessary.
    max_len_eos = max_len - 1 if append_eos else max_len

    ### Sets up the `DatasetTransform`s.

    ds_transformations = []
    if pairing_cls is not None:
        filter_by_length = transforms.FilterByLength(max_len=max_len_eos)
        # NOTE(fllinares): pairing on-the-fly is memory intensive on TPU for some
        # reason not yet understood...
        pair_sequences = pairing_cls(index_keys=index_keys)
        ds_transformations.extend([filter_by_length, pair_sequences])

    ### Sets up the `Transform`s applied *before* batching.

    project_msa_rows = align_transforms.ProjectMSARows(
        on=paired_keys(sequence_key), token=gap_token)
    append_eos_to_context = transforms.EOS(
        on=paired_keys(context_sequence_key))
    add_alignment_context = align_transforms.AddAlignmentContext(
        on=paired_keys(sequence_key) + add_alignment_context_extra_args,
        out=paired_keys(sequence_key),
        max_len=max_len_eos,
        gap_token=gap_token)
    trim_alignment = align_transforms.TrimAlignment(
        on=paired_keys(sequence_key), gap_token=gap_token)
    pop_add_alignment_context_extra_args = transforms.Pop(
        on=add_alignment_context_extra_args)
    add_random_prefix_and_suffix = align_transforms.AddRandomTails(
        on=paired_keys(sequence_key), max_len=max_len_eos)
    create_alignment_targets = align_transforms.CreateAlignmentTargets(
        on=paired_keys(sequence_key),
        out='alignment/targets',
        gap_token=gap_token,
        n_prepend_tokens=0)
    pid1 = align_transforms.PID(on=paired_keys(sequence_key),
                                out='alignment/pid1',
                                definition=1,
                                token=gap_token)
    pid3 = align_transforms.PID(on=paired_keys(sequence_key),
                                out='alignment/pid3',
                                definition=3,
                                token=gap_token)
    remove_gaps = transforms.RemoveTokens(on=paired_keys(sequence_key),
                                          tokens=gap_token)
    append_eos_to_sequence = transforms.EOS(on=paired_keys(sequence_key))
    pad_sequences = transforms.CropOrPad(on=paired_keys(sequence_key),
                                         size=max_len)
    pad_alignment_targets = transforms.CropOrPadND(on='alignment/targets',
                                                   size=2 * max_len)

    transformations = [project_msa_rows]
    if has_context:
        if append_eos_context:
            transformations.append(append_eos_to_context)
        transformations.extend([
            add_alignment_context, trim_alignment,
            pop_add_alignment_context_extra_args
        ])
    if add_random_tails:
        transformations.append(add_random_prefix_and_suffix)
    transformations.append(create_alignment_targets)

    transformations.extend([pid1, pid3, remove_gaps])
    if append_eos:
        transformations.append(append_eos_to_sequence)
    transformations.extend([pad_sequences, pad_alignment_targets])
    for key in [sequence_key] + metadata_keys:
        transformations.extend(stack_and_pop(key))

    ### Sets up the `Transform`s applied *after* batching.

    flatten_sequence_pairs = transforms.Reshape(on=sequence_key,
                                                shape=[-1, max_len])
    flatten_metadata_pairs = transforms.Reshape(on=metadata_keys, shape=[-1])
    create_homology_targets = align_transforms.CreateHomologyTargets(
        on='fam_key',
        out='homology/targets',
        process_negatives=process_negatives)
    create_alignment_weights = align_transforms.CreateBatchedWeights(
        on='alignment/targets', out='alignment/weights')
    add_neg_alignment_targets_and_weights = align_transforms.PadNegativePairs(
        on=('alignment/targets', 'alignment/weights'))
    pad_neg_pid = align_transforms.PadNegativePairs(on=('alignment/pid1',
                                                        'alignment/pid3'),
                                                    value=-1.0)

    batched_transformations = [
        flatten_sequence_pairs, flatten_metadata_pairs, create_homology_targets
    ]
    if process_negatives:
        batched_transformations.extend([
            create_alignment_weights, add_neg_alignment_targets_and_weights,
            pad_neg_pid
        ])
    if lm_cls is not None:
        create_lm_targets = lm_cls(on=sequence_key,
                                   out=(sequence_key, 'masked_lm/targets',
                                        'masked_lm/weights'))
        batched_transformations.append(create_lm_targets)

    ### Sets up the remainder of the `DatasetBuilder` configuration.

    masked_lm_labels = ('masked_lm/targets', 'masked_lm/weights')
    alignment_labels = ('alignment/targets' if not process_negatives else
                        ('alignment/targets', 'alignment/weights'))
    homology_labels = 'homology/targets'
    embeddings = () if lm_cls is None else (masked_lm_labels, )
    alignments = (alignment_labels, homology_labels)

    return builder.DatasetBuilder(
        data_loader=loader_cls(extra_keys),
        ds_transformations=ds_transformations,
        transformations=transformations,
        batched_transformations=batched_transformations,
        labels=multi_task.Backbone(embeddings=embeddings,
                                   alignments=alignments),
        metadata=('seq_key', 'alignment/pid1', 'alignment/pid3'),
        sequence_key=sequence_key,
        **kwargs)