def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" # Determine how much weight we assign to each agreement. In theory, we could # use a full blosum matrix here, but right now let's just down-weight gap # agreement because it could be spurious. # Never put weight on agreeing on BERT mask weights = tf.concat( [tf.ones(21), gap_agreement_weight * tf.ones(1), np.zeros(1)], 0) # Make agreement score as weighted Hamming distance sample_one_hot = (protein['msa_mask'][:, :, None] * tf.one_hot(protein['msa'], 23)) extra_one_hot = (protein['extra_msa_mask'][:, :, None] * tf.one_hot(protein['extra_msa'], 23)) num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot) extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot) # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) # in an optimized fashion to avoid possible memory or computation blowup. agreement = tf.matmul(tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]), transpose_b=True) # Assign each sequence in the extra sequences to the closest MSA sample protein['extra_cluster_assignment'] = tf.argmax(agreement, axis=1, output_type=tf.int32) return protein
def make_msa_mask(protein): """Mask features are all ones, but will later be zero-padded.""" protein['msa_mask'] = tf.ones(shape_helpers.shape_list(protein['msa']), dtype=tf.float32) protein['msa_row_mask'] = tf.ones(shape_helpers.shape_list( protein['msa'])[0], dtype=tf.float32) return protein
def make_masked_msa(protein, config, replace_fraction): """Create data for BERT on raw MSA.""" # Add a random amino acid uniformly random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32) categorical_probs = (config.uniform_prob * random_aa + config.profile_prob * protein['hhblits_profile'] + config.same_prob * tf.one_hot(protein['msa'], 22)) # Put all remaining probability on [MASK] which is a new column pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] pad_shapes[-1][1] = 1 mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob assert mask_prob >= 0. categorical_probs = tf.pad(categorical_probs, pad_shapes, constant_values=mask_prob) sh = shape_helpers.shape_list(protein['msa']) mask_position = tf.random.uniform(sh) < replace_fraction bert_msa = shaped_categorical(categorical_probs) bert_msa = tf.where(mask_position, bert_msa, protein['msa']) # Mix real and masked MSA protein['bert_mask'] = tf.cast(mask_position, tf.float32) protein['true_msa'] = protein['msa'] protein['msa'] = bert_msa return protein
def shaped_categorical(probs, epsilon=1e-10): ds = shape_helpers.shape_list(probs) num_classes = ds[-1] counts = tf.random.categorical(tf.reshape(tf.log(probs + epsilon), [-1, num_classes]), 1, dtype=tf.int32) return tf.reshape(counts, ds[:-1])
def randomly_replace_msa_with_unknown(protein, replace_proportion): """Replace a proportion of the MSA with 'X'.""" msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) < replace_proportion) x_idx = 20 gap_idx = 21 msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx) protein['msa'] = tf.where(msa_mask, tf.ones_like(protein['msa']) * x_idx, protein['msa']) aatype_mask = (tf.random.uniform( shape_helpers.shape_list(protein['aatype'])) < replace_proportion) protein['aatype'] = tf.where(aatype_mask, tf.ones_like(protein['aatype']) * x_idx, protein['aatype']) return protein
def test_shape_list(self): """Test that shape_list can allow for reshaping to dynamic shapes.""" a = tf.zeros([10, 4, 4, 2]) p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4]) shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4] b = tf.reshape(a, shape_dyn) with self.session() as sess: out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))}) self.assertAllEqual(out.shape, (20, 1, 4, 4))
def squeeze_features(protein): """Remove singleton and repeated dimensions in protein features.""" protein['aatype'] = tf.argmax(protein['aatype'], axis=-1, output_type=tf.int32) for k in [ 'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix', 'resolution', 'between_segment_residues', 'residue_index', 'template_all_atom_masks' ]: if k in protein: final_dim = shape_helpers.shape_list(protein[k])[-1] if isinstance(final_dim, int) and final_dim == 1: protein[k] = tf.squeeze(protein[k], axis=-1) for k in ['seq_length', 'num_alignments']: if k in protein: protein[k] = protein[k][0] # Remove fake sequence dimension return protein
def block_delete_msa(protein, config): """Sample MSA by deleting contiguous blocks. Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion" Arguments: protein: batch dict containing the msa config: ConfigDict with parameters Returns: updated protein """ num_seq = shape_helpers.shape_list(protein['msa'])[0] block_num_seq = tf.cast( tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block), tf.int32) if config.randomize_num_blocks: nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32) else: nb = config.num_blocks del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32) del_blocks = del_block_starts[:, None] + tf.range(block_num_seq) del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1) del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0] # Make sure we keep the original sequence sparse_diff = tf.sets.difference( tf.range(1, num_seq)[None], del_indices[None]) keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0) keep_indices = tf.concat([[0], keep_indices], axis=0) for k in _MSA_FEATURE_NAMES: if k in protein: protein[k] = tf.gather(protein[k], keep_indices) return protein
def summarize_clusters(protein): """Produce profile and deletion_matrix_mean within each cluster.""" num_seq = shape_helpers.shape_list(protein['msa'])[0] def csum(x): return tf.math.unsorted_segment_sum( x, protein['extra_cluster_assignment'], num_seq) mask = protein['extra_msa_mask'] mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23)) msa_sum += tf.one_hot(protein['msa'], 23) # Original sequence protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] del msa_sum del_sum = csum(mask * protein['extra_deletion_matrix']) del_sum += protein['deletion_matrix'] # Original sequence protein['cluster_deletion_mean'] = del_sum / mask_counts del del_sum return protein
def make_template_mask(protein): protein['template_mask'] = tf.ones(shape_helpers.shape_list( protein['template_domain_names']), dtype=tf.float32) return protein
def random_crop_to_size(protein, crop_size, max_templates, shape_schema, subsample_templates=False): """Crop randomly to `crop_size`, or keep as is if shorter than that.""" seq_length = protein['seq_length'] if 'template_mask' in protein: num_templates = tf.cast( shape_helpers.shape_list(protein['template_mask'])[0], tf.int32) else: num_templates = tf.constant(0, dtype=tf.int32) num_res_crop_size = tf.math.minimum(seq_length, crop_size) # Ensures that the cropping of residues and templates happens in the same way # across ensembling iterations. # Do not use for randomness that should vary in ensembling. seed_maker = utils.SeedMaker( initial_seed=protein['random_crop_to_size_seed']) if subsample_templates: templates_crop_start = tf.random.stateless_uniform( shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32, seed=seed_maker()) else: templates_crop_start = 0 num_templates_crop_size = tf.math.minimum( num_templates - templates_crop_start, max_templates) num_res_crop_start = tf.random.stateless_uniform(shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1, dtype=tf.int32, seed=seed_maker()) templates_select_indices = tf.argsort( tf.random.stateless_uniform([num_templates], seed=seed_maker())) for k, v in protein.items(): if k not in shape_schema or ('template' not in k and NUM_RES not in shape_schema[k]): continue # randomly permute the templates before cropping them. if k.startswith('template') and subsample_templates: v = tf.gather(v, templates_select_indices) crop_sizes = [] crop_starts = [] for i, (dim_size, dim) in enumerate( zip(shape_schema[k], shape_helpers.shape_list(v))): is_num_res = (dim_size == NUM_RES) if i == 0 and k.startswith('template'): crop_size = num_templates_crop_size crop_start = templates_crop_start else: crop_start = num_res_crop_start if is_num_res else 0 crop_size = (num_res_crop_size if is_num_res else (-1 if dim is None else dim)) crop_sizes.append(crop_size) crop_starts.append(crop_start) protein[k] = tf.slice(v, crop_starts, crop_sizes) protein['seq_length'] = num_res_crop_size return protein
def make_seq_mask(protein): protein['seq_mask'] = tf.ones(shape_helpers.shape_list(protein['aatype']), dtype=tf.float32) return protein