def test_recode(self):
     source = vocabulary.proteins
     target = vocabulary.alternative
     text = 'AABUCCFDEFGHYBOUXZACF'
     sequence = tf.constant(source.encode(text))
     recode_fn = transforms.Recode(vocab=source, target=target)
     encode_fn = transforms.Encode(vocab=target)
     self.assertAllEqual(recode_fn.call(sequence),
                         encode_fn.call(tf.constant(text)))
    def test_encode(self):
        encode_fn = transforms.Encode(vocab=vocabulary.proteins)
        text = 'AABUCCFDEFGHYBOUXZACF'
        sequence = tf.constant(text)
        encoded = encode_fn.call(sequence)
        self.assertEqual(encoded.shape[0], len(text))
        self.assertEqual(encoded.dtype, tf.int32)

        output = encode_fn({'sequence': sequence})
        self.assertIn('sequence', output)
        self.assertAllEqual(output['sequence'], encoded)
예제 #3
0
def preprocess(left, right, max_length=512):
    """Prepares the data to be fed to the DEDAL network."""
    seqs = {'left': left, 'right': right}
    seqs = {k: v.strip().upper() for k, v in seqs.items()}
    keys = list(seqs.keys())
    transformations = [
        transforms.Encode(vocab=vocabulary.alternative, on=keys),
        transforms.EOS(vocab=vocabulary.alternative, on=keys),
        transforms.CropOrPad(size=max_length,
                             vocab=vocabulary.alternative,
                             on=keys)
    ]
    for t in transformations:
        seqs = t(seqs)
    return tf.stack([seqs['left'], seqs['right']], axis=0)
예제 #4
0
def make_tape_builder(root_dir,
                      task,
                      target,
                      weights=None,
                      metadata=(),
                      max_len=1024,
                      input_sequence_key='primary',
                      output_sequence_key='sequence'):
    """Creates a DatasetBuilder for TAPE's benchmark."""
    supported_tasks = list(TAPE_NUM_OUTPUTS)
    if task not in supported_tasks:
        raise ValueError(f'Task {task} not recognized.'
                         f'Supported tasks: {", ".join(supported_tasks)}.')
    num_outputs = TAPE_NUM_OUTPUTS[task].get(target, 1)

    used_keys = [input_sequence_key, target]
    if weights is not None:
        used_keys.append(weights)
    if metadata:
        used_keys.extend(metadata)
    unused_keys = [k for k in TAPE_SPECS[task] if k not in used_keys]

    ds_transformations = []
    if max_len is not None:
        ds_transformations.append(
            transforms.FilterByLength(on=output_sequence_key,
                                      precomputed=False,
                                      max_len=max_len - 1))

    transformations = [
        transforms.Pop(on=unused_keys),
        transforms.Reshape(on=output_sequence_key, shape=[]),
        transforms.Encode(on=output_sequence_key),
        transforms.EOS(on=output_sequence_key),
        transforms.CropOrPad(on=output_sequence_key, size=max_len),
    ]

    if target in TAPE_MULTI_CL_TASKS:
        transformations.append(transforms.OneHot(on=target, depth=num_outputs))
    elif target in TAPE_BACKBONE_ANGLE_TASKS:
        transformations.append(transforms.BackboneAngleTransform(on=target))
    elif target in TAPE_PROT_ENGINEERING_TASKS:
        transformations.append(transforms.Reshape(on=target, shape=[-1]))

    if target in TAPE_SEQ2SEQ_TASKS:
        transformations.extend([
            transforms.Reshape(on=target, shape=[-1, num_outputs]),
            transforms.CropOrPadND(on=target, size=max_len, axis=0),
        ])

    if weights is not None:  # Note: no seq-level TAPE task has weights.
        transformations.extend([
            transforms.Reshape(on=weights, shape=[-1]),
            transforms.CropOrPadND(on=weights, size=max_len),
        ])

    embeddings_labels = [target] if weights is None else [(target, weights)]
    return builder.DatasetBuilder(
        data_loader=make_tape_loader(root_dir=root_dir, task=task),
        ds_transformations=ds_transformations,
        transformations=transformations,
        labels=multi_task.Backbone(embeddings=embeddings_labels),
        metadata=metadata,
        sequence_key=output_sequence_key)
예제 #5
0
 def load(self, split):
     """Creates CSVDataset for split, encoding the sequence."""
     ds = super().load(split)
     return ds.map(transforms.Encode(on=self._output_sequence_key),
                   num_parallel_calls=tf.data.AUTOTUNE)