def test_transforms_from_gin(self): ds_builder = builder.DatasetBuilder() batch = 32 ds = ds_builder.make('train', batch) inputs, y_true, weights, _ = next(iter(ds)) self.assertIsInstance(inputs, tf.Tensor) self.assertEqual(inputs.dtype, tf.int32) self.assertIsInstance(y_true, dict) self.assertGreater(len(y_true), 0) self.assertIsInstance(weights, dict) self.assertGreater(len(weights), 0) self.assertEqual(y_true['alignments/0'].dtype, tf.int32) self.assertEqual(y_true['alignments/0'].shape, (2 * batch, 3, 1024)) self.assertEqual(y_true['alignments/1'].dtype, tf.int32) self.assertEqual(y_true['alignments/1'].shape, (2 * batch, 1)) self.assertEqual(weights['alignments/0'].dtype, tf.float32) self.assertEqual(weights['alignments/0'].shape, (2 * batch, )) self.assertEqual(weights['alignments/1'].dtype, tf.float32) self.assertEqual(weights['alignments/1'].shape, (2 * batch, ))
def make_fake_builder(max_len = 512): return builder.DatasetBuilder( data_loader=FakePairsLoader(max_len=max_len), labels=multi_task.Backbone(alignments=[ ('alignment/targets', 'alignment/weights'), 'homology/targets']), batched_transformations=[ transforms.Reshape(shape=(-1, max_len), on='sequence'), transforms.Reshape(shape=(-1,), on='fam_key'), align_transforms.CreateBatchedWeights( on='alignment/targets', out='alignment/weights'), align_transforms.PadNegativePairs( on=['alignment/targets', 'alignment/weights']), align_transforms.PadNegativePairs( value=-1.0, on=['alignment/pid1', 'alignment/pid3']), align_transforms.CreateHomologyTargets( process_negatives=True, on='fam_key', out='homology/targets'), ], metadata=('alignment/pid1', 'alignment/pid3'), split='train')
def test_transforms_from_gin(self): self.mock_load.return_value = make_fake_sequence_dataset() ds_builder = builder.DatasetBuilder() batch = 32 ds = ds_builder.make('test', batch) inputs, y_true, weights, metadata = next(iter(ds)) self.assertIsInstance(inputs, tf.Tensor) self.assertEqual(inputs.dtype, tf.int32) self.assertIsInstance(y_true, dict) self.assertGreater(len(y_true), 0) self.assertIsInstance(weights, dict) self.assertGreater(len(weights), 0) for y in multi_task.Backbone.unflatten(y_true): self.assertEqual(y.dtype, tf.int32) self.assertEqual(y.shape, (batch, 1024)) # from gin. for y in multi_task.Backbone.unflatten(weights): self.assertEqual(y.dtype, tf.float32) self.assertIn('seq_key', metadata) self.assertIsInstance(metadata['seq_key'], tf.Tensor) self.assertEqual(metadata['seq_key'].dtype, tf.int32) self.assertNotIn('fam_key', metadata)
def _make_loop_with_reference(self): """Creates a training loop for alignments with self._loop as reference.""" seq_len = gin.query_parameter('%SEQUENCE_LENGTH') model_cls = functools.partial( dedal.Dedal, encoder_cls=functools.partial(encoders.TransformerEncoder, emb_dim=48, num_layers=1, num_heads=2, mlp_dim=3 * seq_len, max_len=seq_len), aligner_cls=aligners.SoftAligner, heads_cls=multi_task.Backbone( embeddings=[], alignments=[homology.UncorrectedLogits]), ) workdir2 = tempfile.mkdtemp() ds_builder = builder.DatasetBuilder(labels=multi_task.Backbone( alignments=[('target', 'weights')])) return training_loop.TrainingLoop( workdir=workdir2, strategy=self._loop.strategy, dataset_builder=ds_builder, logger_cls=functools.partial( logger.Logger, scalars=multi_task.Backbone( alignments=[[tf.keras.metrics.BinaryAccuracy]]), every=5), loss_fn=losses.MultiTaskLoss(losses=multi_task.Backbone( alignments=[ tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) ])), optimizer_cls=self._loop._optimizer_cls, batch_size=self._loop._batch_size, model_cls=model_cls, num_steps=self._loop._num_steps, reference_workdir=self._loop._workdir, num_reference_steps=self._loop._num_steps)
def make_tape_builder(root_dir, task, target, weights=None, metadata=(), max_len=1024, input_sequence_key='primary', output_sequence_key='sequence'): """Creates a DatasetBuilder for TAPE's benchmark.""" supported_tasks = list(TAPE_NUM_OUTPUTS) if task not in supported_tasks: raise ValueError(f'Task {task} not recognized.' f'Supported tasks: {", ".join(supported_tasks)}.') num_outputs = TAPE_NUM_OUTPUTS[task].get(target, 1) used_keys = [input_sequence_key, target] if weights is not None: used_keys.append(weights) if metadata: used_keys.extend(metadata) unused_keys = [k for k in TAPE_SPECS[task] if k not in used_keys] ds_transformations = [] if max_len is not None: ds_transformations.append( transforms.FilterByLength(on=output_sequence_key, precomputed=False, max_len=max_len - 1)) transformations = [ transforms.Pop(on=unused_keys), transforms.Reshape(on=output_sequence_key, shape=[]), transforms.Encode(on=output_sequence_key), transforms.EOS(on=output_sequence_key), transforms.CropOrPad(on=output_sequence_key, size=max_len), ] if target in TAPE_MULTI_CL_TASKS: transformations.append(transforms.OneHot(on=target, depth=num_outputs)) elif target in TAPE_BACKBONE_ANGLE_TASKS: transformations.append(transforms.BackboneAngleTransform(on=target)) elif target in TAPE_PROT_ENGINEERING_TASKS: transformations.append(transforms.Reshape(on=target, shape=[-1])) if target in TAPE_SEQ2SEQ_TASKS: transformations.extend([ transforms.Reshape(on=target, shape=[-1, num_outputs]), transforms.CropOrPadND(on=target, size=max_len, axis=0), ]) if weights is not None: # Note: no seq-level TAPE task has weights. transformations.extend([ transforms.Reshape(on=weights, shape=[-1]), transforms.CropOrPadND(on=weights, size=max_len), ]) embeddings_labels = [target] if weights is None else [(target, weights)] return builder.DatasetBuilder( data_loader=make_tape_loader(root_dir=root_dir, task=task), ds_transformations=ds_transformations, transformations=transformations, labels=multi_task.Backbone(embeddings=embeddings_labels), metadata=metadata, sequence_key=output_sequence_key)
def make_pair_builder(max_len=512, index_keys=('fam_key', 'ci_100'), process_negatives=True, gap_token='-', sequence_key='sequence', context_sequence_key='full_sequence', loader_cls=make_pfam_pairs_loader, pairing_cls=None, lm_cls=None, has_context=False, append_eos=True, append_eos_context=True, add_random_tails=False, **kwargs): """Creates a dataset for pairs of sequences.""" # Convenience function to index key pairs. paired_keys = lambda k: tuple(f'{k}_{i}' for i in (1, 2)) def stack_and_pop(on): stack = transforms.Stack(on=paired_keys(on), out=on) pop = transforms.Pop(on=paired_keys(on)) return [stack, pop] # Defines fields to be read from the TFRecords. metadata_keys = ['cla_key', 'seq_key'] + list(index_keys) extra_keys = metadata_keys.copy() # Pre-paired datasets already been filtered by length, seq_len only needed # when pairing sequences on-the-fly. if pairing_cls is not None: extra_keys.append('seq_len') # Optionally, adds fields needed by the `AddAlignmentContext` `Transform`. if has_context: extra_keys.extend(['start', 'end']) add_alignment_context_extra_args = (paired_keys(context_sequence_key) + paired_keys('start') + paired_keys('end')) # Accounts for EOS token if necessary. max_len_eos = max_len - 1 if append_eos else max_len ### Sets up the `DatasetTransform`s. ds_transformations = [] if pairing_cls is not None: filter_by_length = transforms.FilterByLength(max_len=max_len_eos) # NOTE(fllinares): pairing on-the-fly is memory intensive on TPU for some # reason not yet understood... pair_sequences = pairing_cls(index_keys=index_keys) ds_transformations.extend([filter_by_length, pair_sequences]) ### Sets up the `Transform`s applied *before* batching. project_msa_rows = align_transforms.ProjectMSARows( on=paired_keys(sequence_key), token=gap_token) append_eos_to_context = transforms.EOS( on=paired_keys(context_sequence_key)) add_alignment_context = align_transforms.AddAlignmentContext( on=paired_keys(sequence_key) + add_alignment_context_extra_args, out=paired_keys(sequence_key), max_len=max_len_eos, gap_token=gap_token) trim_alignment = align_transforms.TrimAlignment( on=paired_keys(sequence_key), gap_token=gap_token) pop_add_alignment_context_extra_args = transforms.Pop( on=add_alignment_context_extra_args) add_random_prefix_and_suffix = align_transforms.AddRandomTails( on=paired_keys(sequence_key), max_len=max_len_eos) create_alignment_targets = align_transforms.CreateAlignmentTargets( on=paired_keys(sequence_key), out='alignment/targets', gap_token=gap_token, n_prepend_tokens=0) pid1 = align_transforms.PID(on=paired_keys(sequence_key), out='alignment/pid1', definition=1, token=gap_token) pid3 = align_transforms.PID(on=paired_keys(sequence_key), out='alignment/pid3', definition=3, token=gap_token) remove_gaps = transforms.RemoveTokens(on=paired_keys(sequence_key), tokens=gap_token) append_eos_to_sequence = transforms.EOS(on=paired_keys(sequence_key)) pad_sequences = transforms.CropOrPad(on=paired_keys(sequence_key), size=max_len) pad_alignment_targets = transforms.CropOrPadND(on='alignment/targets', size=2 * max_len) transformations = [project_msa_rows] if has_context: if append_eos_context: transformations.append(append_eos_to_context) transformations.extend([ add_alignment_context, trim_alignment, pop_add_alignment_context_extra_args ]) if add_random_tails: transformations.append(add_random_prefix_and_suffix) transformations.append(create_alignment_targets) transformations.extend([pid1, pid3, remove_gaps]) if append_eos: transformations.append(append_eos_to_sequence) transformations.extend([pad_sequences, pad_alignment_targets]) for key in [sequence_key] + metadata_keys: transformations.extend(stack_and_pop(key)) ### Sets up the `Transform`s applied *after* batching. flatten_sequence_pairs = transforms.Reshape(on=sequence_key, shape=[-1, max_len]) flatten_metadata_pairs = transforms.Reshape(on=metadata_keys, shape=[-1]) create_homology_targets = align_transforms.CreateHomologyTargets( on='fam_key', out='homology/targets', process_negatives=process_negatives) create_alignment_weights = align_transforms.CreateBatchedWeights( on='alignment/targets', out='alignment/weights') add_neg_alignment_targets_and_weights = align_transforms.PadNegativePairs( on=('alignment/targets', 'alignment/weights')) pad_neg_pid = align_transforms.PadNegativePairs(on=('alignment/pid1', 'alignment/pid3'), value=-1.0) batched_transformations = [ flatten_sequence_pairs, flatten_metadata_pairs, create_homology_targets ] if process_negatives: batched_transformations.extend([ create_alignment_weights, add_neg_alignment_targets_and_weights, pad_neg_pid ]) if lm_cls is not None: create_lm_targets = lm_cls(on=sequence_key, out=(sequence_key, 'masked_lm/targets', 'masked_lm/weights')) batched_transformations.append(create_lm_targets) ### Sets up the remainder of the `DatasetBuilder` configuration. masked_lm_labels = ('masked_lm/targets', 'masked_lm/weights') alignment_labels = ('alignment/targets' if not process_negatives else ('alignment/targets', 'alignment/weights')) homology_labels = 'homology/targets' embeddings = () if lm_cls is None else (masked_lm_labels, ) alignments = (alignment_labels, homology_labels) return builder.DatasetBuilder( data_loader=loader_cls(extra_keys), ds_transformations=ds_transformations, transformations=transformations, batched_transformations=batched_transformations, labels=multi_task.Backbone(embeddings=embeddings, alignments=alignments), metadata=('seq_key', 'alignment/pid1', 'alignment/pid3'), sequence_key=sequence_key, **kwargs)