Ejemplo n.º 1
0
    def forward(self,
                pair_act: torch.Tensor,
                pair_mask: torch.Tensor,
                is_training: bool = False) -> torch.Tensor:
        assert pair_act.ndimension() == 3
        assert pair_mask.ndimension() == 2
        if self.config.orientation == 'per_column':
            pair_act = pair_act.transpose(-2, -3)
            pair_mask = pair_mask.transpose(-1, -2)

        bias = (1e9 * (pair_mask.to(dtype=pair_act.dtype) - 1.0))[:, None,
                                                                  None, :]
        assert bias.ndimension() == 4

        pair_act = self.query_norm(pair_act)
        nonbatched_bias = self.feat_2d_weights(pair_act)
        nonbatched_bias = permute_final_dims(nonbatched_bias, (2, 0, 1))
        # pair_act = self.attn(pair_act, pair_act, bias, nonbatched_bias)
        pair_act = inference_subbatch(self.attn,
                                      self.global_config.subbatch_size,
                                      batched_args=[pair_act, pair_act, bias],
                                      nonbatched_args=[nonbatched_bias],
                                      low_memory=(not is_training))

        if self.config.orientation == 'per_column':
            pair_act = pair_act.transpose(-2, -3)

        return pair_act
Ejemplo n.º 2
0
    def forward(self,
                msa_act: torch.Tensor,
                msa_mask: torch.Tensor,
                pair_act: torch.Tensor,
                is_training: bool = False):
        assert msa_act.ndimension() == 3
        assert msa_mask.ndimension() == 2
        assert self.config.orientation == 'per_row'

        bias = (1e9 * (msa_mask.to(dtype=msa_act.dtype) - 1.0))[:, None,
                                                                None, :]
        msa_act = self.query_norm(msa_act)
        pair_act = self.feat_2d_norm(pair_act)
        nonbatched_bias = self.feat_2d_weights(pair_act)
        nonbatched_bias = permute_final_dims(nonbatched_bias, (2, 0, 1))
        # msa_act = self.attn(msa_act, msa_act, bias, nonbatched_bias)
        msa_act = inference_subbatch(self.attn,
                                     self.global_config.subbatch_size,
                                     batched_args=[msa_act, msa_act, bias],
                                     nonbatched_args=[nonbatched_bias],
                                     low_memory=(not is_training))
        return msa_act
Ejemplo n.º 3
0
    def forward(self,
                pair_act_raw: torch.Tensor,
                pair_mask: torch.Tensor,
                is_training: bool = False) -> torch.Tensor:
        pair_mask = pair_mask[..., None]
        input_act = self.layer_norm_input(pair_act_raw)

        left_right_proj_act = self.left_right_projection(input_act)
        left_right_proj_act = left_right_proj_act * pair_mask

        left_right_proj_act *= self.sigmoid(self.left_right_gate(input_act))
        left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)

        if self.config.equation == 'ikc,jkc->ijc':  #triangle_multiplication_outgoing
            left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
            right_proj_act = permute_final_dims(right_proj_act, (2, 1, 0))
            act = torch.matmul(left_proj_act, right_proj_act)
            act = permute_final_dims(act, (1, 2, 0))

        elif self.config.equation == 'kjc,kic->ijc':  #triangle_multiplication_incoming
            left_proj_act = permute_final_dims(left_proj_act, (2, 1, 0))
            right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
            act = torch.matmul(left_proj_act, right_proj_act)
            act = permute_final_dims(act, (2, 1, 0))
        else:
            raise ValueError(f"Unknown equation: {self.config.equation}")

        act = self.center_layer_norm(act)
        act = self.output_projection(act)

        gate_values = self.sigmoid(self.gating_linear(input_act))

        dropout_mask = torch.ones_like(act, device=act.device, dtype=act.dtype)
        return bias_ele_dropout_residual(act,
                                         self.out_bias,
                                         gate_values,
                                         dropout_mask,
                                         pair_act_raw,
                                         prob=self.config.dropout_rate,
                                         training=is_training)
Ejemplo n.º 4
0
    def forward(self,
                pair_act: torch.Tensor,
                pair_mask: torch.Tensor,
                is_training: bool = False) -> torch.Tensor:
        pair_mask = pair_mask[..., None]
        input_act = self.layer_norm_input(pair_act)

        left_proj_act = self.left_projection(input_act) * pair_mask
        right_proj_act = self.right_projection(input_act) * pair_mask

        left_gate_values = self.sigmoid(self.left_gate(input_act))
        right_gate_values = self.sigmoid(self.right_gate(input_act))

        left_proj_act *= left_gate_values
        right_proj_act *= right_gate_values

        if self.config.equation == 'ikc,jkc->ijc':  #triangle_multiplication_outgoing
            left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
            right_proj_act = permute_final_dims(right_proj_act, (2, 1, 0))
            act = torch.matmul(left_proj_act, right_proj_act)
            act = permute_final_dims(act, (1, 2, 0))

        elif self.config.equation == 'kjc,kic->ijc':  #triangle_multiplication_incoming
            left_proj_act = permute_final_dims(left_proj_act, (2, 1, 0))
            right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
            act = torch.matmul(left_proj_act, right_proj_act)
            act = permute_final_dims(act, (2, 1, 0))
        else:
            raise ValueError(f"Unknown equation: {self.config.equation}")

        act = self.center_layer_norm(act)
        act = self.output_projection(act)

        gate_values = self.sigmoid(self.gating_linear(input_act))
        act *= gate_values

        return act
Ejemplo n.º 5
0
    def forward(self, inputs_1d: torch.Tensor, inputs_2d: torch.Tensor,
                mask: torch.Tensor, affine) -> torch.Tensor:
        assert inputs_1d.ndimension() == 3
        assert inputs_2d.ndimension() == 4
        assert affine.translation[0].ndimension() == 2
        assert mask.ndimension() == 3

        batch_size = inputs_1d.size(0)
        num_res = inputs_1d.size(1)
        q_scalar = self.q_scalar(inputs_1d)
        q_scalar = q_scalar.view(batch_size, num_res, self.num_head,
                                 self.num_scalar_qk)

        affine.cast_to(dtype=torch.float32)  #All affine operations to float32

        kv_scalar = self.kv_scalar(inputs_1d)
        kv_scalar = kv_scalar.view(batch_size, num_res, self.num_head,
                                   self.num_scalar_v + self.num_scalar_qk)
        k_scalar, v_scalar = kv_scalar.split(self.num_scalar_qk, dim=-1)

        q_point_local = self.q_point_local(inputs_1d)
        q_point_local = q_point_local.split(self.num_head * self.num_point_qk,
                                            dim=-1)

        #Float32 region
        q_point_local = [res.to(dtype=torch.float32) for res in q_point_local]
        q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
        q_point_global = [
            res.to(dtype=q_scalar.dtype) for res in q_point_global
        ]
        #Float32 region

        q_point = [
            x.view(batch_size, num_res, self.num_head, self.num_point_qk)
            for x in q_point_global
        ]

        kv_point_local = self.kv_point_local(inputs_1d)
        kv_point_local = kv_point_local.split(
            self.num_head * (self.num_point_qk + self.num_point_v), dim=-1)

        #Float32 region
        kv_point_local = [
            res.to(dtype=torch.float32) for res in kv_point_local
        ]
        kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
        kv_point_global = [
            res.to(dtype=k_scalar.dtype) for res in kv_point_global
        ]
        #Float32 region

        kv_point_global = [
            x.view(batch_size, num_res, self.num_head,
                   self.num_point_qk + self.num_point_v)
            for x in kv_point_global
        ]
        k_point, v_point = list(
            zip(*[(x[..., :self.num_point_qk], x[..., self.num_point_qk:])
                  for x in kv_point_global]))

        point_weights = self.softplus(
            self.trainable_point_weights).unsqueeze(dim=1) * self.point_weights
        v_point = [x.transpose(-2, -3) for x in v_point]
        q_point = [x.transpose(-2, -3) for x in q_point]
        k_point = [x.transpose(-2, -3) for x in k_point]
        dist2 = [
            torch.pow(qx[..., :, None, :] - kx[..., None, :, :], 2)
            for qx, kx in zip(q_point, k_point)
        ]
        dist2 = sum(dist2)

        attn_qk_point = -0.5 * torch.sum(
            point_weights[..., None, None, :] * dist2, dim=-1)

        v = v_scalar.transpose(-2, -3)
        q = (self.scalar_weights * q_scalar).transpose(-2, -3)
        k = k_scalar.transpose(-2, -3)

        attn_qk_scalar = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_qk_scalar + attn_qk_point
        attention_2d = self.attention_2d(inputs_2d)
        attn_logits += permute_final_dims(attention_2d, (2, 0, 1)) * float(
            self.attention_2d_weights)

        mask_2d = mask * (mask.transpose(-1, -2))
        attn_logits -= 1e5 * (1.0 - mask_2d[..., None, :, :])

        attn = self.softmax(attn_logits)
        result_scalar = torch.matmul(attn, v).transpose(-2, -3)
        result_point_global = [
            torch.sum(attn[..., None] * vx[..., None, :, :],
                      dim=-2).transpose(-2, -3) for vx in v_point
        ]

        output_features = []
        result_scalar = result_scalar.reshape(
            batch_size, num_res, self.num_head * self.num_scalar_v)
        output_features.append(result_scalar)

        result_point_global = [
            r.reshape(batch_size, num_res, self.num_head * self.num_point_v)
            for r in result_point_global
        ]

        ## VVV Float32 region (geometry region)
        # affine.cast_to(dtype=torch.float32)
        result_point_global = [
            res.to(dtype=torch.float32) for res in result_point_global
        ]

        result_point_local = affine.invert_point(result_point_global,
                                                 extra_dims=1)
        dist = torch.sqrt(self._dist_epsilon +
                          torch.pow(result_point_local[0], 2) +
                          torch.pow(result_point_local[1], 2) +
                          torch.pow(result_point_local[2], 2))

        dist = dist.to(dtype=result_scalar.dtype)
        affine.cast_to(dtype=result_scalar.dtype)
        result_point_local = [
            res.to(dtype=result_scalar.dtype) for res in result_point_local
        ]
        # ^^^ Float32 region

        output_features.extend(result_point_local)
        output_features.append(dist)

        result_attention_over_2d = torch.einsum('bhij,bijc->bihc', attn,
                                                inputs_2d)
        num_out = self.num_head * result_attention_over_2d.shape[-1]
        output_features.append(
            result_attention_over_2d.view(batch_size, num_res, num_out))
        final_act = torch.cat(output_features, dim=-1)

        return self.output_pojection(final_act)
Ejemplo n.º 6
0
    def forward(self,
                q_data: torch.Tensor,
                m_data: torch.Tensor,
                bias: torch.Tensor,
                nonbatched_bias: torch.Tensor = None) -> torch.Tensor:
        """
		Arguments: 
			q_data: [batch_size, num_queries, querry_dim]
			m_data: [batch_size, num_keys, value_dim]
			bias: [batch_size, num_queries, num_keys]
			nonbatched_bias: [num_queries, num_keys]
		Returns:
			[batch_size, num_queries, output_dim]
		"""
        flat_head = lambda t: t.view(t.shape[:-1] + (self.num_head, -1))
        assert self.key_dim * self.num_head == q_data.size(-1)
        assert self.value_dim * self.num_head == m_data.size(-1)

        q = self.q_weights(q_data) * (1. / math.sqrt(self.key_dim))
        q = flat_head(q)

        k = self.k_weights(m_data)
        k = flat_head(k)

        v = self.v_weights(m_data)
        v = flat_head(v)

        #Low memory:
        if not (self.q_chunk_size is None):
            weighted_avg = self.iterative_qkv(q, k, v, bias, nonbatched_bias)

        #High memory:
        else:
            q = permute_final_dims(q, (1, 0, 2))
            k = permute_final_dims(k, (1, 2, 0))
            logits = torch.matmul(q, k) + bias
            # print('Opt',q.size(), k.size())
            # return torch.matmul(q, k)
            del q, k

            if not (nonbatched_bias is None):
                # print('Opt nonbbias:', nonbatched_bias.size(), bias.size())
                logits += nonbatched_bias.unsqueeze(dim=0)
            # print('Opt', logits.size(), bias.size())
            weights = self.softmax(logits)
            # print(bias.size(), weights.size())
            # print('Opt:',nonbatched_bias[0,2,:])
            # print('Opt:',weights[0,0,2,:])
            # print('Opt weights:', weights.sum())
            # return weights
            v = permute_final_dims(v, (1, 0, 2))

            weighted_avg = torch.matmul(weights, v).transpose(-2, -3)
            # print('Opt weighted_avg:', weighted_avg.size())

        if self.config.gating:
            gate_values = self.gating_linear(q_data)
            gate_values = self.sigmoid(gate_values)
            gate_values = flat_head(gate_values)
            weighted_avg *= gate_values

        weighted_avg = flatten_final_dims(weighted_avg, 2)
        output = self.o_linear(weighted_avg)

        return output