def forward(self, pair_act: torch.Tensor, pair_mask: torch.Tensor, is_training: bool = False) -> torch.Tensor: assert pair_act.ndimension() == 3 assert pair_mask.ndimension() == 2 if self.config.orientation == 'per_column': pair_act = pair_act.transpose(-2, -3) pair_mask = pair_mask.transpose(-1, -2) bias = (1e9 * (pair_mask.to(dtype=pair_act.dtype) - 1.0))[:, None, None, :] assert bias.ndimension() == 4 pair_act = self.query_norm(pair_act) nonbatched_bias = self.feat_2d_weights(pair_act) nonbatched_bias = permute_final_dims(nonbatched_bias, (2, 0, 1)) # pair_act = self.attn(pair_act, pair_act, bias, nonbatched_bias) pair_act = inference_subbatch(self.attn, self.global_config.subbatch_size, batched_args=[pair_act, pair_act, bias], nonbatched_args=[nonbatched_bias], low_memory=(not is_training)) if self.config.orientation == 'per_column': pair_act = pair_act.transpose(-2, -3) return pair_act
def forward(self, msa_act: torch.Tensor, msa_mask: torch.Tensor, pair_act: torch.Tensor, is_training: bool = False): assert msa_act.ndimension() == 3 assert msa_mask.ndimension() == 2 assert self.config.orientation == 'per_row' bias = (1e9 * (msa_mask.to(dtype=msa_act.dtype) - 1.0))[:, None, None, :] msa_act = self.query_norm(msa_act) pair_act = self.feat_2d_norm(pair_act) nonbatched_bias = self.feat_2d_weights(pair_act) nonbatched_bias = permute_final_dims(nonbatched_bias, (2, 0, 1)) # msa_act = self.attn(msa_act, msa_act, bias, nonbatched_bias) msa_act = inference_subbatch(self.attn, self.global_config.subbatch_size, batched_args=[msa_act, msa_act, bias], nonbatched_args=[nonbatched_bias], low_memory=(not is_training)) return msa_act
def forward(self, pair_act_raw: torch.Tensor, pair_mask: torch.Tensor, is_training: bool = False) -> torch.Tensor: pair_mask = pair_mask[..., None] input_act = self.layer_norm_input(pair_act_raw) left_right_proj_act = self.left_right_projection(input_act) left_right_proj_act = left_right_proj_act * pair_mask left_right_proj_act *= self.sigmoid(self.left_right_gate(input_act)) left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1) if self.config.equation == 'ikc,jkc->ijc': #triangle_multiplication_outgoing left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1)) right_proj_act = permute_final_dims(right_proj_act, (2, 1, 0)) act = torch.matmul(left_proj_act, right_proj_act) act = permute_final_dims(act, (1, 2, 0)) elif self.config.equation == 'kjc,kic->ijc': #triangle_multiplication_incoming left_proj_act = permute_final_dims(left_proj_act, (2, 1, 0)) right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1)) act = torch.matmul(left_proj_act, right_proj_act) act = permute_final_dims(act, (2, 1, 0)) else: raise ValueError(f"Unknown equation: {self.config.equation}") act = self.center_layer_norm(act) act = self.output_projection(act) gate_values = self.sigmoid(self.gating_linear(input_act)) dropout_mask = torch.ones_like(act, device=act.device, dtype=act.dtype) return bias_ele_dropout_residual(act, self.out_bias, gate_values, dropout_mask, pair_act_raw, prob=self.config.dropout_rate, training=is_training)
def forward(self, pair_act: torch.Tensor, pair_mask: torch.Tensor, is_training: bool = False) -> torch.Tensor: pair_mask = pair_mask[..., None] input_act = self.layer_norm_input(pair_act) left_proj_act = self.left_projection(input_act) * pair_mask right_proj_act = self.right_projection(input_act) * pair_mask left_gate_values = self.sigmoid(self.left_gate(input_act)) right_gate_values = self.sigmoid(self.right_gate(input_act)) left_proj_act *= left_gate_values right_proj_act *= right_gate_values if self.config.equation == 'ikc,jkc->ijc': #triangle_multiplication_outgoing left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1)) right_proj_act = permute_final_dims(right_proj_act, (2, 1, 0)) act = torch.matmul(left_proj_act, right_proj_act) act = permute_final_dims(act, (1, 2, 0)) elif self.config.equation == 'kjc,kic->ijc': #triangle_multiplication_incoming left_proj_act = permute_final_dims(left_proj_act, (2, 1, 0)) right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1)) act = torch.matmul(left_proj_act, right_proj_act) act = permute_final_dims(act, (2, 1, 0)) else: raise ValueError(f"Unknown equation: {self.config.equation}") act = self.center_layer_norm(act) act = self.output_projection(act) gate_values = self.sigmoid(self.gating_linear(input_act)) act *= gate_values return act
def forward(self, inputs_1d: torch.Tensor, inputs_2d: torch.Tensor, mask: torch.Tensor, affine) -> torch.Tensor: assert inputs_1d.ndimension() == 3 assert inputs_2d.ndimension() == 4 assert affine.translation[0].ndimension() == 2 assert mask.ndimension() == 3 batch_size = inputs_1d.size(0) num_res = inputs_1d.size(1) q_scalar = self.q_scalar(inputs_1d) q_scalar = q_scalar.view(batch_size, num_res, self.num_head, self.num_scalar_qk) affine.cast_to(dtype=torch.float32) #All affine operations to float32 kv_scalar = self.kv_scalar(inputs_1d) kv_scalar = kv_scalar.view(batch_size, num_res, self.num_head, self.num_scalar_v + self.num_scalar_qk) k_scalar, v_scalar = kv_scalar.split(self.num_scalar_qk, dim=-1) q_point_local = self.q_point_local(inputs_1d) q_point_local = q_point_local.split(self.num_head * self.num_point_qk, dim=-1) #Float32 region q_point_local = [res.to(dtype=torch.float32) for res in q_point_local] q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) q_point_global = [ res.to(dtype=q_scalar.dtype) for res in q_point_global ] #Float32 region q_point = [ x.view(batch_size, num_res, self.num_head, self.num_point_qk) for x in q_point_global ] kv_point_local = self.kv_point_local(inputs_1d) kv_point_local = kv_point_local.split( self.num_head * (self.num_point_qk + self.num_point_v), dim=-1) #Float32 region kv_point_local = [ res.to(dtype=torch.float32) for res in kv_point_local ] kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) kv_point_global = [ res.to(dtype=k_scalar.dtype) for res in kv_point_global ] #Float32 region kv_point_global = [ x.view(batch_size, num_res, self.num_head, self.num_point_qk + self.num_point_v) for x in kv_point_global ] k_point, v_point = list( zip(*[(x[..., :self.num_point_qk], x[..., self.num_point_qk:]) for x in kv_point_global])) point_weights = self.softplus( self.trainable_point_weights).unsqueeze(dim=1) * self.point_weights v_point = [x.transpose(-2, -3) for x in v_point] q_point = [x.transpose(-2, -3) for x in q_point] k_point = [x.transpose(-2, -3) for x in k_point] dist2 = [ torch.pow(qx[..., :, None, :] - kx[..., None, :, :], 2) for qx, kx in zip(q_point, k_point) ] dist2 = sum(dist2) attn_qk_point = -0.5 * torch.sum( point_weights[..., None, None, :] * dist2, dim=-1) v = v_scalar.transpose(-2, -3) q = (self.scalar_weights * q_scalar).transpose(-2, -3) k = k_scalar.transpose(-2, -3) attn_qk_scalar = torch.matmul(q, k.transpose(-2, -1)) attn_logits = attn_qk_scalar + attn_qk_point attention_2d = self.attention_2d(inputs_2d) attn_logits += permute_final_dims(attention_2d, (2, 0, 1)) * float( self.attention_2d_weights) mask_2d = mask * (mask.transpose(-1, -2)) attn_logits -= 1e5 * (1.0 - mask_2d[..., None, :, :]) attn = self.softmax(attn_logits) result_scalar = torch.matmul(attn, v).transpose(-2, -3) result_point_global = [ torch.sum(attn[..., None] * vx[..., None, :, :], dim=-2).transpose(-2, -3) for vx in v_point ] output_features = [] result_scalar = result_scalar.reshape( batch_size, num_res, self.num_head * self.num_scalar_v) output_features.append(result_scalar) result_point_global = [ r.reshape(batch_size, num_res, self.num_head * self.num_point_v) for r in result_point_global ] ## VVV Float32 region (geometry region) # affine.cast_to(dtype=torch.float32) result_point_global = [ res.to(dtype=torch.float32) for res in result_point_global ] result_point_local = affine.invert_point(result_point_global, extra_dims=1) dist = torch.sqrt(self._dist_epsilon + torch.pow(result_point_local[0], 2) + torch.pow(result_point_local[1], 2) + torch.pow(result_point_local[2], 2)) dist = dist.to(dtype=result_scalar.dtype) affine.cast_to(dtype=result_scalar.dtype) result_point_local = [ res.to(dtype=result_scalar.dtype) for res in result_point_local ] # ^^^ Float32 region output_features.extend(result_point_local) output_features.append(dist) result_attention_over_2d = torch.einsum('bhij,bijc->bihc', attn, inputs_2d) num_out = self.num_head * result_attention_over_2d.shape[-1] output_features.append( result_attention_over_2d.view(batch_size, num_res, num_out)) final_act = torch.cat(output_features, dim=-1) return self.output_pojection(final_act)
def forward(self, q_data: torch.Tensor, m_data: torch.Tensor, bias: torch.Tensor, nonbatched_bias: torch.Tensor = None) -> torch.Tensor: """ Arguments: q_data: [batch_size, num_queries, querry_dim] m_data: [batch_size, num_keys, value_dim] bias: [batch_size, num_queries, num_keys] nonbatched_bias: [num_queries, num_keys] Returns: [batch_size, num_queries, output_dim] """ flat_head = lambda t: t.view(t.shape[:-1] + (self.num_head, -1)) assert self.key_dim * self.num_head == q_data.size(-1) assert self.value_dim * self.num_head == m_data.size(-1) q = self.q_weights(q_data) * (1. / math.sqrt(self.key_dim)) q = flat_head(q) k = self.k_weights(m_data) k = flat_head(k) v = self.v_weights(m_data) v = flat_head(v) #Low memory: if not (self.q_chunk_size is None): weighted_avg = self.iterative_qkv(q, k, v, bias, nonbatched_bias) #High memory: else: q = permute_final_dims(q, (1, 0, 2)) k = permute_final_dims(k, (1, 2, 0)) logits = torch.matmul(q, k) + bias # print('Opt',q.size(), k.size()) # return torch.matmul(q, k) del q, k if not (nonbatched_bias is None): # print('Opt nonbbias:', nonbatched_bias.size(), bias.size()) logits += nonbatched_bias.unsqueeze(dim=0) # print('Opt', logits.size(), bias.size()) weights = self.softmax(logits) # print(bias.size(), weights.size()) # print('Opt:',nonbatched_bias[0,2,:]) # print('Opt:',weights[0,0,2,:]) # print('Opt weights:', weights.sum()) # return weights v = permute_final_dims(v, (1, 0, 2)) weighted_avg = torch.matmul(weights, v).transpose(-2, -3) # print('Opt weighted_avg:', weighted_avg.size()) if self.config.gating: gate_values = self.gating_linear(q_data) gate_values = self.sigmoid(gate_values) gate_values = flat_head(gate_values) weighted_avg *= gate_values weighted_avg = flatten_final_dims(weighted_avg, 2) output = self.o_linear(weighted_avg) return output