Example #1
0
	def __init__(self, config, global_config, output_dim:int, key_dim:int, value_dim:int,
				last_bias_fuse:bool=False) -> None:
		super(AttentionFF, self).__init__()
		self.config = config
		self.global_config = global_config
		self.output_dim = output_dim
		self.last_bias_fuse = last_bias_fuse

		all_key_dim = key_dim
		all_value_dim = value_dim
		self.num_head = self.config.num_head
		assert all_key_dim == all_value_dim
		assert all_key_dim % self.num_head == 0
		assert all_value_dim % self.num_head == 0
		self.key_dim = all_key_dim // self.num_head
		self.value_dim = all_value_dim // self.num_head

		self.scaling = (1./math.sqrt(self.key_dim))
		# self.scaling = 0.0

		self.qkv_weights = Linear(all_key_dim, 3*all_key_dim, use_bias=False, initializer='glorot')
		self.o_linear = Linear(all_value_dim, self.output_dim, initializer='final', use_bias=(not last_bias_fuse))

		self.softmax = nn.Softmax(dim=-1)
		self.sigmoid = nn.Sigmoid()

		if self.config.gating:
			self.gating_linear = Linear(all_key_dim, all_value_dim, initializer='gating', use_bias=False)
			self.gating_bias = nn.Parameter(torch.ones(self.num_head * self.key_dim))
Example #2
0
    def __init__(self, config, global_config, num_channel: int) -> None:
        super(TransitionFF, self).__init__()
        self.config = config
        self.global_config = global_config

        num_intermediate = int(num_channel * config.num_intermediate_factor)
        self.input_layer_norm = LayerNormFF(num_channel)
        self.transition = nn.Sequential(
            Linear(num_channel, num_intermediate, initializer='relu'),
            nn.ReLU(),
            Linear(num_intermediate, num_channel, initializer='final'))
Example #3
0
	def __init__(self, config, global_config, target_dim: int, msa_dim: int):
		# Naming after this code:
		# https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1704
		super(InputEmbeddings, self).__init__()
		self.config = config
		self.global_config = global_config

		self.relpos_wind = config.max_relative_feature
		self.preprocess_1d = Linear(target_dim, config.msa_channel)
		self.preprocess_msa = Linear(msa_dim, config.msa_channel)
		self.left_single = Linear(target_dim, config.pair_channel)
		self.right_single = Linear(target_dim, config.pair_channel)
		self.pair_activations = Linear(2*config.max_relative_feature + 1, config.pair_channel)
Example #4
0
    def __init__(self, config, global_config, num_output_channel: int,
                 msa_dim: int) -> None:
        super(OuterProductMeanOpt, self).__init__()
        self.config = config
        self.global_config = global_config

        self.layer_norm_input = nn.LayerNorm(msa_dim)
        self.left_projection = Linear(msa_dim, config.num_outer_channel)
        self.right_projection = Linear(msa_dim, config.num_outer_channel)

        self.output_w = Linear(config.num_outer_channel *
                               config.num_outer_channel,
                               num_output_channel,
                               initializer='final')
Example #5
0
    def __init__(self,
                 config,
                 global_config,
                 output_dim: int,
                 key_dim: int,
                 value_dim: int,
                 q_chunk_size: int = None,
                 kv_chunk_size: int = None) -> None:
        super(AttentionOpt, self).__init__()
        self.config = config
        self.global_config = global_config
        self.output_dim = output_dim

        all_key_dim = key_dim
        all_value_dim = value_dim
        self.num_head = self.config.num_head
        assert all_key_dim % self.num_head == 0
        assert all_value_dim % self.num_head == 0
        self.key_dim = all_key_dim // self.num_head
        self.value_dim = all_value_dim // self.num_head

        self.q_weights = Linear(all_key_dim,
                                all_key_dim,
                                use_bias=False,
                                initializer='glorot')
        self.k_weights = Linear(all_value_dim,
                                all_value_dim,
                                use_bias=False,
                                initializer='glorot')
        self.v_weights = Linear(all_value_dim,
                                all_value_dim,
                                use_bias=False,
                                initializer='glorot')
        self.o_linear = Linear(all_value_dim,
                               self.output_dim,
                               initializer='final')

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        if self.config.gating:
            self.gating_linear = Linear(all_key_dim,
                                        all_value_dim,
                                        initializer='gating')

        #Memory optimization
        assert not ((q_chunk_size is None) ^ (kv_chunk_size is None))
        self.q_chunk_size = q_chunk_size
        self.kv_chunk_size = kv_chunk_size
Example #6
0
	def __init__(self, config, global_config, num_feat_1d:int, num_feat_2d:int) -> None:
		super(FoldIteration, self).__init__()
		self.config = config
		self.global_config = global_config

		self.attention_module = InvariantPointAttention(config, global_config, 
								num_feat_1d=num_feat_1d, num_feat_2d=num_feat_2d)
		self.attention_layer_norm = nn.LayerNorm(num_feat_1d)
		
		self.transition = nn.ModuleList([Linear(self.config.num_channel, self.config.num_channel, initializer='relu') 
										for i in range(self.config.num_layer_in_transition-1)])
		self.transition.append(Linear(self.config.num_channel, self.config.num_channel, initializer='final'))

		self.transition_layer_norm = nn.LayerNorm(self.config.num_channel)
		self.relu = nn.ReLU()

		self.affine_update = Linear(self.config.num_channel, self.affine_update_size)
		self.side_chain = MultiRigidSidechain(config.sidechain, global_config, num_repr=2, repr_dim=self.config.num_channel)
Example #7
0
	def __init__(self, config, global_config, pair_dim:int, msa_dim:int) -> None:
		super(MSARowAttentionWithPairBiasFF, self).__init__()
		self.config = config
		self.global_config = global_config
		self.query_norm = LayerNormFF(msa_dim)
		self.feat_2d_norm = LayerNormFF(pair_dim)
		self.feat_2d_weights = Linear(pair_dim, config.num_head, use_bias=False, initializer='normal')
		self.attn = AttentionFF(config, global_config, msa_dim, msa_dim, msa_dim, last_bias_fuse=True)
		self.out_bias = nn.parameter.Parameter(torch.zeros(msa_dim))
Example #8
0
    def __init__(self, config, global_config, pair_dim: int) -> None:
        super(TriangleMultiplicationFF, self).__init__()
        self.config = config
        self.global_config = global_config

        self.layer_norm_input = LayerNormFF(pair_dim)
        self.left_right_projection = Linear(
            pair_dim, 2 * config.num_intermediate_channel)
        self.left_right_gate = Linear(pair_dim,
                                      2 * config.num_intermediate_channel,
                                      initializer='gating')
        self.sigmoid = nn.Sigmoid()
        self.center_layer_norm = LayerNormFF(config.num_intermediate_channel)
        self.gating_linear = Linear(pair_dim, pair_dim, initializer='gating')
        self.output_projection = Linear(config.num_intermediate_channel,
                                        pair_dim,
                                        initializer='final',
                                        use_bias=False)
        self.out_bias = nn.parameter.Parameter(torch.zeros(pair_dim))
Example #9
0
	def __init__(self, config, global_config, num_feat_1d:int, num_feat_2d:int) -> None:
		super(StructureModule, self).__init__()
		self.config = config
		self.global_config = global_config
		
		self.single_layer_norm = nn.LayerNorm(num_feat_1d)
		self.pair_layer_norm = nn.LayerNorm(num_feat_2d)
		self.initial_projection = Linear(num_feat_1d, self.config.num_channel)

		self.fold_iteration = FoldIteration(config, global_config, num_feat_1d, num_feat_2d)
Example #10
0
    def __init__(self, config, global_config, pair_dim: int) -> None:
        super(TriangleAttentionOpt, self).__init__()
        self.config = config
        self.global_config = global_config

        self.query_norm = nn.LayerNorm(pair_dim)
        self.feat_2d_weights = Linear(pair_dim,
                                      config.num_head,
                                      use_bias=False,
                                      initializer='normal')
        self.attn = AttentionOpt(config, global_config, pair_dim, pair_dim,
                                 pair_dim)
Example #11
0
	def __init__(self, config, global_config, num_feat_2d:int, num_feat_1d:int, dist_epsilon:float=1e-8) -> None:
		super(InvariantPointAttention, self).__init__()
		self.config = config
		self.global_config = global_config
		self._dist_epsilon = dist_epsilon

		self.num_head = self.config.num_head
		self.num_scalar_qk = self.config.num_scalar_qk
		self.num_scalar_v = self.config.num_scalar_v
		self.num_point_qk = self.config.num_point_qk
		self.num_point_v = self.config.num_point_v
			

		scalar_variance = max(self.num_scalar_qk, 1) * 1.0
		point_variance = max(self.num_point_qk, 1) * 9.0/2.0
		num_logit_terms = 3
		self.scalar_weights = sqrt(1.0/(num_logit_terms*scalar_variance))
		self.point_weights = sqrt(1.0/(num_logit_terms*point_variance))
		self.attention_2d_weights = sqrt(1.0/num_logit_terms)

		self.q_scalar = Linear(num_feat_1d, self.num_head * self.num_scalar_qk)
		self.kv_scalar = Linear(num_feat_1d, self.num_head*(self.num_scalar_v + self.num_scalar_qk))
		self.q_point_local = Linear(num_feat_1d, self.num_head * 3 * self.num_point_qk)
		self.kv_point_local = Linear(num_feat_1d, self.num_head * 3 * (self.num_point_qk + self.num_point_v))
		self.trainable_point_weights = nn.Parameter(torch.ones(self.num_head))
		self.attention_2d = Linear(num_feat_2d, self.num_head)
		self.output_pojection = Linear(self.num_head * (num_feat_2d + self.num_scalar_v + 4*self.num_point_v), self.config.num_channel)

		self.softplus = nn.Softplus()
		self.softmax = nn.Softmax(dim=-1)
Example #12
0
    def __init__(self, config, global_config, output_dim: int, key_dim: int,
                 value_dim: int) -> None:
        super(GlobalAttentionOptB, self).__init__()
        self.config = config
        self.global_config = global_config
        self.output_dim = output_dim

        all_key_dim = key_dim
        all_value_dim = value_dim
        self.num_head = self.config.num_head
        assert all_key_dim % self.num_head == 0
        assert all_value_dim % self.num_head == 0
        self.key_dim = all_key_dim // self.num_head
        self.value_dim = all_value_dim // self.num_head

        self.q_weights = Linear(all_key_dim,
                                all_key_dim,
                                use_bias=False,
                                initializer='glorot')
        self.k_weights = Linear(all_value_dim,
                                self.key_dim,
                                use_bias=False,
                                initializer='glorot')
        self.v_weights = Linear(all_value_dim,
                                self.value_dim,
                                use_bias=False,
                                initializer='glorot')
        self.o_linear = Linear(all_value_dim,
                               self.output_dim,
                               initializer='final')

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        if self.config.gating:
            self.gating_linear = Linear(all_key_dim,
                                        all_value_dim,
                                        initializer='gating')
Example #13
0
	def __init__(self, config, global_config):
		super(RecycleEmbedding, self).__init__()
		self.config = config
		self.global_config = global_config

		#Naming of the layers are:
		#https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1730
		self.prev_pos_linear = Linear(config.prev_pos.num_bins, config.pair_channel)
		#https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1745
		self.prev_pair_norm = nn.LayerNorm(config.pair_channel)
		#https://github.com/lupoglaz/alphafold/blob/2d53ad87efedcbbda8e67ab3be96af769dbeae7d/alphafold/model/modules.py#L1736
		self.prev_msa_first_row_norm = nn.LayerNorm(config.msa_channel)

		self.bins = torch.linspace(config.prev_pos.min_bin, config.prev_pos.max_bin, config.prev_pos.num_bins)
Example #14
0
	def __init__(self, config, global_config, target_dim:int, msa_dim:int, extra_msa_dim:int) -> None:
		super(EmbeddingsAndEvoformer, self).__init__()
		self.config = config
		self.global_config = global_config
		
		self.input_emb = InputEmbeddings(config, global_config, msa_dim=msa_dim, target_dim=target_dim)
		self.recycle_emb = RecycleEmbedding(config, global_config)
		self.extra_msa_emb = ExtraMSAEmbedding(config, global_config, msa_dim=extra_msa_dim)
		self.extra_msa_stack = nn.ModuleList()
		for i in range(self.config.extra_msa_stack_num_block):
			self.extra_msa_stack.append(EvoformerIterationFF(	config.evoformer, global_config, 
															msa_dim=config.extra_msa_channel, 
															pair_dim=config.pair_channel, 
															is_extra_msa=True))
		self.evoformer_stack = nn.ModuleList()
		for i in range(self.config.evoformer_num_block):
			self.evoformer_stack.append(EvoformerIterationFF(	config.evoformer, global_config, 
															msa_dim=config.msa_channel, 
															pair_dim=config.pair_channel, 
															is_extra_msa=False))
		self.single_activations = Linear(config.msa_channel, config.seq_channel)
Example #15
0
	def __init__(self, config, global_config, repr_dim:int, num_repr:int) -> None:
		super(MultiRigidSidechain, self).__init__()
		self.config = config
		self.global_config = global_config

		self.num_repr = num_repr

		self.input_projection = nn.ModuleList([Linear(repr_dim, self.config.num_channel)
											for i in range(num_repr)])

		self.resblock1 = nn.ModuleList([Linear(self.config.num_channel, self.config.num_channel, initializer='relu') 
										for i in range(self.config.num_residual_block-1)])
		self.resblock1.append(Linear(self.config.num_channel, self.config.num_channel, initializer='final'))

		self.resblock2 = nn.ModuleList([Linear(self.config.num_channel, self.config.num_channel, initializer='relu') 
										for i in range(self.config.num_residual_block-1)])
		self.resblock2.append(Linear(self.config.num_channel, self.config.num_channel, initializer='final'))

		self.unnormalized_angles = Linear(self.config.num_channel, 14)
		self.relu = nn.ReLU()
Example #16
0
    def __init__(self, config, global_config, pair_dim: int) -> None:
        super(TriangleMultiplicationOpt, self).__init__()
        self.config = config
        self.global_config = global_config

        self.layer_norm_input = nn.LayerNorm(pair_dim)
        self.left_projection = Linear(pair_dim,
                                      config.num_intermediate_channel)
        self.right_projection = Linear(pair_dim,
                                       config.num_intermediate_channel)
        self.left_gate = Linear(pair_dim,
                                config.num_intermediate_channel,
                                initializer='gating')
        self.right_gate = Linear(pair_dim,
                                 config.num_intermediate_channel,
                                 initializer='gating')
        self.sigmoid = nn.Sigmoid()
        self.center_layer_norm = nn.LayerNorm(config.num_intermediate_channel)
        self.output_projection = Linear(config.num_intermediate_channel,
                                        pair_dim,
                                        initializer='final')
        self.gating_linear = Linear(pair_dim, pair_dim, initializer='gating')
Example #17
0
	def __init__(self, config, global_config, msa_dim: int):
		super(ExtraMSAEmbedding, self).__init__()
		self.config = config
		self.global_config = global_config
		self.extra_msa_activations = Linear(msa_dim, config.extra_msa_channel)