def test_add_random_tails(self): seq1 = 'ACG----AATGGCACC--CTAA---' seq2 = '---GGGTAA-GGTACCTACT--TCG' seq1 = tf.convert_to_tensor(self.vocab.encode(seq1), tf.int32) seq2 = tf.convert_to_tensor(self.vocab.encode(seq2), tf.int32) add_random_tails = align_transforms.AddRandomTails() out_seq1, out_seq2 = add_random_tails.call(seq1, seq2) start_pos1 = self.vocab.decode(out_seq1).find(self.vocab.decode(seq1)) start_pos2 = self.vocab.decode(out_seq2).find(self.vocab.decode(seq2)) # Verifies that seq1 (resp. seq2) is contained in out_seq1 (resp. out_seq2). self.assertNotEqual(start_pos1, -1) self.assertNotEqual(start_pos2, -1) # Verifies alignment targets are shifted by the right offset. create_alignment_targets = align_transforms.CreateAlignmentTargets() alg_tar = create_alignment_targets.call(seq1, seq2) out_alg_tar = create_alignment_targets.call(out_seq1, out_seq2) self.assertAllEqual(out_alg_tar[0] - alg_tar[0], alg_tar.shape[1] * [start_pos1]) self.assertAllEqual(out_alg_tar[1] - alg_tar[1], alg_tar.shape[1] * [start_pos2]) self.assertAllEqual(out_alg_tar[2] - alg_tar[2], alg_tar.shape[1] * [0]) # States unchanged.
def test_create_alignment_targets(self): gap_token = '-' n_prepend_tokens = 0 align_fn = align_transforms.CreateAlignmentTargets( gap_token=gap_token, n_prepend_tokens=n_prepend_tokens, vocab=self.vocab) seq1 = tf.convert_to_tensor(self.vocab.encode('XX-XXXX'), tf.int32) seq2 = tf.convert_to_tensor(self.vocab.encode('YYYY-YY'), tf.int32) expected_output = tf.convert_to_tensor([[1, 2, 2, 3, 4, 5, 6], [1, 2, 3, 4, 4, 5, 6], [0, 1, 4, 2, 6, 3, 1]], tf.int32) output = align_fn.call(seq1, seq2) self.assertAllEqual(output, expected_output) seq1 = tf.convert_to_tensor(self.vocab.encode('--XXXXXX'), tf.int32) seq2 = tf.convert_to_tensor(self.vocab.encode('YYYY-YY-'), tf.int32) expected_output = tf.convert_to_tensor([[1, 2, 3, 4, 5], [3, 4, 4, 5, 6], [0, 1, 6, 3, 1]], tf.int32) output = align_fn.call(seq1, seq2) self.assertAllEqual(output, expected_output) seq1 = tf.convert_to_tensor(self.vocab.encode('X-X-X-X-'), tf.int32) seq2 = tf.convert_to_tensor(self.vocab.encode('-Y-Y-Y-Y'), tf.int32) expected_output = tf.zeros([3, 0], tf.int32) output = align_fn.call(seq1, seq2) self.assertAllEqual(output, expected_output)
def test_add_alignment_context(self): sequence_1 = 'AATGGCACC--CT' sequence_2 = 'AA-GGTACCTACT' full_sequence_1 = 'ACG' + sequence_1.replace('-', '') + 'AA' full_sequence_2 = 'GGGT' + sequence_2.replace('-', '') + 'TCG' sequence_1 = tf.convert_to_tensor(self.vocab.encode(sequence_1), tf.int32) sequence_2 = tf.convert_to_tensor(self.vocab.encode(sequence_2), tf.int32) full_sequence_1 = tf.convert_to_tensor( self.vocab.encode(full_sequence_1), tf.int32) full_sequence_2 = tf.convert_to_tensor( self.vocab.encode(full_sequence_2), tf.int32) start_1, end_1 = 4, 14 start_2, end_2 = 5, 16 add_alignment_context = align_transforms.AddAlignmentContext() sequence_with_ctx_1, sequence_with_ctx_2 = add_alignment_context.call( sequence_1, sequence_2, full_sequence_1, full_sequence_2, start_1, start_2, end_1, end_2) self.assertEqual(len(sequence_with_ctx_1), len(sequence_with_ctx_2)) self.assertIn(self.vocab.decode(sequence_with_ctx_1), self.vocab.decode(full_sequence_1)) self.assertIn(self.vocab.decode(sequence_with_ctx_2), self.vocab.decode(full_sequence_2)) create_alignment_targets = align_transforms.CreateAlignmentTargets() targets = create_alignment_targets.call(sequence_1, sequence_2) targets_with_ctx = create_alignment_targets.call( sequence_with_ctx_1, sequence_with_ctx_2) find_1 = self.vocab.decode(sequence_with_ctx_1).find( self.vocab.decode(sequence_1)) find_2 = self.vocab.decode(sequence_with_ctx_2).find( self.vocab.decode(sequence_2)) self.assertAllEqual(targets_with_ctx[0], targets[0] + find_1) self.assertAllEqual(targets_with_ctx[1], targets[1] + find_2) self.assertAllEqual(targets_with_ctx[2], targets[2])
def make_pair_builder(max_len=512, index_keys=('fam_key', 'ci_100'), process_negatives=True, gap_token='-', sequence_key='sequence', context_sequence_key='full_sequence', loader_cls=make_pfam_pairs_loader, pairing_cls=None, lm_cls=None, has_context=False, append_eos=True, append_eos_context=True, add_random_tails=False, **kwargs): """Creates a dataset for pairs of sequences.""" # Convenience function to index key pairs. paired_keys = lambda k: tuple(f'{k}_{i}' for i in (1, 2)) def stack_and_pop(on): stack = transforms.Stack(on=paired_keys(on), out=on) pop = transforms.Pop(on=paired_keys(on)) return [stack, pop] # Defines fields to be read from the TFRecords. metadata_keys = ['cla_key', 'seq_key'] + list(index_keys) extra_keys = metadata_keys.copy() # Pre-paired datasets already been filtered by length, seq_len only needed # when pairing sequences on-the-fly. if pairing_cls is not None: extra_keys.append('seq_len') # Optionally, adds fields needed by the `AddAlignmentContext` `Transform`. if has_context: extra_keys.extend(['start', 'end']) add_alignment_context_extra_args = (paired_keys(context_sequence_key) + paired_keys('start') + paired_keys('end')) # Accounts for EOS token if necessary. max_len_eos = max_len - 1 if append_eos else max_len ### Sets up the `DatasetTransform`s. ds_transformations = [] if pairing_cls is not None: filter_by_length = transforms.FilterByLength(max_len=max_len_eos) # NOTE(fllinares): pairing on-the-fly is memory intensive on TPU for some # reason not yet understood... pair_sequences = pairing_cls(index_keys=index_keys) ds_transformations.extend([filter_by_length, pair_sequences]) ### Sets up the `Transform`s applied *before* batching. project_msa_rows = align_transforms.ProjectMSARows( on=paired_keys(sequence_key), token=gap_token) append_eos_to_context = transforms.EOS( on=paired_keys(context_sequence_key)) add_alignment_context = align_transforms.AddAlignmentContext( on=paired_keys(sequence_key) + add_alignment_context_extra_args, out=paired_keys(sequence_key), max_len=max_len_eos, gap_token=gap_token) trim_alignment = align_transforms.TrimAlignment( on=paired_keys(sequence_key), gap_token=gap_token) pop_add_alignment_context_extra_args = transforms.Pop( on=add_alignment_context_extra_args) add_random_prefix_and_suffix = align_transforms.AddRandomTails( on=paired_keys(sequence_key), max_len=max_len_eos) create_alignment_targets = align_transforms.CreateAlignmentTargets( on=paired_keys(sequence_key), out='alignment/targets', gap_token=gap_token, n_prepend_tokens=0) pid1 = align_transforms.PID(on=paired_keys(sequence_key), out='alignment/pid1', definition=1, token=gap_token) pid3 = align_transforms.PID(on=paired_keys(sequence_key), out='alignment/pid3', definition=3, token=gap_token) remove_gaps = transforms.RemoveTokens(on=paired_keys(sequence_key), tokens=gap_token) append_eos_to_sequence = transforms.EOS(on=paired_keys(sequence_key)) pad_sequences = transforms.CropOrPad(on=paired_keys(sequence_key), size=max_len) pad_alignment_targets = transforms.CropOrPadND(on='alignment/targets', size=2 * max_len) transformations = [project_msa_rows] if has_context: if append_eos_context: transformations.append(append_eos_to_context) transformations.extend([ add_alignment_context, trim_alignment, pop_add_alignment_context_extra_args ]) if add_random_tails: transformations.append(add_random_prefix_and_suffix) transformations.append(create_alignment_targets) transformations.extend([pid1, pid3, remove_gaps]) if append_eos: transformations.append(append_eos_to_sequence) transformations.extend([pad_sequences, pad_alignment_targets]) for key in [sequence_key] + metadata_keys: transformations.extend(stack_and_pop(key)) ### Sets up the `Transform`s applied *after* batching. flatten_sequence_pairs = transforms.Reshape(on=sequence_key, shape=[-1, max_len]) flatten_metadata_pairs = transforms.Reshape(on=metadata_keys, shape=[-1]) create_homology_targets = align_transforms.CreateHomologyTargets( on='fam_key', out='homology/targets', process_negatives=process_negatives) create_alignment_weights = align_transforms.CreateBatchedWeights( on='alignment/targets', out='alignment/weights') add_neg_alignment_targets_and_weights = align_transforms.PadNegativePairs( on=('alignment/targets', 'alignment/weights')) pad_neg_pid = align_transforms.PadNegativePairs(on=('alignment/pid1', 'alignment/pid3'), value=-1.0) batched_transformations = [ flatten_sequence_pairs, flatten_metadata_pairs, create_homology_targets ] if process_negatives: batched_transformations.extend([ create_alignment_weights, add_neg_alignment_targets_and_weights, pad_neg_pid ]) if lm_cls is not None: create_lm_targets = lm_cls(on=sequence_key, out=(sequence_key, 'masked_lm/targets', 'masked_lm/weights')) batched_transformations.append(create_lm_targets) ### Sets up the remainder of the `DatasetBuilder` configuration. masked_lm_labels = ('masked_lm/targets', 'masked_lm/weights') alignment_labels = ('alignment/targets' if not process_negatives else ('alignment/targets', 'alignment/weights')) homology_labels = 'homology/targets' embeddings = () if lm_cls is None else (masked_lm_labels, ) alignments = (alignment_labels, homology_labels) return builder.DatasetBuilder( data_loader=loader_cls(extra_keys), ds_transformations=ds_transformations, transformations=transformations, batched_transformations=batched_transformations, labels=multi_task.Backbone(embeddings=embeddings, alignments=alignments), metadata=('seq_key', 'alignment/pid1', 'alignment/pid3'), sequence_key=sequence_key, **kwargs)