def forward(self, x): residual = x t = self.t(x) p = self.p(x) g = self.g(x) b, c, h, w = t.shape t = paddle.transpose(paddle.reshape(t, (b, c, -1)), (0, 2, 1)) p = paddle.reshape(p, (b, c, -1)) g = paddle.transpose(paddle.reshape(g, (b, c, -1)), (0, 2, 1)) att = paddle.bmm(t, p) if self.use_scale: att = paddle.divide(att, paddle.to_tensor(c**0.5)) att = self.softmax(att) x = paddle.bmm(att, g) x = paddle.transpose(x, (0, 2, 1)) x = paddle.reshape(x, (b, c, h, w)) x = self.z(x) x = self.bn(x) + residual return x
def forward(self, x: paddle.Tensor, proxy: paddle.Tensor) -> paddle.Tensor: n, _, h, w = x.shape # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) query = self.f_pixel(x) query = paddle.reshape(query, (n, self.key_channels, -1)) query = paddle.transpose(query, [0, 2, 1]) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) key = self.f_object(proxy) key = paddle.reshape(key, (n, self.key_channels, -1)) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) value = self.f_down(proxy) value = paddle.reshape(value, (n, self.key_channels, -1)) value = paddle.transpose(value, [0, 2, 1]) # sim_map (n, h1*w1, h2*w2) sim_map = paddle.bmm(query, key) sim_map = (self.key_channels**-.5) * sim_map sim_map = F.softmax(sim_map, axis=-1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) context = paddle.bmm(sim_map, value) context = paddle.transpose(context, [0, 2, 1]) context = paddle.reshape(context, (n, self.key_channels, h, w)) context = self.f_up(context) return context
def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape([b, c, h * w]) q = q.transpose([0, 2, 1]) # b,hw,c k = k.reshape([b, c, h * w]) # b,c,hw w_ = paddle.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = paddle.nn.functional.softmax(w_, 2) # attend to values v = v.reshape([b, c, h * w]) w_ = w_.transpose([0, 2, 1]) # b,hw,hw (first hw of k, second of q) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = paddle.bmm(v, w_) h_ = h_.reshape([b, c, h, w]) h_ = self.proj_out(h_) return x + h_
def forward(self, input, mask=None): """ Args: input (obj: `paddle.Tensor`) of shape (batch, seq_len, input_size): Tensor containing the features of the input sequence. mask (obj: `paddle.Tensor`, optional, defaults to `None`) of shape (batch, seq_len) : Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not. """ forward_input, backward_input = paddle.chunk(input, chunks=2, axis=2) # elementwise-sum forward_x and backward_x # Shape: (batch_size, max_seq_len, hidden_size) h = paddle.add_n([forward_input, backward_input]) # Shape: (batch_size, hidden_size, 1) att_weight = self.att_weight.tile( repeat_times=(paddle.shape(h)[0], 1, 1)) # Shape: (batch_size, max_seq_len, 1) att_score = paddle.bmm(paddle.tanh(h), att_weight) if mask is not None: # mask, remove the effect of 'PAD' mask = paddle.cast(mask, dtype='float32') mask = mask.unsqueeze(axis=-1) inf_tensor = paddle.full( shape=mask.shape, dtype='float32', fill_value=-INF) att_score = paddle.multiply(att_score, mask) + paddle.multiply( inf_tensor, (1 - mask)) # Shape: (batch_size, max_seq_len, 1) att_weight = F.softmax(att_score, axis=1) # Shape: (batch_size, lstm_hidden_size) reps = paddle.bmm(h.transpose(perm=(0, 2, 1)), att_weight).squeeze(axis=-1) reps = paddle.tanh(reps) return reps, att_weight
def forward(self, x): b, c, h, w = x.shape x = paddle.reshape(x, [b, c, h * w]) mu = paddle.tile(self.mu, [b, 1, 1]) with paddle.no_grad(): for i in range(self.stage_num): x_t = paddle.transpose(x, [0, 2, 1]) z = paddle.bmm(x_t, mu) z = F.softmax(z, axis=2) z_ = F.normalize(z, axis=1, p=1) mu = paddle.bmm(x, z_) mu = F.normalize(mu, axis=1, p=2) z_t = paddle.transpose(z, [0, 2, 1]) x = paddle.matmul(mu, z_t) x = paddle.reshape(x, [b, c, h, w]) if self.training: mu = paddle.mean(mu, 0, keepdim=True) if paddle.distributed.get_world_size() > 1: paddle.distributed.reduce( mu / paddle.distributed.get_world_size(), 0) mu = F.normalize(mu, axis=1, p=2) self.mu = self.mu * (1 - self.momentum) + mu * self.momentum return x
def forward(self, input, mask=None): """ Args: input (obj: `paddle.Tensor`) of shape (batch, seq_len, input_size): Tensor containing the features of the input sequence. mask (obj: `paddle.Tensor`, optional, defaults to `None`) of shape (batch, seq_len) : Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not. """ weight = self.input_weight.tile( repeat_times=(paddle.shape(input)[0], 1, 1)) bias = self.bias.tile(repeat_times=(paddle.shape(input)[0], 1, 1)) # Shape: (batch_size, max_seq_len, hidden_size) word_squish = paddle.bmm(input, weight) + bias att_context_vector = self.att_context_vector.tile( repeat_times=(paddle.shape(input)[0], 1, 1)) # Shape: (batch_size, max_seq_len, 1) att_score = paddle.bmm(word_squish, att_context_vector) if mask is not None: # mask, remove the effect of 'PAD' mask = paddle.cast(mask, dtype='float32') mask = mask.unsqueeze(axis=-1) inf_tensor = paddle.full( shape=paddle.shape(mask), dtype='float32', fill_value=-INF) att_score = paddle.multiply(att_score, mask) + paddle.multiply( inf_tensor, (1 - mask)) att_weight = F.softmax(att_score, axis=1) # Shape: (batch_size, hidden_size) reps = paddle.bmm(input.transpose(perm=(0, 2, 1)), att_weight).squeeze(-1) return reps, att_weight
def neg_score(self, inputs, inputs_rel, inputs_last, neg_sample_size, chunk_size): inputs_size = inputs.shape assert inputs_size[:-1] == inputs_rel.shape[:-1] num_dim = inputs_size[-1] inputs = inputs.reshape([-1, 1, self.num_elem]) if self.use_scale: rel = inputs_rel.reshape([-1, self.num_elem, self.num_elem + 1]) scale = self.get_scale(rel[:, :, self.num_elem:]) scale = scale / scale.norm(axis=-1, p=2, keepdim=True) rel_scale = rel[:, :, :self.num_elem] * scale outputs = paddle.bmm(inputs, rel_scale) else: rel = inputs_rel.reshape([-1, self.num_elem, self.num_elem]) outputs = paddle.bmm(inputs, rel) outputs = outputs.reshape([-1, chunk_size, 1, inputs_size[-1]]) inputs_last = inputs_last.reshape( [-1, 1, neg_sample_size, inputs_size[-1]]) outputs = outputs - inputs_last outputs_size = outputs.shape num_dim = outputs_size[-1] outputs = outputs.reshape([-1, self.num_elem]) scores = outputs.norm( p=2, axis=-1).reshape([-1, num_dim // self.num_elem]).sum( axis=-1).reshape(outputs_size[:-1]) return scores
def forward(self, x): n, _, h, w = x.shape # query: n, h * w, c1 query = self.query_conv(x) query = paddle.reshape(query, (n, -1, h * w)) query = paddle.transpose(query, (0, 2, 1)) # key: n, c1, h * w key = self.key_conv(x) key = paddle.reshape(key, (n, -1, h * w)) # sim: n, h * w, h * w sim = paddle.bmm(query, key) sim = F.softmax(sim, axis=-1) value = self.value_conv(x) value = paddle.reshape(value, (n, -1, h * w)) sim = paddle.transpose(sim, (0, 2, 1)) # feat: from (n, c2, h * w) -> (n, c2, h, w) feat = paddle.bmm(value, sim) feat = paddle.reshape(feat, (n, -1, h, w)) out = self.gamma * feat + x return out
def forward(self, x, proxy): x_shape = paddle.shape(x) # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) query = self.f_pixel(x) query = paddle.reshape(query, (0, self.key_channels, -1)) query = paddle.transpose(query, (0, 2, 1)) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) key = self.f_object(proxy) key = paddle.reshape(key, (0, self.key_channels, -1)) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) value = self.f_down(proxy) value = paddle.reshape(value, (0, self.key_channels, -1)) value = paddle.transpose(value, (0, 2, 1)) # sim_map (n, h1*w1, h2*w2) sim_map = paddle.bmm(query, key) sim_map = (self.key_channels**-.5) * sim_map sim_map = F.softmax(sim_map, axis=-1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) context = paddle.bmm(sim_map, value) context = paddle.transpose(context, (0, 2, 1)) context = paddle.reshape(context, (0, self.key_channels, x_shape[2], x_shape[3])) context = self.f_up(context) return context
def forward(self, pro_features, roi_features): ''' pro_features: (1, N * nr_boxes, self.d_model) roi_features: (49, N * nr_boxes, self.d_model) ''' features = roi_features.transpose(perm=[1, 0, 2]) parameters = self.dynamic_layer(pro_features).transpose(perm=[1, 0, 2]) param1 = parameters[:, :, :self.num_params].reshape( [-1, self.hidden_dim, self.dim_dynamic]) param2 = parameters[:, :, self.num_params:].reshape( [-1, self.dim_dynamic, self.hidden_dim]) features = paddle.bmm(features, param1) features = self.norm1(features) features = self.activation(features) features = paddle.bmm(features, param2) features = self.norm2(features) features = self.activation(features) features = features.flatten(1) features = self.out_layer(features) features = self.norm3(features) features = self.activation(features) return features
def forward(self, x): x_shape = paddle.shape(x) x = x.flatten(2) mu = paddle.tile(self.mu, [x_shape[0], 1, 1]) with paddle.no_grad(): for i in range(self.stage_num): x_t = paddle.transpose(x, [0, 2, 1]) z = paddle.bmm(x_t, mu) z = F.softmax(z, axis=2) z_ = F.normalize(z, axis=1, p=1) mu = paddle.bmm(x, z_) mu = F.normalize(mu, axis=1, p=2) z_t = paddle.transpose(z, [0, 2, 1]) x = paddle.matmul(mu, z_t) x = paddle.reshape(x, [0, self.c, x_shape[2], x_shape[3]]) if self.training: mu = paddle.mean(mu, 0, keepdim=True) mu = F.normalize(mu, axis=1, p=2) mu = self.mu * (1 - self.momentum) + mu * self.momentum if paddle.distributed.get_world_size() > 1: mu = paddle.distributed.all_reduce(mu) mu /= paddle.distributed.get_world_size() self.mu = mu return x
def forward(self, X: paddle.Tensor): # input X is a 3D feature map self.P = paddle.bmm(self.weight.expand_as(self.G), self.G) x = paddle.bmm( self.P.transpose((0, 2, 1)).expand((X.shape[0], self.C, self.C)), X.reshape((X.shape[0], X.shape[1], -1))).reshape(X.shape) return x
def forward(self, x): #Notation from https://arxiv.org/pdf/1805.08318.pdf size = x.shape x = paddle.reshape(x, list(size[:2]) + [-1]) f, g, h = self.query(x), self.key(x), self.value(x) beta = paddle.nn.functional.softmax(paddle.bmm( paddle.transpose(f, [0, 2, 1]), g), axis=1) o = self.gamma * paddle.bmm(h, beta) + x return paddle.reshape(o, size)
def func(value_local, query_local, key_local): batch_size_new = value_local.shape[0] h_local, w_local = value_local.shape[2], value_local.shape[3] value_local = value_local.reshape([batch_size_new, self.value_channels, -1]) query_local = query_local.reshape([batch_size_new, self.key_channels, -1]) query_local = query_local.transpose((0, 2, 1)) key_local = key_local.reshape([batch_size_new, self.key_channels, -1]) sim_map = paddle.bmm(query_local, key_local) sim_map = (self.key_channels ** -.5) * sim_map attention = F.softmax(sim_map - paddle.max(sim_map, axis=-1, keepdim=True), axis=-1) context_local = paddle.bmm(value_local, attention.transpose((0, 2, 1))) context_local = context_local.reshape([batch_size_new, self.value_channels, h_local, w_local, 2]) return context_local
def forward(self, node_feat, edge_feat): # get size num_tasks = node_feat.shape[0] num_data = node_feat.shape[1] # get eye matrix (batch_size x 2 x node_size x node_size) diag_mask = 1.0 - paddle.expand( paddle.eye(num_data), [num_tasks, self.edge_dim, num_data, num_data]) # set diagonal as zero and normalize edge_feat = F.normalize(edge_feat * diag_mask, p=1, axis=-1) # compute attention and aggregate aggr_feat = paddle.bmm( paddle.concat(paddle.split(edge_feat, 2, 1), self.edge_dim).squeeze(1), node_feat) node_feat = paddle.transpose( paddle.concat( [node_feat, paddle.concat(paddle.split(aggr_feat, 2, 1), -1)], -1), (0, 2, 1)) # non-linear transform node_feat = paddle.transpose(self.network(node_feat.unsqueeze(-1)), (0, 2, 1, 3)).squeeze(-1) return node_feat
def cdist(self, a, b): a_s = paddle.norm(a, p=2, axis=-1).pow(2) b_s = paddle.norm(b, p=2, axis=-1).pow(2) dist_score = -2 * paddle.bmm(a, b.transpose( [0, 2, 1])) + a_s.unsqueeze(-1) dist_score = paddle.sqrt(paddle.clip(dist_score, min=1e-30)) return dist_score
def forward(self, student, teacher): # reshape for feature map distillation bs = student.shape[0] student = student.reshape([bs, -1]) teacher = teacher.reshape([bs, -1]) td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) norm_td = F.normalize(td, p=2, axis=2) t_angle = paddle.bmm(norm_td, norm_td.transpose([0, 2, 1])).reshape( [-1, 1]) sd = (student.unsqueeze(0) - student.unsqueeze(1)) norm_sd = F.normalize(sd, p=2, axis=2) s_angle = paddle.bmm(norm_sd, norm_sd.transpose([0, 2, 1])).reshape( [-1, 1]) loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean') return loss
def single_attn(self, x): x = self.kqv(x) k, q, v = paddle.split(x, x.shape[-1] // self.emb, axis=-1) kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) # (B, T, m) * (B, m) -> (B, T, 1) D = paddle.bmm(qp, kp.sum(axis=1).unsqueeze(axis=-1)) kptv = paddle.bmm(v.astype("float32").transpose((0, 2, 1)), kp) # (B, emb, m) y = paddle.bmm(qp, kptv.transpose((0, 2, 1))) / ( D.tile([1, 1, self.emb]) + self.epsilon) # (B, T, emb) / Diag # skip connection # same as token_transformer in T2T layer, use v as skip connection y = v + self.dp(self.proj(y)) return y
def forward(self, source, reference): s_batchsize, sC, sT, sH, sW = source.shape r_batchsize, rC, rT, rH, rW = reference.shape proj_query = paddle.reshape(self.query_conv(source), [s_batchsize, -1, sT * sH * sW]) proj_query = paddle.transpose(proj_query, [0, 2, 1]) proj_key = paddle.reshape(self.key_conv(reference), [r_batchsize, -1, rT * rW * rH]) energy = paddle.bmm(proj_query, proj_key) attention = F.softmax(energy) proj_value = paddle.reshape(self.value_conv(reference), [r_batchsize, -1, rT * rH * rW]) out = paddle.bmm(proj_value, paddle.transpose(attention, [0, 2, 1])) out = paddle.reshape(out, [s_batchsize, sC, sT, sH, sW]) out = self.gamma * out + source return out, attention
def similarity_matrix(self, embeds): # (N, M, C) speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation centroids_incl = paddle.mean(embeds, axis=1) centroids_incl_norm = paddle.norm(centroids_incl, p=2, axis=1, keepdim=True) normalized_centroids_incl = centroids_incl / centroids_incl_norm # Exclusive centroids (1 per utterance) centroids_excl = paddle.broadcast_to( paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds centroids_excl /= (utterances_per_speaker - 1) centroids_excl_norm = paddle.norm(centroids_excl, p=2, axis=2, keepdim=True) normalized_centroids_excl = centroids_excl / centroids_excl_norm p1 = paddle.matmul(embeds.reshape([-1, embed_dim]), normalized_centroids_incl, transpose_y=True) # (NMN) p1 = p1.reshape([-1]) # print("p1: ", p1.shape) p2 = paddle.bmm(embeds.reshape([-1, 1, embed_dim]), normalized_centroids_excl.reshape([-1, embed_dim, 1])) # (NM, 1, 1) p2 = p2.reshape([-1]) # (NM) # begin: alternative implementation for scatter with paddle.no_grad(): index = paddle.arange(0, speakers_per_batch * utterances_per_speaker, dtype="int64").reshape([ speakers_per_batch, utterances_per_speaker ]) index = index * speakers_per_batch + paddle.arange( 0, speakers_per_batch, dtype="int64").unsqueeze(-1) index = paddle.reshape(index, [-1]) ones = paddle.ones( [speakers_per_batch * utterances_per_speaker * speakers_per_batch]) zeros = paddle.zeros_like(index, dtype=ones.dtype) mask_p1 = paddle.scatter(ones, index, zeros) p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2) # end: alternative implementation for scatter # p = paddle.scatter(p1, index, p2) p = p * self.similarity_weight + self.similarity_bias # neg p = p.reshape( [speakers_per_batch * utterances_per_speaker, speakers_per_batch]) return p, p1, p2
def forward(self, query, keys, attn_mask=None): # query shape: batch x query_size # keys shape: batch x num keys x key_size # proj_query shape: batch x key_size x 1 proj_query = self.query_proj(query).unsqueeze(2) # attn_logits shape: batch x num keys attn_logits = paddle.bmm(keys, proj_query).squeeze(2) / self.temp maybe_mask(attn_logits, attn_mask) return attn_logits
def forward(self, x): t, b, d = x.shape x = x.reshape((t * b, self.groups, int(d / self.groups))) x = x.transpose((1, 0, 2)) x = paddle.bmm(x, self.group_weight) x = x.transpose((1, 0, 2)) x = x.reshape((t, b, self.out_dim)) if self.group_bias is not None: x = x + self.group_bias return x
def forward(self, word, vis_node): m_batchsize, C, Nc = paddle.shape(word) m_batchsize, C, Nn = paddle.shape(vis_node) proj_query = self.query_conv(word).reshape((m_batchsize, -1, Nc))\ .transpose((0, 2, 1)) proj_key = self.key_conv(vis_node).reshape((m_batchsize, -1, Nn)) energy = paddle.bmm(proj_query, proj_key) attention_vis = self.softmax_vis(energy).transpose((0, 2, 1)) attention_word = self.softmax_word(energy) proj_value_vis = self.value_conv_vis(vis_node).reshape( (m_batchsize, -1, Nn)) proj_value_word = self.value_conv_word(word).reshape( (m_batchsize, -1, Nc)) class_out = paddle.bmm(proj_value_vis, attention_vis) node_out = paddle.bmm(proj_value_word, attention_word) return class_out, node_out
def forward(self, input): """ inputs : x : input feature maps(B C W H) returns : out : self attention value + input feature attention: B N N (N is Width*Height) """ x = self.pool(input) N, C, H, W = x.shape proj_query = self.query_conv(x).reshape([N, -1, H * W]).transpose((0, 2, 1)) proj_key = self.key_conv(x).reshape([N, -1, H * W]) energy = paddle.bmm(proj_query, proj_key) energy = (self.key_channel ** -.5) * energy attention = self.softmax(energy - paddle.max(energy, axis=-1, keepdim=True)) # 防止溢出 proj_value = self.value_conv(x).reshape([N, -1, H * W]) out = paddle.bmm(proj_value, attention.transpose((0, 2, 1))) out = out.reshape([N, C, H, W]) out = F.interpolate(out, [H * self.ds, W * self.ds]) out = out + input return out
def forward(self, query, values, attn_mask=None): # query shape: batch x query_size # values shape: batch x num values x value_size # attn_logits shape: batch x num values attn_logits = self.pointer(query, values, attn_mask) # attn_logits shape: batch x num values attn = self.softmax(attn_logits) # output shape: batch x 1 x value_size output = paddle.bmm(attn.unsqueeze(1), values) output = output.squeeze(1) return output, attn
def test_out(self): input1 = np.array([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], [[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]]) input2 = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]]]) with fluid.dygraph.guard(): x = fluid.dygraph.to_variable(input1) y = fluid.dygraph.to_variable(input2) out = paddle.bmm(x, y) out_np = out.numpy() expected_result = np.matmul(input1, input2) self.assertTrue(np.allclose(expected_result, out_np))
def forward(self, x): x_shape = paddle.shape(x) # query: n, c, h * w query = paddle.reshape(x, (0, self.channels, -1)) # key: n, h * w, c key = paddle.reshape(x, (0, self.channels, -1)) key = paddle.transpose(key, (0, 2, 1)) # sim: n, c, c sim = paddle.bmm(query, key) # The danet author claims that this can avoid gradient divergence sim = paddle.max(sim, axis=-1, keepdim=True).tile( [1, 1, self.channels]) - sim sim = F.softmax(sim, axis=-1) # feat: from (n, c, h * w) to (n, c, h, w) value = paddle.reshape(x, (0, self.channels, -1)) feat = paddle.bmm(sim, value) feat = paddle.reshape(feat, (0, self.channels, x_shape[2], x_shape[3])) out = self.gamma * feat + x return out
def forward(self, x): n, c, h, w = x.shape # query: n, c, h * w query = paddle.reshape(x, (n, c, h * w)) # key: n, h * w, c key = paddle.reshape(x, (n, c, h * w)) key = paddle.transpose(key, (0, 2, 1)) # sim: n, c, c sim = paddle.bmm(query, key) # The danet author claims that this can avoid gradient divergence sim = paddle.max(sim, axis=-1, keepdim=True).expand_as(sim) - sim sim = F.softmax(sim, axis=-1) # feat: from (n, c, h * w) to (n, c, h, w) value = paddle.reshape(x, (n, c, h * w)) feat = paddle.bmm(sim, value) feat = paddle.reshape(feat, (n, c, h, w)) out = self.gamma * feat + x return out
def forward(self, x, inp): B = self.conv_theta(x) sizeB = paddle.shape(B) B = paddle.flatten(B, 2, 3) sizex = paddle.shape(x) x_reduce = self.conv_phi(x) x_reduce = paddle.flatten(x_reduce, 2, 3).transpose((0, 2, 1)) V = paddle.bmm(B, x_reduce).transpose((0, 2, 1)) V = paddle.divide(V, (sizex[2] * sizex[3]).astype('float32')) class_node, new_V = self.graph(inp, V) D = B.transpose((0, 2, 1)) Y = paddle.bmm(D, new_V.transpose((0, 2, 1))) Y = Y.transpose((0, 2, 1)).reshape((sizex[0], self.num_state, \ sizex[2], -1)) Y = self.extend_dim(Y) Y = self.bn(Y) out = Y + x return out, class_node
def multi_head_attention_forward(x: Tensor, num_heads: int, q_proj: Linear, k_proj: Linear, v_proj: Linear, c_proj: Linear, attn_mask: Optional[Tensor] = None): max_len, batch_size, emb_dim = x.shape # set_trace() #assert emb_dim==self.emb_dim, f"The last dim of x: {emb_dim} must be equal to self.emb_dim: {self.emb_dim}" head_dim = emb_dim // num_heads scaling = float(head_dim)**-0.5 q = q_proj(x) # L, N, E k = k_proj(x) # L, N, E v = v_proj(x) # L, N, E #k = k.con v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2)) k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2)) q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2)) q = q * scaling qk = paddle.bmm(q, k.transpose((0, 2, 1))) # attn_output_weight in torch if attn_mask is not None: if attn_mask.ndim == 2: attn_mask.unsqueeze_(0) assert str(attn_mask.dtype) == 'VarType.FP32' and attn_mask.ndim == 3 assert attn_mask.shape[0] == 1 and attn_mask.shape[ 1] == max_len and attn_mask.shape[2] == max_len qk += attn_mask qk = paddle.nn.functional.softmax(qk, axis=-1) atten = paddle.bmm(qk, v) atten = atten.transpose((1, 0, 2)) atten = atten.reshape((max_len, batch_size, emb_dim)) atten = c_proj(atten) return atten