예제 #1
0
 def create_loss(self, prediction, config):
     pre, pos_r, q_emb, p_emb, H_i_emb = prediction
     weight = config.get('hyper_parameters.negative_weight', 0.5)
     loss = weight * paddle.sum(
         paddle.sum(
             paddle.sum(paddle.einsum('ab,ac->abc', q_emb, q_emb), 0) *
             paddle.sum(paddle.einsum('ab,ac->abc', p_emb, p_emb), 0) *
             paddle.matmul(H_i_emb, H_i_emb, transpose_y=True), 0), 0)
     loss += paddle.sum((1.0 - weight) * paddle.square(pos_r) - 2.0 * pos_r)
     return loss
예제 #2
0
    def test_static_graph(self):
        paddle.enable_static()
        fluid = paddle.fluid
        if fluid.core.is_compiled_with_cuda():
            self.place = fluid.CUDAPlace(0)
        else:
            self.place = fluid.CPUPlace()
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            a = paddle.static.data(name='a',
                                   shape=[3, None, None, None],
                                   dtype='float')
            b = paddle.static.data(name='b',
                                   shape=[2, None, None, None],
                                   dtype='float')
            c = paddle.static.data(name='c',
                                   shape=[None, None, 2, None],
                                   dtype='float')
            d = paddle.static.data(name='d',
                                   shape=[None, None, 5],
                                   dtype='float')
            e = paddle.static.data(name='e',
                                   shape=[None, 2, None],
                                   dtype='float')

            outs = []
            outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b))
            outs.append(paddle.einsum('...ik, ...j', c, d))
            outs.append(paddle.einsum('...kj, ...ik', d, e))
            outs.append(paddle.einsum('ijk..., ikj', c, e))
            outs.append(paddle.einsum('ijk..., ikj->...ij', c, e))
        exe = fluid.Executor(self.place)
        exe.run(startup)
        a = np.arange(72).reshape(3, 2, 3, 4).astype('float')
        b = np.arange(48).reshape(2, 2, 3, 4).astype('float')
        c = np.arange(48).reshape(2, 3, 2, 4).astype('float')
        d = np.arange(30).reshape(2, 3, 5).astype('float')
        e = np.arange(12).reshape(2, 2, 3).astype('float')
        feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e}
        actual = exe.run(main, feed=feeds, fetch_list=[outs])
        expect = []
        expect.append(np.einsum("ibnd,jbnd->bnij", a, b))
        expect.append(np.einsum('...ik, ...j', c, d))
        expect.append(np.einsum('...kj, ...ik', d, e))
        expect.append(np.einsum('ijk..., ikj', c, e))
        expect.append(np.einsum('ijk..., ikj->...ij', c, e))
        for a, e in zip(actual, expect):
            self.check_output_equal(a, e)
예제 #3
0
def supervised_chi_loss(ret, batch, value, config):
    """Computes loss for direct chi angle supervision.

    Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss"

    Args:
        ret: Dictionary to write outputs into, needs to contain 'loss'.
        batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'.
        value: Dictionary containing structure module output, needs to contain
            value['sidechains']['angles_sin_cos'] for angles and
            value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized
            angles.
        config: Configuration of loss, should contain 'chi_weight' and
            'angle_norm_weight', 'angle_norm_weight' scales angle norm term,
            'chi_weight' scales torsion term.
    """
    eps = 1e-6
    
    sequence_mask = batch['seq_mask']
    num_res = sequence_mask.shape[1]
    batch_size = sequence_mask.shape[0]
    chi_mask = batch['chi_mask']
    pred_angles = paddle.reshape(value['sidechains']['angles_sin_cos'], [batch_size, -1, num_res, 7, 2])
    pred_angles = pred_angles[:, :, :, 3:]

    residue_type_one_hot = paddle.nn.functional.one_hot(batch['aatype_index'], 
                            num_classes=residue_constants.restype_num + 1)
    chi_pi_periodic = paddle.einsum('nijk, nkl->nijl', residue_type_one_hot[:, None, ...], 
                            paddle.to_tensor(residue_constants.chi_pi_periodic)[None])

    sin_cos_true_chi = batch['chi_angles_sin_cos'][:, None, ...]

    # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic
    shifted_mask = (1 - 2 * chi_pi_periodic)[..., None]
    sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi

    sq_chi_error = paddle.sum(squared_difference(sin_cos_true_chi, pred_angles), axis=-1)
    sq_chi_error_shifted = paddle.sum(squared_difference(sin_cos_true_chi_shifted, pred_angles), axis=-1)
    sq_chi_error = paddle.minimum(sq_chi_error, sq_chi_error_shifted)

    sq_chi_loss_tmp = []
    for i in range(batch_size):
        sq_chi_loss_i = utils.mask_mean(mask=paddle.unsqueeze(chi_mask[i], axis=0), value=sq_chi_error[i])
        sq_chi_loss_tmp.append(sq_chi_loss_i)
    sq_chi_loss = paddle.to_tensor(sq_chi_loss_tmp, stop_gradient=False)
    sq_chi_loss = paddle.squeeze(sq_chi_loss, axis=-1)
    ret['chi_loss'] = sq_chi_loss
    ret['loss'] += config.chi_weight * sq_chi_loss

    unnormed_angles = paddle.reshape(value['sidechains']['unnormalized_angles_sin_cos'], [batch_size, -1, num_res, 7, 2])
    angle_norm = paddle.sqrt(paddle.sum(paddle.square(unnormed_angles), axis=-1) + eps)
    norm_error = paddle.abs(angle_norm - 1.)
    angle_norm_loss_tmp = []
    for i in range(batch_size):
        angle_norm_loss_i = utils.mask_mean(mask=paddle.unsqueeze(sequence_mask[i], axis=[0,2]), value=norm_error[i])
        angle_norm_loss_tmp.append(angle_norm_loss_i)
    angle_norm_loss = paddle.to_tensor(angle_norm_loss_tmp, stop_gradient=False)
    angle_norm_loss = paddle.squeeze(angle_norm_loss, axis=-1)
    ret['angle_norm_loss'] = angle_norm_loss
    ret['loss'] += config.angle_norm_weight * angle_norm_loss
예제 #4
0
 def check_output(self, eqn, *ops):
     expect = np.einsum(eqn, *ops)
     with paddle.fluid.dygraph.guard(
             self._get_place(force_to_use_cpu=False)):
         pd_operands = [paddle.to_tensor(op) for op in ops]
         actual = paddle.einsum(eqn, *pd_operands)
         self.check_output_equal(actual.numpy(), expect)
예제 #5
0
 def _get_rand_mask(self, blocked_query_mask, blocked_key_mask,
                    rand_mask_idx, batch_size, sequence_length):
     '''
     return random mask: [B, H, L-G, bs, R * bs]
     '''
     # rand_mask_idx: [H, T]
     # blocked_query_mask: [B, L, bs]
     # blocked_key_mask: [B, L, bs]
     bs = self.block_size
     B = batch_size
     L = sequence_length // bs
     H = self.num_heads
     G = self.num_global_blocks
     GB = self.num_global_blocks_back
     GF = self.num_global_blocks_front
     R = self.num_rand_blocks
     temp_block_key_mask = paddle.unsqueeze(blocked_key_mask, 1)
     temp_block_key_mask = paddle.expand(temp_block_key_mask, [B, H, L, -1])
     temp_block_key_mask_list = [
         paddle.gather_nd(temp_block_key_mask[b], rand_mask_idx)
         for b in range(B)
     ]
     temp_block_key_mask = paddle.concat(temp_block_key_mask_list, 0)
     temp_block_key_mask = paddle.reshape(temp_block_key_mask, [
         B, temp_block_key_mask.shape[0] // B //
         (L - GF - GB) // R, L - GF - GB, -1
     ])
     rand_mask = paddle.einsum("blq,bhlk->bhlqk",
                               blocked_query_mask[:, GF:-GB],
                               temp_block_key_mask)
     return rand_mask
예제 #6
0
def _hsv_to_rgb(img):
    """Convert a image Tensor from HSV to RGB.
    """
    h, s, v = img.unbind(axis=-3)
    f = h * 6.0
    i = paddle.floor(f)
    f = f - i
    i = i.astype(paddle.int32) % 6

    p = paddle.clip(v * (1.0 - s), 0.0, 1.0)
    q = paddle.clip(v * (1.0 - s * f), 0.0, 1.0)
    t = paddle.clip(v * (1.0 - s * (1.0 - f)), 0.0, 1.0)

    mask = paddle.equal(
        i.unsqueeze(axis=-3),
        paddle.arange(
            6, dtype=i.dtype).reshape((-1, 1, 1))).astype(img.dtype)
    matrix = paddle.stack(
        [
            paddle.stack(
                [v, q, p, p, t, v], axis=-3), paddle.stack(
                    [t, v, v, q, p, p], axis=-3), paddle.stack(
                        [p, p, t, v, v, q], axis=-3)
        ],
        axis=-4)
    return paddle.einsum("...ijk, ...xijk -> ...xjk", mask, matrix)
예제 #7
0
def rotation_3d_in_axis(points, angles, axis=0):
    # points: [N, point_size, 3]
    # angles: [N]
    rot_sin = paddle.sin(angles)
    rot_cos = paddle.cos(angles)
    ones = paddle.ones_like(rot_cos)
    zeros = paddle.zeros_like(rot_cos)
    if axis == 1:
        rot_mat_T = paddle.stack([
            paddle.stack([rot_cos, zeros, -rot_sin]),
            paddle.stack([zeros, ones, zeros]),
            paddle.stack([rot_sin, zeros, rot_cos])
        ])
    elif axis == 2 or axis == -1:
        rot_mat_T = paddle.stack([
            paddle.stack([rot_cos, -rot_sin, zeros]),
            paddle.stack([rot_sin, rot_cos, zeros]),
            paddle.stack([zeros, zeros, ones])
        ])
    elif axis == 0:
        rot_mat_T = paddle.stack([
            paddle.stack([zeros, rot_cos, -rot_sin]),
            paddle.stack([zeros, rot_sin, rot_cos]),
            paddle.stack([ones, zeros, zeros])
        ])
    else:
        raise ValueError("axis should in range")

    return paddle.einsum('aij,jka->aik', (points, rot_mat_T))
예제 #8
0
    def test_forward(self):
        operands = [
            TestEinsum.TEST_SAMPLES[operand] for operand in self.sample["data"]
        ]
        expected_result = np.einsum(self.sample["paradigm"], *operands)
        equation = self.sample["paradigm"]

        with paddle.fluid.dygraph.guard(
                self._get_place(force_to_use_cpu=False)):
            pd_operands = [paddle.to_tensor(operand) for operand in operands]
            result = paddle.einsum(equation, *pd_operands)
            self.check_output_equal(result.numpy(), expected_result)

        with paddle.fluid.dygraph.guard(self._get_place(force_to_use_cpu=True)):
            pd_operands = [paddle.to_tensor(operand) for operand in operands]
            result = paddle.einsum(equation, *pd_operands)
            self.check_output_equal(result.numpy(), expected_result)
예제 #9
0
    def forward(self, h, attn_mask=None, mems=None):
        if mems is not None:
            c = paddle.concat([mems, h], axis=1)
        else:
            c = h

        if self.normalize_before:
            c = self.layer_norm(c)

        head_q = self.q_proj(h)
        head_k, head_v = paddle.chunk(self.kv_proj(c), chunks=2, axis=-1)

        head_q = paddle.reshape(
            head_q, shape=[h.shape[0], h.shape[1], self.n_head, self.d_head])
        head_k = paddle.reshape(
            head_k, shape=[c.shape[0], c.shape[1], self.n_head, self.d_head])
        head_v = paddle.reshape(
            head_v, shape=[c.shape[0], c.shape[1], self.n_head, self.d_head])

        attn_score = paddle.einsum('bind,bjnd->bnij', head_q, head_k)
        attn_score = attn_score * self.scale
        if attn_mask is not None:
            attn_score = attn_score - float('inf') * attn_mask

        attn_prob = F.softmax(attn_score, dim=-1)
        attn_prob = self.attn_drop(attn_prob)

        attn_vec = paddle.einsum('bnij,bjnd->bind', attn_prob, head_v)
        attn_vec = paddle.reshape(
            attn_vec,
            shape=[
                attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head
            ])

        attn_out = self.o_proj(attn_vec)
        attn_out = self.drop(attn_out)
        if self.normalize_before:
            output = h + attn_out
        else:
            output = self.layer_norm(h + attn_out)

        return output
예제 #10
0
 def test_shape(self):
     cuda_major = paddle.version.cuda().split('.')[0].strip()
     if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11:
         """ MatmulKernel support bfloat16 only if cuda_major > 11.0.
         """
         A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
         A = A.cuda()
         B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
         B = B.cuda()
         C = paddle.einsum('i,i->', A, B)
         self.assertEqual(C.item(), 8.0)
예제 #11
0
 def positional_embedding(self, inputs):
     seq_len = inputs.shape[1]
     pos_seq = paddle.arange(0, seq_len, dtype=dtype_float)
     indices = paddle.arange(0, self.head_dim, 2, dtype=dtype_float)
     indices = 1 / 10000**(indices / self.head_dim)
     sinusoid_inp = paddle.einsum("i,d->id", pos_seq, indices)
     pos_emb = paddle.concat(
         [paddle.sin(sinusoid_inp),
          paddle.cos(sinusoid_inp)], axis=-1)
     pos_emb = paddle.reshape(pos_emb, (1, 1, seq_len, self.head_dim))
     pos_emb.stop_gradient = True
     return pos_emb
예제 #12
0
파일: net.py 프로젝트: duyiqi17/PaddleRec
    def forward(self,
                input_u,
                item_attribute,
                input_ur=None,
                item_bind_M=None):
        user_feature_emb = self.user_feature_emb(input_u)
        summed_user_emb = user_feature_emb.sum(1)
        all_item_feature_emb = self.all_item_feature_emb(item_attribute)
        summed_all_item_emb = all_item_feature_emb.sum(1)
        user_cross = 0.5 * (summed_user_emb**2 - (user_feature_emb**2).sum(1))
        item_cross = 0.5 * (summed_all_item_emb**2 -
                            (all_item_feature_emb**2).sum(1))
        user_cross_score = user_cross.matmul(self.H_s)
        item_cross_score = item_cross.matmul(self.H_s)
        user_bias = self.user_bias(input_u).sum(1)
        item_bias = self.item_bias(item_attribute).sum(1)

        I = paddle.ones([input_u.shape[0], 1])
        p_emb = paddle.concat(
            [summed_user_emb, user_cross_score + user_bias + self.bias, I], 1)

        I = paddle.ones([summed_all_item_emb.shape[0], 1])
        q_emb = paddle.concat(
            [summed_all_item_emb, I, item_cross_score + item_bias], 1)
        H_i_emb = paddle.concat(
            [self.H_i,
             paddle.to_tensor([[1.0]]),
             paddle.to_tensor([[1.0]])], 0)
        dot = paddle.einsum('ac,bc->abc', p_emb, q_emb)
        pre = paddle.einsum('ajk,kl->aj', dot, H_i_emb)
        if input_ur is None:
            return (pre, )

        pos_item = F.embedding(input_ur, q_emb)
        pos_num_r = (input_ur != item_bind_M).astype(default_type)
        pos_item = paddle.einsum('ab,abc->abc', pos_num_r, pos_item)

        pos_r = paddle.einsum('ac,abc->abc', p_emb, pos_item)
        pos_r = paddle.einsum('ajk,kl->ajl', pos_r, H_i_emb).flatten(1)
        return pre, pos_r, q_emb, p_emb, H_i_emb
예제 #13
0
def rotation_2d(points, angles):
    """rotation 2d points based on origin point clockwise when angle positive.

    Args:
        points (float array, shape=[N, point_size, 2]): points to be rotated.
        angles (float array, shape=[N]): rotation angle.

    Returns:
        float array: same shape as points
    """
    rot_sin = paddle.sin(angles)
    rot_cos = paddle.cos(angles)
    rot_mat_T = paddle.stack(
        [paddle.stack([rot_cos, -rot_sin]),
         paddle.stack([rot_sin, rot_cos])])
    return paddle.einsum('aij,jka->aik', (points, rot_mat_T))
예제 #14
0
    def get_reference_out(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))

        query = paddle.to_tensor(self.query, stop_gradient=False)
        key = query if self.merge_qkv else paddle.to_tensor(
            self.key, stop_gradient=False)
        q_weight = paddle.to_tensor(self.q_weight, stop_gradient=False)
        k_weight = paddle.to_tensor(self.k_weight, stop_gradient=False)
        v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False)
        src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)

        c = self.key_dim**(-0.5)
        # [batch_size, msa_len, num_heads, res_len, key_dim]
        q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c
        # [batch_size, msa_len, num_heads, m_size, key_dim]
        k = paddle.einsum('nbka,ahc->nbkhc', key, k_weight)
        # [batch_size, msa_len, num_heads, m_size, key_dim]
        v = paddle.einsum('nbka,ahc->nbkhc', key, v_weight)

        # [batch_size, msa_len, num_heads, res_len, m_size]
        logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k)  # qk_out
        logits = logits + src_mask
        if self.bias_attr:
            nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
                                               stop_gradient=False)
            logits = logits + nonbatched_bias

        weights = nn.functional.softmax(logits)  # softmax_out
        weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v)

        if self.has_gating:
            gating_w = paddle.to_tensor(self.gating_w, stop_gradient=False)
            gating_b = paddle.to_tensor(self.gating_b, stop_gradient=False)
            gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
                                        gating_w) + gating_b
            gate_values = nn.functional.sigmoid(gate_values)
            weighted_avg = weighted_avg * gate_values

        output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
        output_w = paddle.to_tensor(self.output_w, stop_gradient=False)

        out = paddle.einsum('nbqhc,hco->nbqo', weighted_avg,
                            output_w) + output_b
        paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
                                 retain_graph=True)
        if self.merge_qkv:
            return out, query.grad, None
        else:
            return out, query.grad, key.grad
예제 #15
0
 def test_diagonalize_errors(self):
     a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
     a = paddle.to_tensor(a)
     with self.assertRaisesRegex(AssertionError, (
             'Diagonal and trace not implemented yet.')):
         paddle.einsum('...ii->...i', a)
     with self.assertRaisesRegex(AssertionError, (
             'Diagonal and trace not implemented yet.')):
         paddle.einsum('i...i', a)
     with self.assertRaisesRegex(AssertionError, (
             'Diagonal and trace not implemented yet.')):
         paddle.einsum('i...i->i...', a)
예제 #16
0
 def test_diagonalize_errors(self):
     a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
     a = paddle.to_tensor(a)
     with self.assertRaisesRegex(AssertionError,
                                 ('Duplicate labels are not supported.')):
         paddle.einsum('...ii->...i', a)
     with self.assertRaisesRegex(AssertionError,
                                 ('Duplicate labels are not supported.')):
         paddle.einsum('i...i', a)
     with self.assertRaisesRegex(AssertionError,
                                 ('Duplicate labels are not supported.')):
         paddle.einsum('i...i->i...', a)
예제 #17
0
    def forward(self, single_act: paddle.Tensor, pair_act: paddle.Tensor,
                mask: paddle.Tensor, affine: quat_affine.QuatAffine):
        # single_act: [B, N, C]
        # pair_act: [B, N, M, C']
        # mask: [B, N, 1]
        num_residues = single_act.shape[1]
        num_head = self.config.num_head
        num_scalar_qk = self.config.num_scalar_qk
        num_point_qk = self.config.num_point_qk
        num_scalar_v = self.config.num_scalar_v
        num_point_v = self.config.num_point_v
        num_output = self.config.num_channel

        # Construct scalar queries of shape:
        # [batch_size, num_query_residues, num_head, num_points]
        q_scalar = self.q_scalar(single_act)
        q_scalar = paddle.reshape(
            q_scalar, [-1, num_residues, num_head, num_scalar_qk])

        # Construct scalar keys/values of shape:
        # [batch_size, num_target_residues, num_head, num_points]
        kv_scalar = self.kv_scalar(single_act)
        kv_scalar = paddle.reshape(
            kv_scalar,
            [-1, num_residues, num_head, num_scalar_v + num_scalar_qk])
        k_scalar, v_scalar = paddle.split(
            kv_scalar, [num_scalar_qk, -1], axis=-1)

        # Construct query points of shape:
        # [batch_size, num_residues, num_head, num_point_qk]
        q_point_local = self.q_point_local(single_act)
        q_point_local = paddle.split(q_point_local, 3, axis=-1)

        q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
        q_point = [
            paddle.reshape(x, [-1, num_residues, num_head, num_point_qk])
            for x in q_point_global]

        # Construct key and value points.
        # Key points shape [batch_size, num_residues, num_head, num_point_qk]
        # Value points shape [batch_size, num_residues, num_head, num_point_v]
        kv_point_local = self.kv_point_local(single_act)
        kv_point_local = paddle.split(kv_point_local, 3, axis=-1)

        kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
        kv_point_global = [
            paddle.reshape(x, [-1, num_residues, num_head, num_point_qk + num_point_v])
            for x in kv_point_global]

        k_point, v_point = list(
            zip(*[
                paddle.split(x, [num_point_qk, -1], axis=-1)
                for x in kv_point_global
            ]))

        # We assume that all queries and keys come iid from N(0, 1) distribution
        # and compute the variances of the attention logits.
        # Each scalar pair (q, k) contributes Var q*k = 1
        scalar_variance = max(num_scalar_qk, 1) * 1.
        # Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
        point_variance = max(num_point_qk, 1) * 9. / 2

        # Allocate equal variance to scalar, point and attention 2d parts so that
        # the sum is 1.

        num_logit_terms = 3
        scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
        point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
        attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))

        trainable_point_weights = nn.functional.softplus(
            self.trainable_point_weights)
        point_weights *= paddle.unsqueeze(
            trainable_point_weights, axis=1)

        # [B, R, H, C] => [B, H, R, C], put head dim first
        q_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in q_point]
        k_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in k_point]
        v_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in v_point]

        dist2 = [
            paddle.square(paddle.unsqueeze(qx, axis=-2) - \
                          paddle.unsqueeze(kx, axis=-3))
            for qx, kx in zip(q_point, k_point)]
        dist2 = sum(dist2)

        attn_qk_point = -0.5 * paddle.sum(
            paddle.unsqueeze(point_weights, axis=[1, 2]) * dist2, axis=-1)

        q = paddle.transpose(scalar_weights * q_scalar, [0, 2, 1, 3])
        k = paddle.transpose(k_scalar, [0, 2, 1, 3])
        v = paddle.transpose(v_scalar, [0, 2, 1, 3])
        attn_qk_scalar = paddle.matmul(q, paddle.transpose(k, [0, 1, 3, 2]))
        attn_logits = attn_qk_scalar + attn_qk_point

        attention_2d = self.attention_2d(pair_act)
        attention_2d = paddle.transpose(attention_2d, [0, 3, 1, 2])
        attention_2d = attention_2d_weights * attention_2d
        attn_logits += attention_2d

        mask_2d = mask * paddle.transpose(mask, [0, 2, 1])
        attn_logits -= 1e5 * (1. - mask_2d)

        # [batch_size, num_head, num_query_residues, num_target_residues]
        attn = nn.functional.softmax(attn_logits)

        # o_i^h
        # [batch_size, num_query_residues, num_head, num_head * num_scalar_v]
        result_scalar = paddle.matmul(attn, v)
        result_scalar = paddle.transpose(result_scalar, [0, 2, 1, 3])

        # o_i^{hp}
        # [batch_size, num_query_residues, num_head, num_head * num_point_v]
        result_point_global = [
            paddle.sum(paddle.unsqueeze(attn, -1) * paddle.unsqueeze(vx, -3),
                       axis=-2) for vx in v_point]
        result_point_global = [
            paddle.transpose(x, [0, 2, 1, 3]) for x in result_point_global]

        # \tilde{o}_i^h
        # [batch_size, num_residues, num_head, pair_channel]
        result_attention_over_2d = paddle.einsum(
            'nhij,nijc->nihc', attn, pair_act)

        # Reshape, global-to-local and save
        result_scalar = paddle.reshape(
            result_scalar, [-1, num_residues, num_head * num_scalar_v])
        result_point_global = [
            paddle.reshape(x, [-1, num_residues, num_head * num_point_v])
            for x in result_point_global]
        result_point_local = affine.invert_point(
            result_point_global, extra_dims=1)
        result_attention_over_2d = paddle.reshape(
            result_attention_over_2d,
            [-1, num_residues, num_head * self.channel_num['pair_channel']])

        result_point_local_norm = paddle.sqrt(
            self.dist_epsilon + paddle.square(result_point_local[0]) + \
            paddle.square(result_point_local[1]) + \
            paddle.square(result_point_local[2]))

        output_features = [result_scalar]
        output_features.extend(result_point_local)
        output_features.extend(
            [result_point_local_norm, result_attention_over_2d])

        final_act = paddle.concat(output_features, axis=-1)
        return self.output_projection(final_act)
예제 #18
0
    def forward(self,
                query_matrix,
                key_matrix,
                value_matrix,
                d_head,
                attn_mask=None,
                rand_mask_idx=None,
                query_mask=None,
                key_mask=None,
                dropout=None):
        '''
            query_matrix: [B, H, T, D]
            key_matrix: [B, H, T, D]
            value_matrix: [B, H, T, D]
            query_mask: [B, 1, T, 1]  bool mask
            key_mask: [B, 1, 1, T]    bool mask
            rand_mask_idx: [H, T//bs, bs]
            Global Attention
            Random Attention
            Window Attention            
        '''
        B = query_matrix.shape[0]  # batch_size
        H = self.num_heads
        T = query_matrix.shape[2]  # sequence_length
        D = query_matrix.shape[3]  # size per head
        G = self.num_global_blocks
        GB = self.num_global_blocks_back
        GF = self.num_global_blocks_front
        R = self.num_rand_blocks
        W = self.window_size
        bs = self.block_size
        L = T // bs  # blocked length

        blocked_query_matrix = paddle.reshape(query_matrix, [B, H, L, bs, -1])
        blocked_key_matrix = paddle.reshape(key_matrix, [B, H, L, bs, -1])
        blocked_value_matrix = paddle.reshape(value_matrix, [B, H, L, bs, -1])
        blocked_query_mask = paddle.reshape(query_mask, [B, L, bs])
        blocked_key_mask = paddle.reshape(key_mask, [B, L, bs])

        # 1. global_front_product
        global_front_out = self._get_global_out(query_matrix, key_matrix,
                                                value_matrix, key_mask, d_head,
                                                dropout)

        # 2. global_back_product
        global_back_out = self._get_global_out(query_matrix, key_matrix,
                                               value_matrix, key_mask, d_head,
                                               dropout, False)

        # 3. second_product

        # create second matrix
        # [B, 1, L-G, bs, (G+W)*bs]
        band_mask = self._get_band_mask(blocked_query_mask, blocked_key_mask,
                                        B, T)
        # [B, H, L-G, bs, R*bs]
        rand_mask = self._get_rand_mask(blocked_query_mask, blocked_key_mask,
                                        rand_mask_idx, B, T)
        # [B, H, L-G, bs, (G+W+R)*bs]
        second_mask = paddle.concat([band_mask, rand_mask], axis=4)

        # [B, H, L-G, R * bs, -1]
        random_keys = self._gather_random_key_value(blocked_key_matrix,
                                                    rand_mask_idx, B, T)
        random_values = self._gather_random_key_value(blocked_value_matrix,
                                                      rand_mask_idx, B, T)

        band_keys_matrix = self._get_band_matrix(blocked_key_matrix, B, T)
        band_value_matrix = self._get_band_matrix(blocked_value_matrix, B, T)

        # [B, H, L - G, bs, -1]
        second_query_matrix = blocked_query_matrix[:, :, GF:-GB]
        # [B, H, L - G, (G+W+R)*bs, -1]
        second_key_matrix = paddle.concat([band_keys_matrix, random_keys],
                                          axis=3)
        # [B, H, L - G, (G+W+R)*bs, -1]
        second_value_matrix = paddle.concat([band_value_matrix, random_values],
                                            axis=3)
        second_top_value_matrix, second_middle_value_matrix, second_bottom_value_matrix = \
            self._get_splited_matrix(second_value_matrix)
        second_product = paddle.einsum("bhlqd,bhlkd->bhlqk",
                                       second_query_matrix, second_key_matrix)
        second_product = second_product * (d_head**-0.5)
        second_product += (1 - second_mask) * -1e6
        second_weights = F.softmax(second_product)

        second_top_weights, second_middle_weights, second_bottom_weights = \
            self._get_splited_matrix(second_weights)
        second_top_out = paddle.einsum("bhlqk,bhlkd->bhlqd",
                                       second_top_weights,
                                       second_top_value_matrix)

        second_middle_out = paddle.einsum(
            "bhlqk,bhlkd->bhlqd",
            second_middle_weights[:, :, :, :, GF * bs:-(GB + R) * bs],
            second_middle_value_matrix[:, :, :, GF * bs:-(GB + R) * bs])
        # add global block attention
        second_middle_out += paddle.einsum(
            "bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, :GF * bs],
            blocked_value_matrix[:, :, 0])
        second_middle_out += paddle.einsum(
            "bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :,
                                                       -(GB + R) * bs:-R * bs],
            blocked_value_matrix[:, :, -GB])
        # add random block attention
        second_middle_out += paddle.einsum(
            "...qk,...kd->...qd", second_middle_weights[:, :, :, :, -R * bs:],
            random_values[:, :, GF:-GB])

        second_bottom_out = paddle.einsum("bhlqk,bhlkd->bhlqd",
                                          second_bottom_weights,
                                          second_bottom_value_matrix)

        second_out = paddle.concat(
            [second_top_out, second_middle_out, second_bottom_out], axis=2)
        second_out = paddle.reshape(second_out, [B, H, (L - G) * bs, -1])

        # [B, H, T, D]
        out = paddle.concat([global_front_out, second_out, global_back_out],
                            axis=2)
        out = out * query_mask
        return out
예제 #19
0
    def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
                       sequence_length):
        '''
        Return second mask: [B, 1, L-G, bs, G+W]
        '''
        GB = self.num_global_blocks_back
        GF = self.num_global_blocks_front
        G = self.num_global_blocks
        R = self.num_rand_blocks
        W = self.window_size
        bs = self.block_size
        T = sequence_length
        L = T // bs  # blocked length
        B = batch_size
        H = self.num_heads
        # G+W+R
        # query_mask: [B, L, bs]
        # key_mask: [B, L, bs]
        # [B, L-G, bs, 1] * [B, L-G, 1, G*bs] -> [B, L-G, bs, G*bs]
        temp_query_mask = paddle.reshape(blocked_query_mask[:, GF:-GB],
                                         [B, L - G, bs, 1])
        temp_key_mask_front = paddle.reshape(blocked_key_mask[:, :GF],
                                             [B, 1, 1, GF * bs])
        global_block_mask_front = paddle.einsum("blqd,bmdk->blqk",
                                                temp_query_mask,
                                                temp_key_mask_front)

        temp_key_mask_back = paddle.reshape(blocked_key_mask[:, -GB:],
                                            [B, 1, 1, GB * bs])
        global_block_mask_back = paddle.einsum("blqd,bmdk->blqk",
                                               temp_query_mask,
                                               temp_key_mask_back)
        # create window block mask
        key_mask_list = []
        for query_block_id in range(GF, GF + W // 2):
            left_block_id = query_block_id - W // 2
            right_block_id = query_block_id + W // 2
            zero_key_mask = paddle.zeros_like(
                blocked_key_mask[:, -(W - (right_block_id + 1 - G)):-GB])
            temp_key_mask = paddle.concat(
                [blocked_key_mask[:, GF:(right_block_id + 1)], zero_key_mask],
                axis=1)
            temp_key_mask = paddle.unsqueeze(temp_key_mask, 1)
            key_mask_list.append(temp_key_mask)
        roll_key_mask1 = paddle.concat(key_mask_list, axis=1)
        roll_key_mask1 = paddle.reshape(roll_key_mask1, [0, 0, W * bs])
        key_mask_list = []

        band_length = L - G - W // 2 * 2
        for query_block_id in range(GF + W // 2, GF + W // 2 + W):
            left_block_id = query_block_id - W // 2
            right_block_id = query_block_id + W // 2
            key_mask_list.append(
                blocked_key_mask[:, left_block_id:left_block_id + band_length])
        window_key_mask = paddle.concat(key_mask_list, axis=2)
        window_key_mask = paddle.reshape(window_key_mask, [0, 0, W * bs])

        key_mask_list = []
        for query_block_id in range((L - GB) - W // 2, L - GB):
            left_block_id = query_block_id - W // 2
            right_block_id = query_block_id + W // 2
            zero_key_mask = paddle.zeros_like(
                blocked_key_mask[:, GF:GF + W - (L - left_block_id - GB)])
            temp_key_mask = paddle.concat(
                [zero_key_mask, blocked_key_mask[:, left_block_id:-GB]],
                axis=1)
            temp_key_mask = paddle.unsqueeze(temp_key_mask, 1)
            key_mask_list.append(temp_key_mask)
        roll_key_mask2 = paddle.concat(key_mask_list, axis=1)
        roll_key_mask2 = paddle.reshape(roll_key_mask2, [0, 0, W * bs])

        window_key_mask = paddle.concat(
            [roll_key_mask1, window_key_mask, roll_key_mask2], axis=1)
        window_key_mask = paddle.unsqueeze(window_key_mask, axis=2)
        # [B, L-G, bs, 1] * [B, L-G, 1, W*bs] -> [B, L-G, bs, W*bs]
        window_block_mask = paddle.einsum("blkd,bldq->blkq", temp_query_mask,
                                          window_key_mask)
        band_mask = paddle.concat([
            global_block_mask_front, window_block_mask, global_block_mask_back
        ],
                                  axis=3)
        band_mask = paddle.unsqueeze(band_mask, 1)  # for head
        band_mask = paddle.expand(band_mask, [B, H, L - G, bs, -1])
        return band_mask
예제 #20
0
    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
        qlen, rlen, bsz = w.shape[1], r.shape[1], w.shape[0]

        if mems is not None:
            cat = paddle.concat([mems, w], axis=1)
            if self.normalize_before:
                w_heads = self.qkv_proj(self.layer_norm(cat))
            else:
                w_heads = self.qkv_proj(cat)
            r_head_k = self.r_proj(r)

            w_head_q, w_head_k, w_head_v = paddle.chunk(
                w_heads, chunks=3, axis=-1)

            w_head_q = w_head_q[:, -qlen:, :]
        else:
            if self.normalize_before:
                w_heads = self.qkv_proj(self.layer_norm(w))
            else:
                w_heads = self.qkv_proj(w)
            r_head_k = self.r_proj(r)

            w_head_q, w_head_k, w_head_v = paddle.chunk(
                w_heads, chunks=3, axis=-1)

        klen = w_head_k.shape[1]

        w_head_q = paddle.reshape(
            w_head_q, shape=[bsz, qlen, self.n_head, self.d_head])
        w_head_k = paddle.reshape(
            w_head_k, shape=[bsz, klen, self.n_head, self.d_head])
        w_head_v = paddle.reshape(
            w_head_v, shape=[bsz, klen, self.n_head, self.d_head])

        r_head_k = paddle.reshape(
            r_head_k, shape=[bsz, rlen, self.n_head, self.d_head])

        rw_head_q = w_head_q + r_w_bias

        AC = paddle.einsum('bind,bjnd->bnij', rw_head_q, w_head_k)
        rr_head_q = w_head_q + r_r_bias

        BD = paddle.einsum('bind,bjnd->bnij', rr_head_q, r_head_k)
        BD = self._rel_shift(BD)

        attn_score = AC + BD
        attn_score = attn_score * self.scale

        if attn_mask is not None:
            attn_score = attn_score - 1e30 * attn_mask

        attn_prob = F.softmax(attn_score, axis=-1)
        attn_prob = self.attn_drop(attn_prob)

        attn_vec = paddle.einsum('bnij,bjnd->bind', attn_prob, w_head_v)

        attn_vec = paddle.reshape(
            attn_vec,
            shape=[
                attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head
            ])

        attn_out = self.o_proj(attn_vec)
        attn_out = self.drop(attn_out)

        if self.normalize_before:
            output = w + attn_out
        else:
            output = self.layer_norm(w + attn_out)

        return output
예제 #21
0
    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
        qlen, bsz = w.shape[1], w.shape[0]

        if mems is not None:
            cat = paddle.concat([mems, w], 1)
            if self.normalize_before:
                w_heads = self.qkv_proj(self.layer_norm(cat))
            else:
                w_heads = self.qkv_proj(cat)
            w_head_q, w_head_k, w_head_v = paddle.chunk(
                w_heads, chunks=3, axis=-1)

            w_head_q = w_head_q[-qlen:]
        else:
            if self.normalize_before:
                w_heads = self.qkv_proj(self.layer_norm(w))
            else:
                w_heads = self.qkv_proj(w)
            w_head_q, w_head_k, w_head_v = paddle.chunk(
                w_heads, chunks=3, axis=-1)

        klen = w_head_k.shape[1]

        w_head_q = paddle.reshape(
            w_head_q,
            shape=[
                w_head_q.shape[0], w_head_q.shape[1], self.n_head, self.d_head
            ])
        w_head_k = paddle.reshape(
            w_head_k,
            shape=[
                w_head_k.shape[0], w_head_k.shape[1], self.n_head, self.d_head
            ])
        w_head_v = paddle.reshape(
            w_head_v,
            shape=[
                w_head_v.shape[0], w_head_v.shape[1], self.n_head, self.d_head
            ])

        if klen > r_emb.shape[0]:
            r_emb_pad = r_emb[0:1].expand(klen - r_emb.shape[0], -1, -1)
            r_emb = paddle.concat([r_emb_pad, r_emb], 0)
            r_bias_pad = r_bias[0:1].expand(klen - r_bias.shape[0], -1)
            r_bias = paddle.concat([r_bias_pad, r_bias], 0)
        else:
            r_emb = r_emb[-klen:]
            r_bias = r_bias[-klen:]

        rw_head_q = w_head_q + r_w_bias.unsqueeze([0])

        AC = paddle.einsum('bind,bjnd->bnij', rw_head_q, w_head_k)
        r_emb = r_emb.unsqueeze([0]).expand([bsz, -1, -1, -1])
        B_ = paddle.einsum('bind,bjnd->bnij', w_head_q, r_emb)
        D_ = r_bias.unsqueeze([0, 2])
        BD = self._rel_shift(B_ + D_)

        attn_score = AC + BD
        attn_score = attn_score * self.scale

        if attn_mask is not None:
            attn_score = attn_score - float('inf') * attn_mask

        attn_prob = F.softmax(attn_score, dim=-1)
        attn_prob = self.attn_drop(attn_prob)

        attn_vec = paddle.einsum('bnij,bjnd->bind', attn_prob, w_head_v)

        attn_vec = paddle.reshape(
            attn_vec,
            shape=[
                attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head
            ])

        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.normalize_before:
            output = w + attn_out
        else:
            output = self.layer_norm(w + attn_out)

        return output
예제 #22
0
 def test_param_errors(self):
     a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
     a = paddle.to_tensor(a)
     with self.assertRaisesRegex(AssertionError,
                                 ('At least one operand is expected.')):
         paddle.einsum('ijk')
     with self.assertRaisesRegex(AssertionError, (
             'Invalid equation: multiple `->` were found.')):
         paddle.einsum('i -> j -> k', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: the number of operands is 2, "
             "but found 3 segments in the label equation.")):
         paddle.einsum('i,j,k', a, a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: the number of operands is 2, "
             "but found 1 segments in the label equation.")):
         paddle.einsum('ij -> k', a, a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: the number of operands is 1, "
             "but found 2 segments in the label equation.")):
         paddle.einsum('i, -> k', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: the label string '' misses dimensions.")):
         paddle.einsum('->', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: the label string 'i' misses dimensions.")):
         paddle.einsum('i', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: _ is not a valid label, "
             "which should be letters.")):
         paddle.einsum('i_', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: `.` is found outside of an ellipsis.")):
         paddle.einsum('i..j', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: `.` is found outside of an ellipsis.")):
         paddle.einsum('...k...', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: missing ellipsis in output labels.")):
         paddle.einsum('i...->i', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid equation: duplicate output labels are found.")):
         paddle.einsum('i...->i...i', a)
     with self.assertRaisesRegex(AssertionError, (
             "Invalid operands: label i "
             "corresponds to non-broadcastable dimensions.")):
         paddle.einsum('ij...,ji...', a, a)
예제 #23
0
 def test_shape(self):
     A = paddle.static.data(name='x', shape=[-1])
     B = paddle.static.data(name='y', shape=[384])
     C = paddle.einsum('i,d->id', A, B)
     self.assertEqual(C.shape, (-1, 384))