コード例 #1
0
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
    """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""

    # Determine how much weight we assign to each agreement.  In theory, we could
    # use a full blosum matrix here, but right now let's just down-weight gap
    # agreement because it could be spurious.
    # Never put weight on agreeing on BERT mask
    weights = tf.concat(
        [tf.ones(21), gap_agreement_weight * tf.ones(1),
         np.zeros(1)], 0)

    # Make agreement score as weighted Hamming distance
    sample_one_hot = (protein['msa_mask'][:, :, None] *
                      tf.one_hot(protein['msa'], 23))
    extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
                     tf.one_hot(protein['extra_msa'], 23))

    num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot)
    extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot)

    # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
    # in an optimized fashion to avoid possible memory or computation blowup.
    agreement = tf.matmul(tf.reshape(extra_one_hot,
                                     [extra_num_seq, num_res * 23]),
                          tf.reshape(sample_one_hot * weights,
                                     [num_seq, num_res * 23]),
                          transpose_b=True)

    # Assign each sequence in the extra sequences to the closest MSA sample
    protein['extra_cluster_assignment'] = tf.argmax(agreement,
                                                    axis=1,
                                                    output_type=tf.int32)

    return protein
コード例 #2
0
def make_msa_mask(protein):
    """Mask features are all ones, but will later be zero-padded."""
    protein['msa_mask'] = tf.ones(shape_helpers.shape_list(protein['msa']),
                                  dtype=tf.float32)
    protein['msa_row_mask'] = tf.ones(shape_helpers.shape_list(
        protein['msa'])[0],
                                      dtype=tf.float32)
    return protein
コード例 #3
0
def make_masked_msa(protein, config, replace_fraction):
    """Create data for BERT on raw MSA."""
    # Add a random amino acid uniformly
    random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)

    categorical_probs = (config.uniform_prob * random_aa +
                         config.profile_prob * protein['hhblits_profile'] +
                         config.same_prob * tf.one_hot(protein['msa'], 22))

    # Put all remaining probability on [MASK] which is a new column
    pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
    pad_shapes[-1][1] = 1
    mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
    assert mask_prob >= 0.
    categorical_probs = tf.pad(categorical_probs,
                               pad_shapes,
                               constant_values=mask_prob)

    sh = shape_helpers.shape_list(protein['msa'])
    mask_position = tf.random.uniform(sh) < replace_fraction

    bert_msa = shaped_categorical(categorical_probs)
    bert_msa = tf.where(mask_position, bert_msa, protein['msa'])

    # Mix real and masked MSA
    protein['bert_mask'] = tf.cast(mask_position, tf.float32)
    protein['true_msa'] = protein['msa']
    protein['msa'] = bert_msa

    return protein
コード例 #4
0
def shaped_categorical(probs, epsilon=1e-10):
    ds = shape_helpers.shape_list(probs)
    num_classes = ds[-1]
    counts = tf.random.categorical(tf.reshape(tf.log(probs + epsilon),
                                              [-1, num_classes]),
                                   1,
                                   dtype=tf.int32)
    return tf.reshape(counts, ds[:-1])
コード例 #5
0
def randomly_replace_msa_with_unknown(protein, replace_proportion):
    """Replace a proportion of the MSA with 'X'."""
    msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) <
                replace_proportion)
    x_idx = 20
    gap_idx = 21
    msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx)
    protein['msa'] = tf.where(msa_mask,
                              tf.ones_like(protein['msa']) * x_idx,
                              protein['msa'])
    aatype_mask = (tf.random.uniform(
        shape_helpers.shape_list(protein['aatype'])) < replace_proportion)

    protein['aatype'] = tf.where(aatype_mask,
                                 tf.ones_like(protein['aatype']) * x_idx,
                                 protein['aatype'])
    return protein
コード例 #6
0
    def test_shape_list(self):
        """Test that shape_list can allow for reshaping to dynamic shapes."""
        a = tf.zeros([10, 4, 4, 2])
        p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4])
        shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4]

        b = tf.reshape(a, shape_dyn)
        with self.session() as sess:
            out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))})

        self.assertAllEqual(out.shape, (20, 1, 4, 4))
コード例 #7
0
def squeeze_features(protein):
    """Remove singleton and repeated dimensions in protein features."""
    protein['aatype'] = tf.argmax(protein['aatype'],
                                  axis=-1,
                                  output_type=tf.int32)
    for k in [
            'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
            'superfamily', 'deletion_matrix', 'resolution',
            'between_segment_residues', 'residue_index',
            'template_all_atom_masks'
    ]:
        if k in protein:
            final_dim = shape_helpers.shape_list(protein[k])[-1]
            if isinstance(final_dim, int) and final_dim == 1:
                protein[k] = tf.squeeze(protein[k], axis=-1)

    for k in ['seq_length', 'num_alignments']:
        if k in protein:
            protein[k] = protein[k][0]  # Remove fake sequence dimension
    return protein
コード例 #8
0
def block_delete_msa(protein, config):
    """Sample MSA by deleting contiguous blocks.

  Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"

  Arguments:
    protein: batch dict containing the msa
    config: ConfigDict with parameters

  Returns:
    updated protein
  """
    num_seq = shape_helpers.shape_list(protein['msa'])[0]
    block_num_seq = tf.cast(
        tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
        tf.int32)

    if config.randomize_num_blocks:
        nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
    else:
        nb = config.num_blocks

    del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
    del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
    del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
    del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]

    # Make sure we keep the original sequence
    sparse_diff = tf.sets.difference(
        tf.range(1, num_seq)[None], del_indices[None])
    keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
    keep_indices = tf.concat([[0], keep_indices], axis=0)

    for k in _MSA_FEATURE_NAMES:
        if k in protein:
            protein[k] = tf.gather(protein[k], keep_indices)

    return protein
コード例 #9
0
def summarize_clusters(protein):
    """Produce profile and deletion_matrix_mean within each cluster."""
    num_seq = shape_helpers.shape_list(protein['msa'])[0]

    def csum(x):
        return tf.math.unsorted_segment_sum(
            x, protein['extra_cluster_assignment'], num_seq)

    mask = protein['extra_msa_mask']
    mask_counts = 1e-6 + protein['msa_mask'] + csum(mask)  # Include center

    msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
    msa_sum += tf.one_hot(protein['msa'], 23)  # Original sequence
    protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]

    del msa_sum

    del_sum = csum(mask * protein['extra_deletion_matrix'])
    del_sum += protein['deletion_matrix']  # Original sequence
    protein['cluster_deletion_mean'] = del_sum / mask_counts
    del del_sum

    return protein
コード例 #10
0
def make_template_mask(protein):
    protein['template_mask'] = tf.ones(shape_helpers.shape_list(
        protein['template_domain_names']),
                                       dtype=tf.float32)
    return protein
コード例 #11
0
def random_crop_to_size(protein,
                        crop_size,
                        max_templates,
                        shape_schema,
                        subsample_templates=False):
    """Crop randomly to `crop_size`, or keep as is if shorter than that."""
    seq_length = protein['seq_length']
    if 'template_mask' in protein:
        num_templates = tf.cast(
            shape_helpers.shape_list(protein['template_mask'])[0], tf.int32)
    else:
        num_templates = tf.constant(0, dtype=tf.int32)
    num_res_crop_size = tf.math.minimum(seq_length, crop_size)

    # Ensures that the cropping of residues and templates happens in the same way
    # across ensembling iterations.
    # Do not use for randomness that should vary in ensembling.
    seed_maker = utils.SeedMaker(
        initial_seed=protein['random_crop_to_size_seed'])

    if subsample_templates:
        templates_crop_start = tf.random.stateless_uniform(
            shape=(),
            minval=0,
            maxval=num_templates + 1,
            dtype=tf.int32,
            seed=seed_maker())
    else:
        templates_crop_start = 0

    num_templates_crop_size = tf.math.minimum(
        num_templates - templates_crop_start, max_templates)

    num_res_crop_start = tf.random.stateless_uniform(shape=(),
                                                     minval=0,
                                                     maxval=seq_length -
                                                     num_res_crop_size + 1,
                                                     dtype=tf.int32,
                                                     seed=seed_maker())

    templates_select_indices = tf.argsort(
        tf.random.stateless_uniform([num_templates], seed=seed_maker()))

    for k, v in protein.items():
        if k not in shape_schema or ('template' not in k
                                     and NUM_RES not in shape_schema[k]):
            continue

        # randomly permute the templates before cropping them.
        if k.startswith('template') and subsample_templates:
            v = tf.gather(v, templates_select_indices)

        crop_sizes = []
        crop_starts = []
        for i, (dim_size, dim) in enumerate(
                zip(shape_schema[k], shape_helpers.shape_list(v))):
            is_num_res = (dim_size == NUM_RES)
            if i == 0 and k.startswith('template'):
                crop_size = num_templates_crop_size
                crop_start = templates_crop_start
            else:
                crop_start = num_res_crop_start if is_num_res else 0
                crop_size = (num_res_crop_size if is_num_res else
                             (-1 if dim is None else dim))
            crop_sizes.append(crop_size)
            crop_starts.append(crop_start)
        protein[k] = tf.slice(v, crop_starts, crop_sizes)

    protein['seq_length'] = num_res_crop_size
    return protein
コード例 #12
0
def make_seq_mask(protein):
    protein['seq_mask'] = tf.ones(shape_helpers.shape_list(protein['aatype']),
                                  dtype=tf.float32)
    return protein