Esempio n. 1
0
 def test_transforms_from_gin(self):
     ds_builder = builder.DatasetBuilder()
     batch = 32
     ds = ds_builder.make('train', batch)
     inputs, y_true, weights, _ = next(iter(ds))
     self.assertIsInstance(inputs, tf.Tensor)
     self.assertEqual(inputs.dtype, tf.int32)
     self.assertIsInstance(y_true, dict)
     self.assertGreater(len(y_true), 0)
     self.assertIsInstance(weights, dict)
     self.assertGreater(len(weights), 0)
     self.assertEqual(y_true['alignments/0'].dtype, tf.int32)
     self.assertEqual(y_true['alignments/0'].shape, (2 * batch, 3, 1024))
     self.assertEqual(y_true['alignments/1'].dtype, tf.int32)
     self.assertEqual(y_true['alignments/1'].shape, (2 * batch, 1))
     self.assertEqual(weights['alignments/0'].dtype, tf.float32)
     self.assertEqual(weights['alignments/0'].shape, (2 * batch, ))
     self.assertEqual(weights['alignments/1'].dtype, tf.float32)
     self.assertEqual(weights['alignments/1'].shape, (2 * batch, ))
Esempio n. 2
0
def make_fake_builder(max_len = 512):
  return builder.DatasetBuilder(
      data_loader=FakePairsLoader(max_len=max_len),
      labels=multi_task.Backbone(alignments=[
          ('alignment/targets', 'alignment/weights'),
          'homology/targets']),
      batched_transformations=[
          transforms.Reshape(shape=(-1, max_len), on='sequence'),
          transforms.Reshape(shape=(-1,), on='fam_key'),
          align_transforms.CreateBatchedWeights(
              on='alignment/targets', out='alignment/weights'),
          align_transforms.PadNegativePairs(
              on=['alignment/targets', 'alignment/weights']),
          align_transforms.PadNegativePairs(
              value=-1.0, on=['alignment/pid1', 'alignment/pid3']),
          align_transforms.CreateHomologyTargets(
              process_negatives=True, on='fam_key', out='homology/targets'),
      ],
      metadata=('alignment/pid1', 'alignment/pid3'),
      split='train')
 def test_transforms_from_gin(self):
     self.mock_load.return_value = make_fake_sequence_dataset()
     ds_builder = builder.DatasetBuilder()
     batch = 32
     ds = ds_builder.make('test', batch)
     inputs, y_true, weights, metadata = next(iter(ds))
     self.assertIsInstance(inputs, tf.Tensor)
     self.assertEqual(inputs.dtype, tf.int32)
     self.assertIsInstance(y_true, dict)
     self.assertGreater(len(y_true), 0)
     self.assertIsInstance(weights, dict)
     self.assertGreater(len(weights), 0)
     for y in multi_task.Backbone.unflatten(y_true):
         self.assertEqual(y.dtype, tf.int32)
         self.assertEqual(y.shape, (batch, 1024))  # from gin.
     for y in multi_task.Backbone.unflatten(weights):
         self.assertEqual(y.dtype, tf.float32)
     self.assertIn('seq_key', metadata)
     self.assertIsInstance(metadata['seq_key'], tf.Tensor)
     self.assertEqual(metadata['seq_key'].dtype, tf.int32)
     self.assertNotIn('fam_key', metadata)
Esempio n. 4
0
 def _make_loop_with_reference(self):
     """Creates a training loop for alignments with self._loop as reference."""
     seq_len = gin.query_parameter('%SEQUENCE_LENGTH')
     model_cls = functools.partial(
         dedal.Dedal,
         encoder_cls=functools.partial(encoders.TransformerEncoder,
                                       emb_dim=48,
                                       num_layers=1,
                                       num_heads=2,
                                       mlp_dim=3 * seq_len,
                                       max_len=seq_len),
         aligner_cls=aligners.SoftAligner,
         heads_cls=multi_task.Backbone(
             embeddings=[], alignments=[homology.UncorrectedLogits]),
     )
     workdir2 = tempfile.mkdtemp()
     ds_builder = builder.DatasetBuilder(labels=multi_task.Backbone(
         alignments=[('target', 'weights')]))
     return training_loop.TrainingLoop(
         workdir=workdir2,
         strategy=self._loop.strategy,
         dataset_builder=ds_builder,
         logger_cls=functools.partial(
             logger.Logger,
             scalars=multi_task.Backbone(
                 alignments=[[tf.keras.metrics.BinaryAccuracy]]),
             every=5),
         loss_fn=losses.MultiTaskLoss(losses=multi_task.Backbone(
             alignments=[
                 tf.keras.losses.BinaryCrossentropy(
                     from_logits=True,
                     reduction=tf.keras.losses.Reduction.NONE)
             ])),
         optimizer_cls=self._loop._optimizer_cls,
         batch_size=self._loop._batch_size,
         model_cls=model_cls,
         num_steps=self._loop._num_steps,
         reference_workdir=self._loop._workdir,
         num_reference_steps=self._loop._num_steps)
Esempio n. 5
0
def make_tape_builder(root_dir,
                      task,
                      target,
                      weights=None,
                      metadata=(),
                      max_len=1024,
                      input_sequence_key='primary',
                      output_sequence_key='sequence'):
    """Creates a DatasetBuilder for TAPE's benchmark."""
    supported_tasks = list(TAPE_NUM_OUTPUTS)
    if task not in supported_tasks:
        raise ValueError(f'Task {task} not recognized.'
                         f'Supported tasks: {", ".join(supported_tasks)}.')
    num_outputs = TAPE_NUM_OUTPUTS[task].get(target, 1)

    used_keys = [input_sequence_key, target]
    if weights is not None:
        used_keys.append(weights)
    if metadata:
        used_keys.extend(metadata)
    unused_keys = [k for k in TAPE_SPECS[task] if k not in used_keys]

    ds_transformations = []
    if max_len is not None:
        ds_transformations.append(
            transforms.FilterByLength(on=output_sequence_key,
                                      precomputed=False,
                                      max_len=max_len - 1))

    transformations = [
        transforms.Pop(on=unused_keys),
        transforms.Reshape(on=output_sequence_key, shape=[]),
        transforms.Encode(on=output_sequence_key),
        transforms.EOS(on=output_sequence_key),
        transforms.CropOrPad(on=output_sequence_key, size=max_len),
    ]

    if target in TAPE_MULTI_CL_TASKS:
        transformations.append(transforms.OneHot(on=target, depth=num_outputs))
    elif target in TAPE_BACKBONE_ANGLE_TASKS:
        transformations.append(transforms.BackboneAngleTransform(on=target))
    elif target in TAPE_PROT_ENGINEERING_TASKS:
        transformations.append(transforms.Reshape(on=target, shape=[-1]))

    if target in TAPE_SEQ2SEQ_TASKS:
        transformations.extend([
            transforms.Reshape(on=target, shape=[-1, num_outputs]),
            transforms.CropOrPadND(on=target, size=max_len, axis=0),
        ])

    if weights is not None:  # Note: no seq-level TAPE task has weights.
        transformations.extend([
            transforms.Reshape(on=weights, shape=[-1]),
            transforms.CropOrPadND(on=weights, size=max_len),
        ])

    embeddings_labels = [target] if weights is None else [(target, weights)]
    return builder.DatasetBuilder(
        data_loader=make_tape_loader(root_dir=root_dir, task=task),
        ds_transformations=ds_transformations,
        transformations=transformations,
        labels=multi_task.Backbone(embeddings=embeddings_labels),
        metadata=metadata,
        sequence_key=output_sequence_key)
Esempio n. 6
0
def make_pair_builder(max_len=512,
                      index_keys=('fam_key', 'ci_100'),
                      process_negatives=True,
                      gap_token='-',
                      sequence_key='sequence',
                      context_sequence_key='full_sequence',
                      loader_cls=make_pfam_pairs_loader,
                      pairing_cls=None,
                      lm_cls=None,
                      has_context=False,
                      append_eos=True,
                      append_eos_context=True,
                      add_random_tails=False,
                      **kwargs):
    """Creates a dataset for pairs of sequences."""
    # Convenience function to index key pairs.
    paired_keys = lambda k: tuple(f'{k}_{i}' for i in (1, 2))

    def stack_and_pop(on):
        stack = transforms.Stack(on=paired_keys(on), out=on)
        pop = transforms.Pop(on=paired_keys(on))
        return [stack, pop]

    # Defines fields to be read from the TFRecords.
    metadata_keys = ['cla_key', 'seq_key'] + list(index_keys)
    extra_keys = metadata_keys.copy()
    # Pre-paired datasets already been filtered by length, seq_len only needed
    # when pairing sequences on-the-fly.
    if pairing_cls is not None:
        extra_keys.append('seq_len')
    # Optionally, adds fields needed by the `AddAlignmentContext` `Transform`.
    if has_context:
        extra_keys.extend(['start', 'end'])
    add_alignment_context_extra_args = (paired_keys(context_sequence_key) +
                                        paired_keys('start') +
                                        paired_keys('end'))
    # Accounts for EOS token if necessary.
    max_len_eos = max_len - 1 if append_eos else max_len

    ### Sets up the `DatasetTransform`s.

    ds_transformations = []
    if pairing_cls is not None:
        filter_by_length = transforms.FilterByLength(max_len=max_len_eos)
        # NOTE(fllinares): pairing on-the-fly is memory intensive on TPU for some
        # reason not yet understood...
        pair_sequences = pairing_cls(index_keys=index_keys)
        ds_transformations.extend([filter_by_length, pair_sequences])

    ### Sets up the `Transform`s applied *before* batching.

    project_msa_rows = align_transforms.ProjectMSARows(
        on=paired_keys(sequence_key), token=gap_token)
    append_eos_to_context = transforms.EOS(
        on=paired_keys(context_sequence_key))
    add_alignment_context = align_transforms.AddAlignmentContext(
        on=paired_keys(sequence_key) + add_alignment_context_extra_args,
        out=paired_keys(sequence_key),
        max_len=max_len_eos,
        gap_token=gap_token)
    trim_alignment = align_transforms.TrimAlignment(
        on=paired_keys(sequence_key), gap_token=gap_token)
    pop_add_alignment_context_extra_args = transforms.Pop(
        on=add_alignment_context_extra_args)
    add_random_prefix_and_suffix = align_transforms.AddRandomTails(
        on=paired_keys(sequence_key), max_len=max_len_eos)
    create_alignment_targets = align_transforms.CreateAlignmentTargets(
        on=paired_keys(sequence_key),
        out='alignment/targets',
        gap_token=gap_token,
        n_prepend_tokens=0)
    pid1 = align_transforms.PID(on=paired_keys(sequence_key),
                                out='alignment/pid1',
                                definition=1,
                                token=gap_token)
    pid3 = align_transforms.PID(on=paired_keys(sequence_key),
                                out='alignment/pid3',
                                definition=3,
                                token=gap_token)
    remove_gaps = transforms.RemoveTokens(on=paired_keys(sequence_key),
                                          tokens=gap_token)
    append_eos_to_sequence = transforms.EOS(on=paired_keys(sequence_key))
    pad_sequences = transforms.CropOrPad(on=paired_keys(sequence_key),
                                         size=max_len)
    pad_alignment_targets = transforms.CropOrPadND(on='alignment/targets',
                                                   size=2 * max_len)

    transformations = [project_msa_rows]
    if has_context:
        if append_eos_context:
            transformations.append(append_eos_to_context)
        transformations.extend([
            add_alignment_context, trim_alignment,
            pop_add_alignment_context_extra_args
        ])
    if add_random_tails:
        transformations.append(add_random_prefix_and_suffix)
    transformations.append(create_alignment_targets)

    transformations.extend([pid1, pid3, remove_gaps])
    if append_eos:
        transformations.append(append_eos_to_sequence)
    transformations.extend([pad_sequences, pad_alignment_targets])
    for key in [sequence_key] + metadata_keys:
        transformations.extend(stack_and_pop(key))

    ### Sets up the `Transform`s applied *after* batching.

    flatten_sequence_pairs = transforms.Reshape(on=sequence_key,
                                                shape=[-1, max_len])
    flatten_metadata_pairs = transforms.Reshape(on=metadata_keys, shape=[-1])
    create_homology_targets = align_transforms.CreateHomologyTargets(
        on='fam_key',
        out='homology/targets',
        process_negatives=process_negatives)
    create_alignment_weights = align_transforms.CreateBatchedWeights(
        on='alignment/targets', out='alignment/weights')
    add_neg_alignment_targets_and_weights = align_transforms.PadNegativePairs(
        on=('alignment/targets', 'alignment/weights'))
    pad_neg_pid = align_transforms.PadNegativePairs(on=('alignment/pid1',
                                                        'alignment/pid3'),
                                                    value=-1.0)

    batched_transformations = [
        flatten_sequence_pairs, flatten_metadata_pairs, create_homology_targets
    ]
    if process_negatives:
        batched_transformations.extend([
            create_alignment_weights, add_neg_alignment_targets_and_weights,
            pad_neg_pid
        ])
    if lm_cls is not None:
        create_lm_targets = lm_cls(on=sequence_key,
                                   out=(sequence_key, 'masked_lm/targets',
                                        'masked_lm/weights'))
        batched_transformations.append(create_lm_targets)

    ### Sets up the remainder of the `DatasetBuilder` configuration.

    masked_lm_labels = ('masked_lm/targets', 'masked_lm/weights')
    alignment_labels = ('alignment/targets' if not process_negatives else
                        ('alignment/targets', 'alignment/weights'))
    homology_labels = 'homology/targets'
    embeddings = () if lm_cls is None else (masked_lm_labels, )
    alignments = (alignment_labels, homology_labels)

    return builder.DatasetBuilder(
        data_loader=loader_cls(extra_keys),
        ds_transformations=ds_transformations,
        transformations=transformations,
        batched_transformations=batched_transformations,
        labels=multi_task.Backbone(embeddings=embeddings,
                                   alignments=alignments),
        metadata=('seq_key', 'alignment/pid1', 'alignment/pid3'),
        sequence_key=sequence_key,
        **kwargs)