def __init__(self, config, global_config, output_dim:int, key_dim:int, value_dim:int, last_bias_fuse:bool=False) -> None: super(AttentionFF, self).__init__() self.config = config self.global_config = global_config self.output_dim = output_dim self.last_bias_fuse = last_bias_fuse all_key_dim = key_dim all_value_dim = value_dim self.num_head = self.config.num_head assert all_key_dim == all_value_dim assert all_key_dim % self.num_head == 0 assert all_value_dim % self.num_head == 0 self.key_dim = all_key_dim // self.num_head self.value_dim = all_value_dim // self.num_head self.scaling = (1./math.sqrt(self.key_dim)) # self.scaling = 0.0 self.qkv_weights = Linear(all_key_dim, 3*all_key_dim, use_bias=False, initializer='glorot') self.o_linear = Linear(all_value_dim, self.output_dim, initializer='final', use_bias=(not last_bias_fuse)) self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() if self.config.gating: self.gating_linear = Linear(all_key_dim, all_value_dim, initializer='gating', use_bias=False) self.gating_bias = nn.Parameter(torch.ones(self.num_head * self.key_dim))
def __init__(self, config, global_config, num_channel: int) -> None: super(TransitionFF, self).__init__() self.config = config self.global_config = global_config num_intermediate = int(num_channel * config.num_intermediate_factor) self.input_layer_norm = LayerNormFF(num_channel) self.transition = nn.Sequential( Linear(num_channel, num_intermediate, initializer='relu'), nn.ReLU(), Linear(num_intermediate, num_channel, initializer='final'))
def __init__(self, config, global_config, target_dim: int, msa_dim: int): # Naming after this code: # https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1704 super(InputEmbeddings, self).__init__() self.config = config self.global_config = global_config self.relpos_wind = config.max_relative_feature self.preprocess_1d = Linear(target_dim, config.msa_channel) self.preprocess_msa = Linear(msa_dim, config.msa_channel) self.left_single = Linear(target_dim, config.pair_channel) self.right_single = Linear(target_dim, config.pair_channel) self.pair_activations = Linear(2*config.max_relative_feature + 1, config.pair_channel)
def __init__(self, config, global_config, num_output_channel: int, msa_dim: int) -> None: super(OuterProductMeanOpt, self).__init__() self.config = config self.global_config = global_config self.layer_norm_input = nn.LayerNorm(msa_dim) self.left_projection = Linear(msa_dim, config.num_outer_channel) self.right_projection = Linear(msa_dim, config.num_outer_channel) self.output_w = Linear(config.num_outer_channel * config.num_outer_channel, num_output_channel, initializer='final')
def __init__(self, config, global_config, output_dim: int, key_dim: int, value_dim: int, q_chunk_size: int = None, kv_chunk_size: int = None) -> None: super(AttentionOpt, self).__init__() self.config = config self.global_config = global_config self.output_dim = output_dim all_key_dim = key_dim all_value_dim = value_dim self.num_head = self.config.num_head assert all_key_dim % self.num_head == 0 assert all_value_dim % self.num_head == 0 self.key_dim = all_key_dim // self.num_head self.value_dim = all_value_dim // self.num_head self.q_weights = Linear(all_key_dim, all_key_dim, use_bias=False, initializer='glorot') self.k_weights = Linear(all_value_dim, all_value_dim, use_bias=False, initializer='glorot') self.v_weights = Linear(all_value_dim, all_value_dim, use_bias=False, initializer='glorot') self.o_linear = Linear(all_value_dim, self.output_dim, initializer='final') self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() if self.config.gating: self.gating_linear = Linear(all_key_dim, all_value_dim, initializer='gating') #Memory optimization assert not ((q_chunk_size is None) ^ (kv_chunk_size is None)) self.q_chunk_size = q_chunk_size self.kv_chunk_size = kv_chunk_size
def __init__(self, config, global_config, num_feat_1d:int, num_feat_2d:int) -> None: super(FoldIteration, self).__init__() self.config = config self.global_config = global_config self.attention_module = InvariantPointAttention(config, global_config, num_feat_1d=num_feat_1d, num_feat_2d=num_feat_2d) self.attention_layer_norm = nn.LayerNorm(num_feat_1d) self.transition = nn.ModuleList([Linear(self.config.num_channel, self.config.num_channel, initializer='relu') for i in range(self.config.num_layer_in_transition-1)]) self.transition.append(Linear(self.config.num_channel, self.config.num_channel, initializer='final')) self.transition_layer_norm = nn.LayerNorm(self.config.num_channel) self.relu = nn.ReLU() self.affine_update = Linear(self.config.num_channel, self.affine_update_size) self.side_chain = MultiRigidSidechain(config.sidechain, global_config, num_repr=2, repr_dim=self.config.num_channel)
def __init__(self, config, global_config, pair_dim:int, msa_dim:int) -> None: super(MSARowAttentionWithPairBiasFF, self).__init__() self.config = config self.global_config = global_config self.query_norm = LayerNormFF(msa_dim) self.feat_2d_norm = LayerNormFF(pair_dim) self.feat_2d_weights = Linear(pair_dim, config.num_head, use_bias=False, initializer='normal') self.attn = AttentionFF(config, global_config, msa_dim, msa_dim, msa_dim, last_bias_fuse=True) self.out_bias = nn.parameter.Parameter(torch.zeros(msa_dim))
def __init__(self, config, global_config, pair_dim: int) -> None: super(TriangleMultiplicationFF, self).__init__() self.config = config self.global_config = global_config self.layer_norm_input = LayerNormFF(pair_dim) self.left_right_projection = Linear( pair_dim, 2 * config.num_intermediate_channel) self.left_right_gate = Linear(pair_dim, 2 * config.num_intermediate_channel, initializer='gating') self.sigmoid = nn.Sigmoid() self.center_layer_norm = LayerNormFF(config.num_intermediate_channel) self.gating_linear = Linear(pair_dim, pair_dim, initializer='gating') self.output_projection = Linear(config.num_intermediate_channel, pair_dim, initializer='final', use_bias=False) self.out_bias = nn.parameter.Parameter(torch.zeros(pair_dim))
def __init__(self, config, global_config, num_feat_1d:int, num_feat_2d:int) -> None: super(StructureModule, self).__init__() self.config = config self.global_config = global_config self.single_layer_norm = nn.LayerNorm(num_feat_1d) self.pair_layer_norm = nn.LayerNorm(num_feat_2d) self.initial_projection = Linear(num_feat_1d, self.config.num_channel) self.fold_iteration = FoldIteration(config, global_config, num_feat_1d, num_feat_2d)
def __init__(self, config, global_config, pair_dim: int) -> None: super(TriangleAttentionOpt, self).__init__() self.config = config self.global_config = global_config self.query_norm = nn.LayerNorm(pair_dim) self.feat_2d_weights = Linear(pair_dim, config.num_head, use_bias=False, initializer='normal') self.attn = AttentionOpt(config, global_config, pair_dim, pair_dim, pair_dim)
def __init__(self, config, global_config, num_feat_2d:int, num_feat_1d:int, dist_epsilon:float=1e-8) -> None: super(InvariantPointAttention, self).__init__() self.config = config self.global_config = global_config self._dist_epsilon = dist_epsilon self.num_head = self.config.num_head self.num_scalar_qk = self.config.num_scalar_qk self.num_scalar_v = self.config.num_scalar_v self.num_point_qk = self.config.num_point_qk self.num_point_v = self.config.num_point_v scalar_variance = max(self.num_scalar_qk, 1) * 1.0 point_variance = max(self.num_point_qk, 1) * 9.0/2.0 num_logit_terms = 3 self.scalar_weights = sqrt(1.0/(num_logit_terms*scalar_variance)) self.point_weights = sqrt(1.0/(num_logit_terms*point_variance)) self.attention_2d_weights = sqrt(1.0/num_logit_terms) self.q_scalar = Linear(num_feat_1d, self.num_head * self.num_scalar_qk) self.kv_scalar = Linear(num_feat_1d, self.num_head*(self.num_scalar_v + self.num_scalar_qk)) self.q_point_local = Linear(num_feat_1d, self.num_head * 3 * self.num_point_qk) self.kv_point_local = Linear(num_feat_1d, self.num_head * 3 * (self.num_point_qk + self.num_point_v)) self.trainable_point_weights = nn.Parameter(torch.ones(self.num_head)) self.attention_2d = Linear(num_feat_2d, self.num_head) self.output_pojection = Linear(self.num_head * (num_feat_2d + self.num_scalar_v + 4*self.num_point_v), self.config.num_channel) self.softplus = nn.Softplus() self.softmax = nn.Softmax(dim=-1)
def __init__(self, config, global_config, output_dim: int, key_dim: int, value_dim: int) -> None: super(GlobalAttentionOptB, self).__init__() self.config = config self.global_config = global_config self.output_dim = output_dim all_key_dim = key_dim all_value_dim = value_dim self.num_head = self.config.num_head assert all_key_dim % self.num_head == 0 assert all_value_dim % self.num_head == 0 self.key_dim = all_key_dim // self.num_head self.value_dim = all_value_dim // self.num_head self.q_weights = Linear(all_key_dim, all_key_dim, use_bias=False, initializer='glorot') self.k_weights = Linear(all_value_dim, self.key_dim, use_bias=False, initializer='glorot') self.v_weights = Linear(all_value_dim, self.value_dim, use_bias=False, initializer='glorot') self.o_linear = Linear(all_value_dim, self.output_dim, initializer='final') self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() if self.config.gating: self.gating_linear = Linear(all_key_dim, all_value_dim, initializer='gating')
def __init__(self, config, global_config): super(RecycleEmbedding, self).__init__() self.config = config self.global_config = global_config #Naming of the layers are: #https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1730 self.prev_pos_linear = Linear(config.prev_pos.num_bins, config.pair_channel) #https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1745 self.prev_pair_norm = nn.LayerNorm(config.pair_channel) #https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1736 self.prev_msa_first_row_norm = nn.LayerNorm(config.msa_channel) self.bins = torch.linspace(config.prev_pos.min_bin, config.prev_pos.max_bin, config.prev_pos.num_bins)
def __init__(self, config, global_config, target_dim:int, msa_dim:int, extra_msa_dim:int) -> None: super(EmbeddingsAndEvoformer, self).__init__() self.config = config self.global_config = global_config self.input_emb = InputEmbeddings(config, global_config, msa_dim=msa_dim, target_dim=target_dim) self.recycle_emb = RecycleEmbedding(config, global_config) self.extra_msa_emb = ExtraMSAEmbedding(config, global_config, msa_dim=extra_msa_dim) self.extra_msa_stack = nn.ModuleList() for i in range(self.config.extra_msa_stack_num_block): self.extra_msa_stack.append(EvoformerIterationFF( config.evoformer, global_config, msa_dim=config.extra_msa_channel, pair_dim=config.pair_channel, is_extra_msa=True)) self.evoformer_stack = nn.ModuleList() for i in range(self.config.evoformer_num_block): self.evoformer_stack.append(EvoformerIterationFF( config.evoformer, global_config, msa_dim=config.msa_channel, pair_dim=config.pair_channel, is_extra_msa=False)) self.single_activations = Linear(config.msa_channel, config.seq_channel)
def __init__(self, config, global_config, repr_dim:int, num_repr:int) -> None: super(MultiRigidSidechain, self).__init__() self.config = config self.global_config = global_config self.num_repr = num_repr self.input_projection = nn.ModuleList([Linear(repr_dim, self.config.num_channel) for i in range(num_repr)]) self.resblock1 = nn.ModuleList([Linear(self.config.num_channel, self.config.num_channel, initializer='relu') for i in range(self.config.num_residual_block-1)]) self.resblock1.append(Linear(self.config.num_channel, self.config.num_channel, initializer='final')) self.resblock2 = nn.ModuleList([Linear(self.config.num_channel, self.config.num_channel, initializer='relu') for i in range(self.config.num_residual_block-1)]) self.resblock2.append(Linear(self.config.num_channel, self.config.num_channel, initializer='final')) self.unnormalized_angles = Linear(self.config.num_channel, 14) self.relu = nn.ReLU()
def __init__(self, config, global_config, pair_dim: int) -> None: super(TriangleMultiplicationOpt, self).__init__() self.config = config self.global_config = global_config self.layer_norm_input = nn.LayerNorm(pair_dim) self.left_projection = Linear(pair_dim, config.num_intermediate_channel) self.right_projection = Linear(pair_dim, config.num_intermediate_channel) self.left_gate = Linear(pair_dim, config.num_intermediate_channel, initializer='gating') self.right_gate = Linear(pair_dim, config.num_intermediate_channel, initializer='gating') self.sigmoid = nn.Sigmoid() self.center_layer_norm = nn.LayerNorm(config.num_intermediate_channel) self.output_projection = Linear(config.num_intermediate_channel, pair_dim, initializer='final') self.gating_linear = Linear(pair_dim, pair_dim, initializer='gating')
def __init__(self, config, global_config, msa_dim: int): super(ExtraMSAEmbedding, self).__init__() self.config = config self.global_config = global_config self.extra_msa_activations = Linear(msa_dim, config.extra_msa_channel)