def test_disable_backprop_homology(self): encoder_cls = functools.partial(encoders.LookupEncoder, emb_dim=32) aligner_cls = functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties) heads_cls = multi_task.Backbone( embeddings=[nlp_layers.DensePerTokenOutputHead], alignments=[dedal.Selector, homology.LogCorrectedLogits]) backprop = multi_task.Backbone(embeddings=[True], alignments=[True, False]) model = dedal.Dedal(encoder_cls=encoder_cls, aligner_cls=aligner_cls, heads_cls=heads_cls, backprop=backprop) inputs = tf.random.uniform((4, 16), maxval=25, dtype=tf.int32) with tf.GradientTape(persistent=True) as tape: y_pred = model(inputs).flatten() dummy_loss = tf.reduce_sum(y_pred['alignments/1']**2) grads_encoder = tape.gradient( dummy_loss, model.encoder.trainable_variables[0], unconnected_gradients=tf.UnconnectedGradients.ZERO) self.assertAllClose(tf.linalg.norm(grads_encoder), 0.0) with tf.GradientTape(persistent=True) as tape: y_pred = model(inputs).flatten() dummy_loss = tf.reduce_sum(y_pred['embeddings/0']**2) grads_encoder = tape.gradient( dummy_loss, model.encoder.trainable_variables[0], unconnected_gradients=tf.UnconnectedGradients.ZERO) self.assertGreater(tf.linalg.norm(grads_encoder), 0.0)
def __init__(self, workdir, strategy, split=None, task=None, scalars=multi_task.Backbone(), images=multi_task.Backbone(), means=(), every=1000, reset_every_step=False, start_clock=True): """Initialization. Args: workdir: the parent directory where to store data. strategy: distribution strategy. split: usually the name of the phase (train, test, valid). task: usually the name of the task (train, evaluate, downstream). scalars: the scalar metrics to be computed and dumped. images: the image metrics to be computed and dumped. means: the name of the scalar metrics that will be means. At the very least, "loss" and "gradient_norm" will be present. every: the periodicity to log the metrics. reset_every_step: whether to reset the metrics at every step. start_clock: whether or not to start the clock at instantiation. """ split = '' if split is None else split self.workdir = os.path.join(workdir, split).rstrip('/') self._split = split self._task = task self._timer = timer.Timer() self._reset_every_step = reset_every_step self.training = task == 'train' # Take the bigger network structure. shape = tuple(max(scalars.shape[i], images.shape[i]) for i in range(2)) enveloppe = multi_task.Backbone.constant_from_shape([], shape) means = set(means).union(['loss']) if self.training: means = means.union(['gradient_norm']) with strategy.scope(): self._scalars = enveloppe.pack([[metric_factory(m) for m in ms] for ms in scalars], default_value=[]) self._images = enveloppe.pack([[metric_factory(m) for m in ms] for ms in images], default_value=[]) self._means = {name: tf.keras.metrics.Mean(name) for name in means} self._summary_writer = tf.summary.create_file_writer(self.workdir) self._every = every self._last_step = None if self.training else 0 if start_clock: self.restart_clock()
def setUp(self): super().setUp() self.dim = 3 self.heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits], alignments=[ homology.UncorrectedLogits, homology.LogCorrectedLogits ]) self.model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder, aligner_cls=functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties, align_fn=fake_align), heads_cls=self.heads_cls) batch1, batch2 = 32, 16 seq_len1, seq_len2 = 100, 50 self.inputs1 = tf.random.uniform((batch1, seq_len1), maxval=35, dtype=tf.int32) self.inputs2 = tf.random.uniform((batch2, seq_len2), maxval=35, dtype=tf.int32) self.switch = multi_task.SwitchBackbone(embeddings=[1], alignments=[0, 0])
def test_alignment_outputs(self): """A test with complex outputs.""" heads_cls = multi_task.Backbone(embeddings=[FakeSequenceLogits], alignments=[ dedal.Selector, homology.UncorrectedLogits, homology.LogCorrectedLogits ]) model = dedal.Dedal(encoder_cls=encoders.OneHotEncoder, aligner_cls=functools.partial( aligners.SoftAligner, gap_pen_cls=aligners.ConstantGapPenalties, align_fn=fake_align), heads_cls=heads_cls, process_negatives=True) preds = model(self.inputs) self.assertEqual(preds.embeddings[0].shape, (32, 10)) self.assertEqual(preds.alignments[1].shape, (32, 1)) self.assertEqual(preds.alignments[2].shape, (32, 1)) aligment_pred = preds.alignments[0] self.assertLen(aligment_pred, 3) scores, paths, sw_params = aligment_pred self.assertEqual(scores.shape, (32, )) self.assertLen(sw_params, 3) self.assertIsNone(paths) self.assertEqual(sw_params[0].shape, (32, 100, 100))
def _make_loop_with_reference(self): """Creates a training loop for alignments with self._loop as reference.""" seq_len = gin.query_parameter('%SEQUENCE_LENGTH') model_cls = functools.partial( dedal.Dedal, encoder_cls=functools.partial(encoders.TransformerEncoder, emb_dim=48, num_layers=1, num_heads=2, mlp_dim=3 * seq_len, max_len=seq_len), aligner_cls=aligners.SoftAligner, heads_cls=multi_task.Backbone( embeddings=[], alignments=[homology.UncorrectedLogits]), ) workdir2 = tempfile.mkdtemp() ds_builder = builder.DatasetBuilder(labels=multi_task.Backbone( alignments=[('target', 'weights')])) return training_loop.TrainingLoop( workdir=workdir2, strategy=self._loop.strategy, dataset_builder=ds_builder, logger_cls=functools.partial( logger.Logger, scalars=multi_task.Backbone( alignments=[[tf.keras.metrics.BinaryAccuracy]]), every=5), loss_fn=losses.MultiTaskLoss(losses=multi_task.Backbone( alignments=[ tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) ])), optimizer_cls=self._loop._optimizer_cls, batch_size=self._loop._batch_size, model_cls=model_cls, num_steps=self._loop._num_steps, reference_workdir=self._loop._workdir, num_reference_steps=self._loop._num_steps)
def __init__(self, encoder_cls=gin.REQUIRED, aligner_cls=gin.REQUIRED, heads_cls=gin.REQUIRED, process_negatives=True, switch=None, backprop=None, **kwargs): """Initializes. Args: encoder_cls: a keras layer (or model) that turns a batch of sequences into a batch of sequence embeddings. aligner_cls: a layer that turns a pairs of sequence embeddings into scores (and optionally paths). heads_cls: for a multi-task setup all the layers to be plugged either on the embeddings or on the alignments. process_negatives: should the network consider the negatives or not in a batch. Aligning might be expensive, so switching off negative supervision saves half of the alignment cost. switch: optional mapping of output heads to inputs, required only when calling the model in multi-input mode. See `call` for additional details. backprop: for a multi-task setup, whether to backprop each loss from the output head to the encoder (i.e. finetune the encoder on the loss) or train the output head params only. **kwargs: optional keyword arguments to be passed to `tf.keras.Model`. """ super().__init__(**kwargs) self.encoder = encoder_cls() self.aligner = aligner_cls() if aligner_cls else None self.heads = multi_task.Backbone( embeddings=[ head_cls() if head_cls is not None else None for head_cls in heads_cls.embeddings ], alignments=[ head_cls() if head_cls is not None else None for head_cls in heads_cls.alignments ]) self.process_negatives = process_negatives self.switch = switch if self.switch is None: self.switch = multi_task.SwitchBackbone.constant_like(self.heads) self.backprop = backprop if self.backprop is None: self.backprop = self.heads.constant_copy(True) # For TF to keep track of variables self._flat_heads = self.heads.flatten()
def make_fake_builder(max_len = 512): return builder.DatasetBuilder( data_loader=FakePairsLoader(max_len=max_len), labels=multi_task.Backbone(alignments=[ ('alignment/targets', 'alignment/weights'), 'homology/targets']), batched_transformations=[ transforms.Reshape(shape=(-1, max_len), on='sequence'), transforms.Reshape(shape=(-1,), on='fam_key'), align_transforms.CreateBatchedWeights( on='alignment/targets', out='alignment/weights'), align_transforms.PadNegativePairs( on=['alignment/targets', 'alignment/weights']), align_transforms.PadNegativePairs( value=-1.0, on=['alignment/pid1', 'alignment/pid3']), align_transforms.CreateHomologyTargets( process_negatives=True, on='fam_key', out='homology/targets'), ], metadata=('alignment/pid1', 'alignment/pid3'), split='train')
def forward(self, inputs, selector=None, training=True): """Run the models on a single input and potentially selects some heads only. Args: inputs: a Tensor<int32>[batch, seq_len] representing protein sequences. selector: If set a multi_task.Backbone[bool] to specify which head to apply. For non selected heads, a None will replace the output. If not set, all the heads will be output. training: whether to run in training mode or eval mode. Returns: A multi_task.Backbone of tensor corresponding to the output of the different heads of the model. """ selector = self.heads.constant_copy( True) if selector is None else selector embeddings = self.encoder(inputs, training=training) masks = self.encoder.compute_mask(inputs) result = multi_task.Backbone() for head, on, backprop, in zip(self.heads.embeddings, selector.embeddings, self.backprop.embeddings): head_output = self.head_output(head, on, backprop, embeddings, mask=masks, training=training) result.embeddings.append(head_output) if not self.heads.alignments or not any(selector.alignments): # Ensures structure of result matches self.heads even when method skips # alignment phase due to selector. for _ in selector.alignments: result.alignments.append(tf.constant([])) return result # For each head, we compute the output of positive pairs and negative ones, # then concatenate to obtain an output batch where the first half is # positive and the second half is negative. outputs = [] pos_indices = pairs_lib.consecutive_indices(inputs) neg_indices = (pairs_lib.roll_indices(pos_indices) if self.process_negatives else None) num_alignment_calls = 1 + int(self.process_negatives) for indices in (pos_indices, neg_indices)[:num_alignment_calls]: curr = [] embedding_pairs, mask_pairs = pairs_lib.build( indices, embeddings, masks) alignments = self.aligner(embedding_pairs, mask=mask_pairs, training=training) for head, on, backprop, in zip(self.heads.alignments, selector.alignments, self.backprop.alignments): head_output = self.head_output(head, on, backprop, alignments, mask=mask_pairs, training=training) curr.append(head_output) outputs.append(curr) for output in merge(*outputs): result.alignments.append(output) return result
def test_backbone(self): values = [1, 2, 4] container = multi_task.Backbone(values) self.assertSequenceEqual(container.embeddings, values) self.assertSequenceEqual(container.alignments, []) self.assertSequenceEqual(list(container), values)
def make_tape_builder(root_dir, task, target, weights=None, metadata=(), max_len=1024, input_sequence_key='primary', output_sequence_key='sequence'): """Creates a DatasetBuilder for TAPE's benchmark.""" supported_tasks = list(TAPE_NUM_OUTPUTS) if task not in supported_tasks: raise ValueError(f'Task {task} not recognized.' f'Supported tasks: {", ".join(supported_tasks)}.') num_outputs = TAPE_NUM_OUTPUTS[task].get(target, 1) used_keys = [input_sequence_key, target] if weights is not None: used_keys.append(weights) if metadata: used_keys.extend(metadata) unused_keys = [k for k in TAPE_SPECS[task] if k not in used_keys] ds_transformations = [] if max_len is not None: ds_transformations.append( transforms.FilterByLength(on=output_sequence_key, precomputed=False, max_len=max_len - 1)) transformations = [ transforms.Pop(on=unused_keys), transforms.Reshape(on=output_sequence_key, shape=[]), transforms.Encode(on=output_sequence_key), transforms.EOS(on=output_sequence_key), transforms.CropOrPad(on=output_sequence_key, size=max_len), ] if target in TAPE_MULTI_CL_TASKS: transformations.append(transforms.OneHot(on=target, depth=num_outputs)) elif target in TAPE_BACKBONE_ANGLE_TASKS: transformations.append(transforms.BackboneAngleTransform(on=target)) elif target in TAPE_PROT_ENGINEERING_TASKS: transformations.append(transforms.Reshape(on=target, shape=[-1])) if target in TAPE_SEQ2SEQ_TASKS: transformations.extend([ transforms.Reshape(on=target, shape=[-1, num_outputs]), transforms.CropOrPadND(on=target, size=max_len, axis=0), ]) if weights is not None: # Note: no seq-level TAPE task has weights. transformations.extend([ transforms.Reshape(on=weights, shape=[-1]), transforms.CropOrPadND(on=weights, size=max_len), ]) embeddings_labels = [target] if weights is None else [(target, weights)] return builder.DatasetBuilder( data_loader=make_tape_loader(root_dir=root_dir, task=task), ds_transformations=ds_transformations, transformations=transformations, labels=multi_task.Backbone(embeddings=embeddings_labels), metadata=metadata, sequence_key=output_sequence_key)
def make_pair_builder(max_len=512, index_keys=('fam_key', 'ci_100'), process_negatives=True, gap_token='-', sequence_key='sequence', context_sequence_key='full_sequence', loader_cls=make_pfam_pairs_loader, pairing_cls=None, lm_cls=None, has_context=False, append_eos=True, append_eos_context=True, add_random_tails=False, **kwargs): """Creates a dataset for pairs of sequences.""" # Convenience function to index key pairs. paired_keys = lambda k: tuple(f'{k}_{i}' for i in (1, 2)) def stack_and_pop(on): stack = transforms.Stack(on=paired_keys(on), out=on) pop = transforms.Pop(on=paired_keys(on)) return [stack, pop] # Defines fields to be read from the TFRecords. metadata_keys = ['cla_key', 'seq_key'] + list(index_keys) extra_keys = metadata_keys.copy() # Pre-paired datasets already been filtered by length, seq_len only needed # when pairing sequences on-the-fly. if pairing_cls is not None: extra_keys.append('seq_len') # Optionally, adds fields needed by the `AddAlignmentContext` `Transform`. if has_context: extra_keys.extend(['start', 'end']) add_alignment_context_extra_args = (paired_keys(context_sequence_key) + paired_keys('start') + paired_keys('end')) # Accounts for EOS token if necessary. max_len_eos = max_len - 1 if append_eos else max_len ### Sets up the `DatasetTransform`s. ds_transformations = [] if pairing_cls is not None: filter_by_length = transforms.FilterByLength(max_len=max_len_eos) # NOTE(fllinares): pairing on-the-fly is memory intensive on TPU for some # reason not yet understood... pair_sequences = pairing_cls(index_keys=index_keys) ds_transformations.extend([filter_by_length, pair_sequences]) ### Sets up the `Transform`s applied *before* batching. project_msa_rows = align_transforms.ProjectMSARows( on=paired_keys(sequence_key), token=gap_token) append_eos_to_context = transforms.EOS( on=paired_keys(context_sequence_key)) add_alignment_context = align_transforms.AddAlignmentContext( on=paired_keys(sequence_key) + add_alignment_context_extra_args, out=paired_keys(sequence_key), max_len=max_len_eos, gap_token=gap_token) trim_alignment = align_transforms.TrimAlignment( on=paired_keys(sequence_key), gap_token=gap_token) pop_add_alignment_context_extra_args = transforms.Pop( on=add_alignment_context_extra_args) add_random_prefix_and_suffix = align_transforms.AddRandomTails( on=paired_keys(sequence_key), max_len=max_len_eos) create_alignment_targets = align_transforms.CreateAlignmentTargets( on=paired_keys(sequence_key), out='alignment/targets', gap_token=gap_token, n_prepend_tokens=0) pid1 = align_transforms.PID(on=paired_keys(sequence_key), out='alignment/pid1', definition=1, token=gap_token) pid3 = align_transforms.PID(on=paired_keys(sequence_key), out='alignment/pid3', definition=3, token=gap_token) remove_gaps = transforms.RemoveTokens(on=paired_keys(sequence_key), tokens=gap_token) append_eos_to_sequence = transforms.EOS(on=paired_keys(sequence_key)) pad_sequences = transforms.CropOrPad(on=paired_keys(sequence_key), size=max_len) pad_alignment_targets = transforms.CropOrPadND(on='alignment/targets', size=2 * max_len) transformations = [project_msa_rows] if has_context: if append_eos_context: transformations.append(append_eos_to_context) transformations.extend([ add_alignment_context, trim_alignment, pop_add_alignment_context_extra_args ]) if add_random_tails: transformations.append(add_random_prefix_and_suffix) transformations.append(create_alignment_targets) transformations.extend([pid1, pid3, remove_gaps]) if append_eos: transformations.append(append_eos_to_sequence) transformations.extend([pad_sequences, pad_alignment_targets]) for key in [sequence_key] + metadata_keys: transformations.extend(stack_and_pop(key)) ### Sets up the `Transform`s applied *after* batching. flatten_sequence_pairs = transforms.Reshape(on=sequence_key, shape=[-1, max_len]) flatten_metadata_pairs = transforms.Reshape(on=metadata_keys, shape=[-1]) create_homology_targets = align_transforms.CreateHomologyTargets( on='fam_key', out='homology/targets', process_negatives=process_negatives) create_alignment_weights = align_transforms.CreateBatchedWeights( on='alignment/targets', out='alignment/weights') add_neg_alignment_targets_and_weights = align_transforms.PadNegativePairs( on=('alignment/targets', 'alignment/weights')) pad_neg_pid = align_transforms.PadNegativePairs(on=('alignment/pid1', 'alignment/pid3'), value=-1.0) batched_transformations = [ flatten_sequence_pairs, flatten_metadata_pairs, create_homology_targets ] if process_negatives: batched_transformations.extend([ create_alignment_weights, add_neg_alignment_targets_and_weights, pad_neg_pid ]) if lm_cls is not None: create_lm_targets = lm_cls(on=sequence_key, out=(sequence_key, 'masked_lm/targets', 'masked_lm/weights')) batched_transformations.append(create_lm_targets) ### Sets up the remainder of the `DatasetBuilder` configuration. masked_lm_labels = ('masked_lm/targets', 'masked_lm/weights') alignment_labels = ('alignment/targets' if not process_negatives else ('alignment/targets', 'alignment/weights')) homology_labels = 'homology/targets' embeddings = () if lm_cls is None else (masked_lm_labels, ) alignments = (alignment_labels, homology_labels) return builder.DatasetBuilder( data_loader=loader_cls(extra_keys), ds_transformations=ds_transformations, transformations=transformations, batched_transformations=batched_transformations, labels=multi_task.Backbone(embeddings=embeddings, alignments=alignments), metadata=('seq_key', 'alignment/pid1', 'alignment/pid3'), sequence_key=sequence_key, **kwargs)