def test_recode(self): source = vocabulary.proteins target = vocabulary.alternative text = 'AABUCCFDEFGHYBOUXZACF' sequence = tf.constant(source.encode(text)) recode_fn = transforms.Recode(vocab=source, target=target) encode_fn = transforms.Encode(vocab=target) self.assertAllEqual(recode_fn.call(sequence), encode_fn.call(tf.constant(text)))
def test_encode(self): encode_fn = transforms.Encode(vocab=vocabulary.proteins) text = 'AABUCCFDEFGHYBOUXZACF' sequence = tf.constant(text) encoded = encode_fn.call(sequence) self.assertEqual(encoded.shape[0], len(text)) self.assertEqual(encoded.dtype, tf.int32) output = encode_fn({'sequence': sequence}) self.assertIn('sequence', output) self.assertAllEqual(output['sequence'], encoded)
def preprocess(left, right, max_length=512): """Prepares the data to be fed to the DEDAL network.""" seqs = {'left': left, 'right': right} seqs = {k: v.strip().upper() for k, v in seqs.items()} keys = list(seqs.keys()) transformations = [ transforms.Encode(vocab=vocabulary.alternative, on=keys), transforms.EOS(vocab=vocabulary.alternative, on=keys), transforms.CropOrPad(size=max_length, vocab=vocabulary.alternative, on=keys) ] for t in transformations: seqs = t(seqs) return tf.stack([seqs['left'], seqs['right']], axis=0)
def make_tape_builder(root_dir, task, target, weights=None, metadata=(), max_len=1024, input_sequence_key='primary', output_sequence_key='sequence'): """Creates a DatasetBuilder for TAPE's benchmark.""" supported_tasks = list(TAPE_NUM_OUTPUTS) if task not in supported_tasks: raise ValueError(f'Task {task} not recognized.' f'Supported tasks: {", ".join(supported_tasks)}.') num_outputs = TAPE_NUM_OUTPUTS[task].get(target, 1) used_keys = [input_sequence_key, target] if weights is not None: used_keys.append(weights) if metadata: used_keys.extend(metadata) unused_keys = [k for k in TAPE_SPECS[task] if k not in used_keys] ds_transformations = [] if max_len is not None: ds_transformations.append( transforms.FilterByLength(on=output_sequence_key, precomputed=False, max_len=max_len - 1)) transformations = [ transforms.Pop(on=unused_keys), transforms.Reshape(on=output_sequence_key, shape=[]), transforms.Encode(on=output_sequence_key), transforms.EOS(on=output_sequence_key), transforms.CropOrPad(on=output_sequence_key, size=max_len), ] if target in TAPE_MULTI_CL_TASKS: transformations.append(transforms.OneHot(on=target, depth=num_outputs)) elif target in TAPE_BACKBONE_ANGLE_TASKS: transformations.append(transforms.BackboneAngleTransform(on=target)) elif target in TAPE_PROT_ENGINEERING_TASKS: transformations.append(transforms.Reshape(on=target, shape=[-1])) if target in TAPE_SEQ2SEQ_TASKS: transformations.extend([ transforms.Reshape(on=target, shape=[-1, num_outputs]), transforms.CropOrPadND(on=target, size=max_len, axis=0), ]) if weights is not None: # Note: no seq-level TAPE task has weights. transformations.extend([ transforms.Reshape(on=weights, shape=[-1]), transforms.CropOrPadND(on=weights, size=max_len), ]) embeddings_labels = [target] if weights is None else [(target, weights)] return builder.DatasetBuilder( data_loader=make_tape_loader(root_dir=root_dir, task=task), ds_transformations=ds_transformations, transformations=transformations, labels=multi_task.Backbone(embeddings=embeddings_labels), metadata=metadata, sequence_key=output_sequence_key)
def load(self, split): """Creates CSVDataset for split, encoding the sequence.""" ds = super().load(split) return ds.map(transforms.Encode(on=self._output_sequence_key), num_parallel_calls=tf.data.AUTOTUNE)