def create_loss(self, prediction, config): pre, pos_r, q_emb, p_emb, H_i_emb = prediction weight = config.get('hyper_parameters.negative_weight', 0.5) loss = weight * paddle.sum( paddle.sum( paddle.sum(paddle.einsum('ab,ac->abc', q_emb, q_emb), 0) * paddle.sum(paddle.einsum('ab,ac->abc', p_emb, p_emb), 0) * paddle.matmul(H_i_emb, H_i_emb, transpose_y=True), 0), 0) loss += paddle.sum((1.0 - weight) * paddle.square(pos_r) - 2.0 * pos_r) return loss
def test_static_graph(self): paddle.enable_static() fluid = paddle.fluid if fluid.core.is_compiled_with_cuda(): self.place = fluid.CUDAPlace(0) else: self.place = fluid.CPUPlace() main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): a = paddle.static.data(name='a', shape=[3, None, None, None], dtype='float') b = paddle.static.data(name='b', shape=[2, None, None, None], dtype='float') c = paddle.static.data(name='c', shape=[None, None, 2, None], dtype='float') d = paddle.static.data(name='d', shape=[None, None, 5], dtype='float') e = paddle.static.data(name='e', shape=[None, 2, None], dtype='float') outs = [] outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b)) outs.append(paddle.einsum('...ik, ...j', c, d)) outs.append(paddle.einsum('...kj, ...ik', d, e)) outs.append(paddle.einsum('ijk..., ikj', c, e)) outs.append(paddle.einsum('ijk..., ikj->...ij', c, e)) exe = fluid.Executor(self.place) exe.run(startup) a = np.arange(72).reshape(3, 2, 3, 4).astype('float') b = np.arange(48).reshape(2, 2, 3, 4).astype('float') c = np.arange(48).reshape(2, 3, 2, 4).astype('float') d = np.arange(30).reshape(2, 3, 5).astype('float') e = np.arange(12).reshape(2, 2, 3).astype('float') feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e} actual = exe.run(main, feed=feeds, fetch_list=[outs]) expect = [] expect.append(np.einsum("ibnd,jbnd->bnij", a, b)) expect.append(np.einsum('...ik, ...j', c, d)) expect.append(np.einsum('...kj, ...ik', d, e)) expect.append(np.einsum('ijk..., ikj', c, e)) expect.append(np.einsum('ijk..., ikj->...ij', c, e)) for a, e in zip(actual, expect): self.check_output_equal(a, e)
def supervised_chi_loss(ret, batch, value, config): """Computes loss for direct chi angle supervision. Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss" Args: ret: Dictionary to write outputs into, needs to contain 'loss'. batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'. value: Dictionary containing structure module output, needs to contain value['sidechains']['angles_sin_cos'] for angles and value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized angles. config: Configuration of loss, should contain 'chi_weight' and 'angle_norm_weight', 'angle_norm_weight' scales angle norm term, 'chi_weight' scales torsion term. """ eps = 1e-6 sequence_mask = batch['seq_mask'] num_res = sequence_mask.shape[1] batch_size = sequence_mask.shape[0] chi_mask = batch['chi_mask'] pred_angles = paddle.reshape(value['sidechains']['angles_sin_cos'], [batch_size, -1, num_res, 7, 2]) pred_angles = pred_angles[:, :, :, 3:] residue_type_one_hot = paddle.nn.functional.one_hot(batch['aatype_index'], num_classes=residue_constants.restype_num + 1) chi_pi_periodic = paddle.einsum('nijk, nkl->nijl', residue_type_one_hot[:, None, ...], paddle.to_tensor(residue_constants.chi_pi_periodic)[None]) sin_cos_true_chi = batch['chi_angles_sin_cos'][:, None, ...] # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi sq_chi_error = paddle.sum(squared_difference(sin_cos_true_chi, pred_angles), axis=-1) sq_chi_error_shifted = paddle.sum(squared_difference(sin_cos_true_chi_shifted, pred_angles), axis=-1) sq_chi_error = paddle.minimum(sq_chi_error, sq_chi_error_shifted) sq_chi_loss_tmp = [] for i in range(batch_size): sq_chi_loss_i = utils.mask_mean(mask=paddle.unsqueeze(chi_mask[i], axis=0), value=sq_chi_error[i]) sq_chi_loss_tmp.append(sq_chi_loss_i) sq_chi_loss = paddle.to_tensor(sq_chi_loss_tmp, stop_gradient=False) sq_chi_loss = paddle.squeeze(sq_chi_loss, axis=-1) ret['chi_loss'] = sq_chi_loss ret['loss'] += config.chi_weight * sq_chi_loss unnormed_angles = paddle.reshape(value['sidechains']['unnormalized_angles_sin_cos'], [batch_size, -1, num_res, 7, 2]) angle_norm = paddle.sqrt(paddle.sum(paddle.square(unnormed_angles), axis=-1) + eps) norm_error = paddle.abs(angle_norm - 1.) angle_norm_loss_tmp = [] for i in range(batch_size): angle_norm_loss_i = utils.mask_mean(mask=paddle.unsqueeze(sequence_mask[i], axis=[0,2]), value=norm_error[i]) angle_norm_loss_tmp.append(angle_norm_loss_i) angle_norm_loss = paddle.to_tensor(angle_norm_loss_tmp, stop_gradient=False) angle_norm_loss = paddle.squeeze(angle_norm_loss, axis=-1) ret['angle_norm_loss'] = angle_norm_loss ret['loss'] += config.angle_norm_weight * angle_norm_loss
def check_output(self, eqn, *ops): expect = np.einsum(eqn, *ops) with paddle.fluid.dygraph.guard( self._get_place(force_to_use_cpu=False)): pd_operands = [paddle.to_tensor(op) for op in ops] actual = paddle.einsum(eqn, *pd_operands) self.check_output_equal(actual.numpy(), expect)
def _get_rand_mask(self, blocked_query_mask, blocked_key_mask, rand_mask_idx, batch_size, sequence_length): ''' return random mask: [B, H, L-G, bs, R * bs] ''' # rand_mask_idx: [H, T] # blocked_query_mask: [B, L, bs] # blocked_key_mask: [B, L, bs] bs = self.block_size B = batch_size L = sequence_length // bs H = self.num_heads G = self.num_global_blocks GB = self.num_global_blocks_back GF = self.num_global_blocks_front R = self.num_rand_blocks temp_block_key_mask = paddle.unsqueeze(blocked_key_mask, 1) temp_block_key_mask = paddle.expand(temp_block_key_mask, [B, H, L, -1]) temp_block_key_mask_list = [ paddle.gather_nd(temp_block_key_mask[b], rand_mask_idx) for b in range(B) ] temp_block_key_mask = paddle.concat(temp_block_key_mask_list, 0) temp_block_key_mask = paddle.reshape(temp_block_key_mask, [ B, temp_block_key_mask.shape[0] // B // (L - GF - GB) // R, L - GF - GB, -1 ]) rand_mask = paddle.einsum("blq,bhlk->bhlqk", blocked_query_mask[:, GF:-GB], temp_block_key_mask) return rand_mask
def _hsv_to_rgb(img): """Convert a image Tensor from HSV to RGB. """ h, s, v = img.unbind(axis=-3) f = h * 6.0 i = paddle.floor(f) f = f - i i = i.astype(paddle.int32) % 6 p = paddle.clip(v * (1.0 - s), 0.0, 1.0) q = paddle.clip(v * (1.0 - s * f), 0.0, 1.0) t = paddle.clip(v * (1.0 - s * (1.0 - f)), 0.0, 1.0) mask = paddle.equal( i.unsqueeze(axis=-3), paddle.arange( 6, dtype=i.dtype).reshape((-1, 1, 1))).astype(img.dtype) matrix = paddle.stack( [ paddle.stack( [v, q, p, p, t, v], axis=-3), paddle.stack( [t, v, v, q, p, p], axis=-3), paddle.stack( [p, p, t, v, v, q], axis=-3) ], axis=-4) return paddle.einsum("...ijk, ...xijk -> ...xjk", mask, matrix)
def rotation_3d_in_axis(points, angles, axis=0): # points: [N, point_size, 3] # angles: [N] rot_sin = paddle.sin(angles) rot_cos = paddle.cos(angles) ones = paddle.ones_like(rot_cos) zeros = paddle.zeros_like(rot_cos) if axis == 1: rot_mat_T = paddle.stack([ paddle.stack([rot_cos, zeros, -rot_sin]), paddle.stack([zeros, ones, zeros]), paddle.stack([rot_sin, zeros, rot_cos]) ]) elif axis == 2 or axis == -1: rot_mat_T = paddle.stack([ paddle.stack([rot_cos, -rot_sin, zeros]), paddle.stack([rot_sin, rot_cos, zeros]), paddle.stack([zeros, zeros, ones]) ]) elif axis == 0: rot_mat_T = paddle.stack([ paddle.stack([zeros, rot_cos, -rot_sin]), paddle.stack([zeros, rot_sin, rot_cos]), paddle.stack([ones, zeros, zeros]) ]) else: raise ValueError("axis should in range") return paddle.einsum('aij,jka->aik', (points, rot_mat_T))
def test_forward(self): operands = [ TestEinsum.TEST_SAMPLES[operand] for operand in self.sample["data"] ] expected_result = np.einsum(self.sample["paradigm"], *operands) equation = self.sample["paradigm"] with paddle.fluid.dygraph.guard( self._get_place(force_to_use_cpu=False)): pd_operands = [paddle.to_tensor(operand) for operand in operands] result = paddle.einsum(equation, *pd_operands) self.check_output_equal(result.numpy(), expected_result) with paddle.fluid.dygraph.guard(self._get_place(force_to_use_cpu=True)): pd_operands = [paddle.to_tensor(operand) for operand in operands] result = paddle.einsum(equation, *pd_operands) self.check_output_equal(result.numpy(), expected_result)
def forward(self, h, attn_mask=None, mems=None): if mems is not None: c = paddle.concat([mems, h], axis=1) else: c = h if self.normalize_before: c = self.layer_norm(c) head_q = self.q_proj(h) head_k, head_v = paddle.chunk(self.kv_proj(c), chunks=2, axis=-1) head_q = paddle.reshape( head_q, shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) head_k = paddle.reshape( head_k, shape=[c.shape[0], c.shape[1], self.n_head, self.d_head]) head_v = paddle.reshape( head_v, shape=[c.shape[0], c.shape[1], self.n_head, self.d_head]) attn_score = paddle.einsum('bind,bjnd->bnij', head_q, head_k) attn_score = attn_score * self.scale if attn_mask is not None: attn_score = attn_score - float('inf') * attn_mask attn_prob = F.softmax(attn_score, dim=-1) attn_prob = self.attn_drop(attn_prob) attn_vec = paddle.einsum('bnij,bjnd->bind', attn_prob, head_v) attn_vec = paddle.reshape( attn_vec, shape=[ attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head ]) attn_out = self.o_proj(attn_vec) attn_out = self.drop(attn_out) if self.normalize_before: output = h + attn_out else: output = self.layer_norm(h + attn_out) return output
def test_shape(self): cuda_major = paddle.version.cuda().split('.')[0].strip() if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11: """ MatmulKernel support bfloat16 only if cuda_major > 11.0. """ A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16) A = A.cuda() B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16) B = B.cuda() C = paddle.einsum('i,i->', A, B) self.assertEqual(C.item(), 8.0)
def positional_embedding(self, inputs): seq_len = inputs.shape[1] pos_seq = paddle.arange(0, seq_len, dtype=dtype_float) indices = paddle.arange(0, self.head_dim, 2, dtype=dtype_float) indices = 1 / 10000**(indices / self.head_dim) sinusoid_inp = paddle.einsum("i,d->id", pos_seq, indices) pos_emb = paddle.concat( [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) pos_emb = paddle.reshape(pos_emb, (1, 1, seq_len, self.head_dim)) pos_emb.stop_gradient = True return pos_emb
def forward(self, input_u, item_attribute, input_ur=None, item_bind_M=None): user_feature_emb = self.user_feature_emb(input_u) summed_user_emb = user_feature_emb.sum(1) all_item_feature_emb = self.all_item_feature_emb(item_attribute) summed_all_item_emb = all_item_feature_emb.sum(1) user_cross = 0.5 * (summed_user_emb**2 - (user_feature_emb**2).sum(1)) item_cross = 0.5 * (summed_all_item_emb**2 - (all_item_feature_emb**2).sum(1)) user_cross_score = user_cross.matmul(self.H_s) item_cross_score = item_cross.matmul(self.H_s) user_bias = self.user_bias(input_u).sum(1) item_bias = self.item_bias(item_attribute).sum(1) I = paddle.ones([input_u.shape[0], 1]) p_emb = paddle.concat( [summed_user_emb, user_cross_score + user_bias + self.bias, I], 1) I = paddle.ones([summed_all_item_emb.shape[0], 1]) q_emb = paddle.concat( [summed_all_item_emb, I, item_cross_score + item_bias], 1) H_i_emb = paddle.concat( [self.H_i, paddle.to_tensor([[1.0]]), paddle.to_tensor([[1.0]])], 0) dot = paddle.einsum('ac,bc->abc', p_emb, q_emb) pre = paddle.einsum('ajk,kl->aj', dot, H_i_emb) if input_ur is None: return (pre, ) pos_item = F.embedding(input_ur, q_emb) pos_num_r = (input_ur != item_bind_M).astype(default_type) pos_item = paddle.einsum('ab,abc->abc', pos_num_r, pos_item) pos_r = paddle.einsum('ac,abc->abc', p_emb, pos_item) pos_r = paddle.einsum('ajk,kl->ajl', pos_r, H_i_emb).flatten(1) return pre, pos_r, q_emb, p_emb, H_i_emb
def rotation_2d(points, angles): """rotation 2d points based on origin point clockwise when angle positive. Args: points (float array, shape=[N, point_size, 2]): points to be rotated. angles (float array, shape=[N]): rotation angle. Returns: float array: same shape as points """ rot_sin = paddle.sin(angles) rot_cos = paddle.cos(angles) rot_mat_T = paddle.stack( [paddle.stack([rot_cos, -rot_sin]), paddle.stack([rot_sin, rot_cos])]) return paddle.einsum('aij,jka->aik', (points, rot_mat_T))
def get_reference_out(self): paddle.disable_static(place=paddle.CUDAPlace(0)) query = paddle.to_tensor(self.query, stop_gradient=False) key = query if self.merge_qkv else paddle.to_tensor( self.key, stop_gradient=False) q_weight = paddle.to_tensor(self.q_weight, stop_gradient=False) k_weight = paddle.to_tensor(self.k_weight, stop_gradient=False) v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False) src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) c = self.key_dim**(-0.5) # [batch_size, msa_len, num_heads, res_len, key_dim] q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c # [batch_size, msa_len, num_heads, m_size, key_dim] k = paddle.einsum('nbka,ahc->nbkhc', key, k_weight) # [batch_size, msa_len, num_heads, m_size, key_dim] v = paddle.einsum('nbka,ahc->nbkhc', key, v_weight) # [batch_size, msa_len, num_heads, res_len, m_size] logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) # qk_out logits = logits + src_mask if self.bias_attr: nonbatched_bias = paddle.to_tensor(self.nonbatched_bias, stop_gradient=False) logits = logits + nonbatched_bias weights = nn.functional.softmax(logits) # softmax_out weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) if self.has_gating: gating_w = paddle.to_tensor(self.gating_w, stop_gradient=False) gating_b = paddle.to_tensor(self.gating_b, stop_gradient=False) gate_values = paddle.einsum('nbqc,chv->nbqhv', query, gating_w) + gating_b gate_values = nn.functional.sigmoid(gate_values) weighted_avg = weighted_avg * gate_values output_b = paddle.to_tensor(self.output_b, stop_gradient=False) output_w = paddle.to_tensor(self.output_w, stop_gradient=False) out = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, output_w) + output_b paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], retain_graph=True) if self.merge_qkv: return out, query.grad, None else: return out, query.grad, key.grad
def test_diagonalize_errors(self): a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') a = paddle.to_tensor(a) with self.assertRaisesRegex(AssertionError, ( 'Diagonal and trace not implemented yet.')): paddle.einsum('...ii->...i', a) with self.assertRaisesRegex(AssertionError, ( 'Diagonal and trace not implemented yet.')): paddle.einsum('i...i', a) with self.assertRaisesRegex(AssertionError, ( 'Diagonal and trace not implemented yet.')): paddle.einsum('i...i->i...', a)
def test_diagonalize_errors(self): a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') a = paddle.to_tensor(a) with self.assertRaisesRegex(AssertionError, ('Duplicate labels are not supported.')): paddle.einsum('...ii->...i', a) with self.assertRaisesRegex(AssertionError, ('Duplicate labels are not supported.')): paddle.einsum('i...i', a) with self.assertRaisesRegex(AssertionError, ('Duplicate labels are not supported.')): paddle.einsum('i...i->i...', a)
def forward(self, single_act: paddle.Tensor, pair_act: paddle.Tensor, mask: paddle.Tensor, affine: quat_affine.QuatAffine): # single_act: [B, N, C] # pair_act: [B, N, M, C'] # mask: [B, N, 1] num_residues = single_act.shape[1] num_head = self.config.num_head num_scalar_qk = self.config.num_scalar_qk num_point_qk = self.config.num_point_qk num_scalar_v = self.config.num_scalar_v num_point_v = self.config.num_point_v num_output = self.config.num_channel # Construct scalar queries of shape: # [batch_size, num_query_residues, num_head, num_points] q_scalar = self.q_scalar(single_act) q_scalar = paddle.reshape( q_scalar, [-1, num_residues, num_head, num_scalar_qk]) # Construct scalar keys/values of shape: # [batch_size, num_target_residues, num_head, num_points] kv_scalar = self.kv_scalar(single_act) kv_scalar = paddle.reshape( kv_scalar, [-1, num_residues, num_head, num_scalar_v + num_scalar_qk]) k_scalar, v_scalar = paddle.split( kv_scalar, [num_scalar_qk, -1], axis=-1) # Construct query points of shape: # [batch_size, num_residues, num_head, num_point_qk] q_point_local = self.q_point_local(single_act) q_point_local = paddle.split(q_point_local, 3, axis=-1) q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) q_point = [ paddle.reshape(x, [-1, num_residues, num_head, num_point_qk]) for x in q_point_global] # Construct key and value points. # Key points shape [batch_size, num_residues, num_head, num_point_qk] # Value points shape [batch_size, num_residues, num_head, num_point_v] kv_point_local = self.kv_point_local(single_act) kv_point_local = paddle.split(kv_point_local, 3, axis=-1) kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) kv_point_global = [ paddle.reshape(x, [-1, num_residues, num_head, num_point_qk + num_point_v]) for x in kv_point_global] k_point, v_point = list( zip(*[ paddle.split(x, [num_point_qk, -1], axis=-1) for x in kv_point_global ])) # We assume that all queries and keys come iid from N(0, 1) distribution # and compute the variances of the attention logits. # Each scalar pair (q, k) contributes Var q*k = 1 scalar_variance = max(num_scalar_qk, 1) * 1. # Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2 point_variance = max(num_point_qk, 1) * 9. / 2 # Allocate equal variance to scalar, point and attention 2d parts so that # the sum is 1. num_logit_terms = 3 scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance)) point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance)) attention_2d_weights = np.sqrt(1.0 / (num_logit_terms)) trainable_point_weights = nn.functional.softplus( self.trainable_point_weights) point_weights *= paddle.unsqueeze( trainable_point_weights, axis=1) # [B, R, H, C] => [B, H, R, C], put head dim first q_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in q_point] k_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in k_point] v_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in v_point] dist2 = [ paddle.square(paddle.unsqueeze(qx, axis=-2) - \ paddle.unsqueeze(kx, axis=-3)) for qx, kx in zip(q_point, k_point)] dist2 = sum(dist2) attn_qk_point = -0.5 * paddle.sum( paddle.unsqueeze(point_weights, axis=[1, 2]) * dist2, axis=-1) q = paddle.transpose(scalar_weights * q_scalar, [0, 2, 1, 3]) k = paddle.transpose(k_scalar, [0, 2, 1, 3]) v = paddle.transpose(v_scalar, [0, 2, 1, 3]) attn_qk_scalar = paddle.matmul(q, paddle.transpose(k, [0, 1, 3, 2])) attn_logits = attn_qk_scalar + attn_qk_point attention_2d = self.attention_2d(pair_act) attention_2d = paddle.transpose(attention_2d, [0, 3, 1, 2]) attention_2d = attention_2d_weights * attention_2d attn_logits += attention_2d mask_2d = mask * paddle.transpose(mask, [0, 2, 1]) attn_logits -= 1e5 * (1. - mask_2d) # [batch_size, num_head, num_query_residues, num_target_residues] attn = nn.functional.softmax(attn_logits) # o_i^h # [batch_size, num_query_residues, num_head, num_head * num_scalar_v] result_scalar = paddle.matmul(attn, v) result_scalar = paddle.transpose(result_scalar, [0, 2, 1, 3]) # o_i^{hp} # [batch_size, num_query_residues, num_head, num_head * num_point_v] result_point_global = [ paddle.sum(paddle.unsqueeze(attn, -1) * paddle.unsqueeze(vx, -3), axis=-2) for vx in v_point] result_point_global = [ paddle.transpose(x, [0, 2, 1, 3]) for x in result_point_global] # \tilde{o}_i^h # [batch_size, num_residues, num_head, pair_channel] result_attention_over_2d = paddle.einsum( 'nhij,nijc->nihc', attn, pair_act) # Reshape, global-to-local and save result_scalar = paddle.reshape( result_scalar, [-1, num_residues, num_head * num_scalar_v]) result_point_global = [ paddle.reshape(x, [-1, num_residues, num_head * num_point_v]) for x in result_point_global] result_point_local = affine.invert_point( result_point_global, extra_dims=1) result_attention_over_2d = paddle.reshape( result_attention_over_2d, [-1, num_residues, num_head * self.channel_num['pair_channel']]) result_point_local_norm = paddle.sqrt( self.dist_epsilon + paddle.square(result_point_local[0]) + \ paddle.square(result_point_local[1]) + \ paddle.square(result_point_local[2])) output_features = [result_scalar] output_features.extend(result_point_local) output_features.extend( [result_point_local_norm, result_attention_over_2d]) final_act = paddle.concat(output_features, axis=-1) return self.output_projection(final_act)
def forward(self, query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None): ''' query_matrix: [B, H, T, D] key_matrix: [B, H, T, D] value_matrix: [B, H, T, D] query_mask: [B, 1, T, 1] bool mask key_mask: [B, 1, 1, T] bool mask rand_mask_idx: [H, T//bs, bs] Global Attention Random Attention Window Attention ''' B = query_matrix.shape[0] # batch_size H = self.num_heads T = query_matrix.shape[2] # sequence_length D = query_matrix.shape[3] # size per head G = self.num_global_blocks GB = self.num_global_blocks_back GF = self.num_global_blocks_front R = self.num_rand_blocks W = self.window_size bs = self.block_size L = T // bs # blocked length blocked_query_matrix = paddle.reshape(query_matrix, [B, H, L, bs, -1]) blocked_key_matrix = paddle.reshape(key_matrix, [B, H, L, bs, -1]) blocked_value_matrix = paddle.reshape(value_matrix, [B, H, L, bs, -1]) blocked_query_mask = paddle.reshape(query_mask, [B, L, bs]) blocked_key_mask = paddle.reshape(key_mask, [B, L, bs]) # 1. global_front_product global_front_out = self._get_global_out(query_matrix, key_matrix, value_matrix, key_mask, d_head, dropout) # 2. global_back_product global_back_out = self._get_global_out(query_matrix, key_matrix, value_matrix, key_mask, d_head, dropout, False) # 3. second_product # create second matrix # [B, 1, L-G, bs, (G+W)*bs] band_mask = self._get_band_mask(blocked_query_mask, blocked_key_mask, B, T) # [B, H, L-G, bs, R*bs] rand_mask = self._get_rand_mask(blocked_query_mask, blocked_key_mask, rand_mask_idx, B, T) # [B, H, L-G, bs, (G+W+R)*bs] second_mask = paddle.concat([band_mask, rand_mask], axis=4) # [B, H, L-G, R * bs, -1] random_keys = self._gather_random_key_value(blocked_key_matrix, rand_mask_idx, B, T) random_values = self._gather_random_key_value(blocked_value_matrix, rand_mask_idx, B, T) band_keys_matrix = self._get_band_matrix(blocked_key_matrix, B, T) band_value_matrix = self._get_band_matrix(blocked_value_matrix, B, T) # [B, H, L - G, bs, -1] second_query_matrix = blocked_query_matrix[:, :, GF:-GB] # [B, H, L - G, (G+W+R)*bs, -1] second_key_matrix = paddle.concat([band_keys_matrix, random_keys], axis=3) # [B, H, L - G, (G+W+R)*bs, -1] second_value_matrix = paddle.concat([band_value_matrix, random_values], axis=3) second_top_value_matrix, second_middle_value_matrix, second_bottom_value_matrix = \ self._get_splited_matrix(second_value_matrix) second_product = paddle.einsum("bhlqd,bhlkd->bhlqk", second_query_matrix, second_key_matrix) second_product = second_product * (d_head**-0.5) second_product += (1 - second_mask) * -1e6 second_weights = F.softmax(second_product) second_top_weights, second_middle_weights, second_bottom_weights = \ self._get_splited_matrix(second_weights) second_top_out = paddle.einsum("bhlqk,bhlkd->bhlqd", second_top_weights, second_top_value_matrix) second_middle_out = paddle.einsum( "bhlqk,bhlkd->bhlqd", second_middle_weights[:, :, :, :, GF * bs:-(GB + R) * bs], second_middle_value_matrix[:, :, :, GF * bs:-(GB + R) * bs]) # add global block attention second_middle_out += paddle.einsum( "bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, :GF * bs], blocked_value_matrix[:, :, 0]) second_middle_out += paddle.einsum( "bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, -(GB + R) * bs:-R * bs], blocked_value_matrix[:, :, -GB]) # add random block attention second_middle_out += paddle.einsum( "...qk,...kd->...qd", second_middle_weights[:, :, :, :, -R * bs:], random_values[:, :, GF:-GB]) second_bottom_out = paddle.einsum("bhlqk,bhlkd->bhlqd", second_bottom_weights, second_bottom_value_matrix) second_out = paddle.concat( [second_top_out, second_middle_out, second_bottom_out], axis=2) second_out = paddle.reshape(second_out, [B, H, (L - G) * bs, -1]) # [B, H, T, D] out = paddle.concat([global_front_out, second_out, global_back_out], axis=2) out = out * query_mask return out
def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size, sequence_length): ''' Return second mask: [B, 1, L-G, bs, G+W] ''' GB = self.num_global_blocks_back GF = self.num_global_blocks_front G = self.num_global_blocks R = self.num_rand_blocks W = self.window_size bs = self.block_size T = sequence_length L = T // bs # blocked length B = batch_size H = self.num_heads # G+W+R # query_mask: [B, L, bs] # key_mask: [B, L, bs] # [B, L-G, bs, 1] * [B, L-G, 1, G*bs] -> [B, L-G, bs, G*bs] temp_query_mask = paddle.reshape(blocked_query_mask[:, GF:-GB], [B, L - G, bs, 1]) temp_key_mask_front = paddle.reshape(blocked_key_mask[:, :GF], [B, 1, 1, GF * bs]) global_block_mask_front = paddle.einsum("blqd,bmdk->blqk", temp_query_mask, temp_key_mask_front) temp_key_mask_back = paddle.reshape(blocked_key_mask[:, -GB:], [B, 1, 1, GB * bs]) global_block_mask_back = paddle.einsum("blqd,bmdk->blqk", temp_query_mask, temp_key_mask_back) # create window block mask key_mask_list = [] for query_block_id in range(GF, GF + W // 2): left_block_id = query_block_id - W // 2 right_block_id = query_block_id + W // 2 zero_key_mask = paddle.zeros_like( blocked_key_mask[:, -(W - (right_block_id + 1 - G)):-GB]) temp_key_mask = paddle.concat( [blocked_key_mask[:, GF:(right_block_id + 1)], zero_key_mask], axis=1) temp_key_mask = paddle.unsqueeze(temp_key_mask, 1) key_mask_list.append(temp_key_mask) roll_key_mask1 = paddle.concat(key_mask_list, axis=1) roll_key_mask1 = paddle.reshape(roll_key_mask1, [0, 0, W * bs]) key_mask_list = [] band_length = L - G - W // 2 * 2 for query_block_id in range(GF + W // 2, GF + W // 2 + W): left_block_id = query_block_id - W // 2 right_block_id = query_block_id + W // 2 key_mask_list.append( blocked_key_mask[:, left_block_id:left_block_id + band_length]) window_key_mask = paddle.concat(key_mask_list, axis=2) window_key_mask = paddle.reshape(window_key_mask, [0, 0, W * bs]) key_mask_list = [] for query_block_id in range((L - GB) - W // 2, L - GB): left_block_id = query_block_id - W // 2 right_block_id = query_block_id + W // 2 zero_key_mask = paddle.zeros_like( blocked_key_mask[:, GF:GF + W - (L - left_block_id - GB)]) temp_key_mask = paddle.concat( [zero_key_mask, blocked_key_mask[:, left_block_id:-GB]], axis=1) temp_key_mask = paddle.unsqueeze(temp_key_mask, 1) key_mask_list.append(temp_key_mask) roll_key_mask2 = paddle.concat(key_mask_list, axis=1) roll_key_mask2 = paddle.reshape(roll_key_mask2, [0, 0, W * bs]) window_key_mask = paddle.concat( [roll_key_mask1, window_key_mask, roll_key_mask2], axis=1) window_key_mask = paddle.unsqueeze(window_key_mask, axis=2) # [B, L-G, bs, 1] * [B, L-G, 1, W*bs] -> [B, L-G, bs, W*bs] window_block_mask = paddle.einsum("blkd,bldq->blkq", temp_query_mask, window_key_mask) band_mask = paddle.concat([ global_block_mask_front, window_block_mask, global_block_mask_back ], axis=3) band_mask = paddle.unsqueeze(band_mask, 1) # for head band_mask = paddle.expand(band_mask, [B, H, L - G, bs, -1]) return band_mask
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): qlen, rlen, bsz = w.shape[1], r.shape[1], w.shape[0] if mems is not None: cat = paddle.concat([mems, w], axis=1) if self.normalize_before: w_heads = self.qkv_proj(self.layer_norm(cat)) else: w_heads = self.qkv_proj(cat) r_head_k = self.r_proj(r) w_head_q, w_head_k, w_head_v = paddle.chunk( w_heads, chunks=3, axis=-1) w_head_q = w_head_q[:, -qlen:, :] else: if self.normalize_before: w_heads = self.qkv_proj(self.layer_norm(w)) else: w_heads = self.qkv_proj(w) r_head_k = self.r_proj(r) w_head_q, w_head_k, w_head_v = paddle.chunk( w_heads, chunks=3, axis=-1) klen = w_head_k.shape[1] w_head_q = paddle.reshape( w_head_q, shape=[bsz, qlen, self.n_head, self.d_head]) w_head_k = paddle.reshape( w_head_k, shape=[bsz, klen, self.n_head, self.d_head]) w_head_v = paddle.reshape( w_head_v, shape=[bsz, klen, self.n_head, self.d_head]) r_head_k = paddle.reshape( r_head_k, shape=[bsz, rlen, self.n_head, self.d_head]) rw_head_q = w_head_q + r_w_bias AC = paddle.einsum('bind,bjnd->bnij', rw_head_q, w_head_k) rr_head_q = w_head_q + r_r_bias BD = paddle.einsum('bind,bjnd->bnij', rr_head_q, r_head_k) BD = self._rel_shift(BD) attn_score = AC + BD attn_score = attn_score * self.scale if attn_mask is not None: attn_score = attn_score - 1e30 * attn_mask attn_prob = F.softmax(attn_score, axis=-1) attn_prob = self.attn_drop(attn_prob) attn_vec = paddle.einsum('bnij,bjnd->bind', attn_prob, w_head_v) attn_vec = paddle.reshape( attn_vec, shape=[ attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head ]) attn_out = self.o_proj(attn_vec) attn_out = self.drop(attn_out) if self.normalize_before: output = w + attn_out else: output = self.layer_norm(w + attn_out) return output
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): qlen, bsz = w.shape[1], w.shape[0] if mems is not None: cat = paddle.concat([mems, w], 1) if self.normalize_before: w_heads = self.qkv_proj(self.layer_norm(cat)) else: w_heads = self.qkv_proj(cat) w_head_q, w_head_k, w_head_v = paddle.chunk( w_heads, chunks=3, axis=-1) w_head_q = w_head_q[-qlen:] else: if self.normalize_before: w_heads = self.qkv_proj(self.layer_norm(w)) else: w_heads = self.qkv_proj(w) w_head_q, w_head_k, w_head_v = paddle.chunk( w_heads, chunks=3, axis=-1) klen = w_head_k.shape[1] w_head_q = paddle.reshape( w_head_q, shape=[ w_head_q.shape[0], w_head_q.shape[1], self.n_head, self.d_head ]) w_head_k = paddle.reshape( w_head_k, shape=[ w_head_k.shape[0], w_head_k.shape[1], self.n_head, self.d_head ]) w_head_v = paddle.reshape( w_head_v, shape=[ w_head_v.shape[0], w_head_v.shape[1], self.n_head, self.d_head ]) if klen > r_emb.shape[0]: r_emb_pad = r_emb[0:1].expand(klen - r_emb.shape[0], -1, -1) r_emb = paddle.concat([r_emb_pad, r_emb], 0) r_bias_pad = r_bias[0:1].expand(klen - r_bias.shape[0], -1) r_bias = paddle.concat([r_bias_pad, r_bias], 0) else: r_emb = r_emb[-klen:] r_bias = r_bias[-klen:] rw_head_q = w_head_q + r_w_bias.unsqueeze([0]) AC = paddle.einsum('bind,bjnd->bnij', rw_head_q, w_head_k) r_emb = r_emb.unsqueeze([0]).expand([bsz, -1, -1, -1]) B_ = paddle.einsum('bind,bjnd->bnij', w_head_q, r_emb) D_ = r_bias.unsqueeze([0, 2]) BD = self._rel_shift(B_ + D_) attn_score = AC + BD attn_score = attn_score * self.scale if attn_mask is not None: attn_score = attn_score - float('inf') * attn_mask attn_prob = F.softmax(attn_score, dim=-1) attn_prob = self.attn_drop(attn_prob) attn_vec = paddle.einsum('bnij,bjnd->bind', attn_prob, w_head_v) attn_vec = paddle.reshape( attn_vec, shape=[ attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head ]) attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.normalize_before: output = w + attn_out else: output = self.layer_norm(w + attn_out) return output
def test_param_errors(self): a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') a = paddle.to_tensor(a) with self.assertRaisesRegex(AssertionError, ('At least one operand is expected.')): paddle.einsum('ijk') with self.assertRaisesRegex(AssertionError, ( 'Invalid equation: multiple `->` were found.')): paddle.einsum('i -> j -> k', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: the number of operands is 2, " "but found 3 segments in the label equation.")): paddle.einsum('i,j,k', a, a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: the number of operands is 2, " "but found 1 segments in the label equation.")): paddle.einsum('ij -> k', a, a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: the number of operands is 1, " "but found 2 segments in the label equation.")): paddle.einsum('i, -> k', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: the label string '' misses dimensions.")): paddle.einsum('->', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: the label string 'i' misses dimensions.")): paddle.einsum('i', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: _ is not a valid label, " "which should be letters.")): paddle.einsum('i_', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: `.` is found outside of an ellipsis.")): paddle.einsum('i..j', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: `.` is found outside of an ellipsis.")): paddle.einsum('...k...', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: missing ellipsis in output labels.")): paddle.einsum('i...->i', a) with self.assertRaisesRegex(AssertionError, ( "Invalid equation: duplicate output labels are found.")): paddle.einsum('i...->i...i', a) with self.assertRaisesRegex(AssertionError, ( "Invalid operands: label i " "corresponds to non-broadcastable dimensions.")): paddle.einsum('ij...,ji...', a, a)
def test_shape(self): A = paddle.static.data(name='x', shape=[-1]) B = paddle.static.data(name='y', shape=[384]) C = paddle.einsum('i,d->id', A, B) self.assertEqual(C.shape, (-1, 384))