def execute(self, x): imsize = x.shape _, _, _, x = self.backbone(x) b, c, h, w = x.shape x_k = self.conv_0(x) x1 = self.layer5a(x_k).reshape(b, c, -1) x2 = self.layer5b(x_k).reshape(b, c, -1) x3 = self.layer5c(x_k).reshape(b, c, -1) x4 = self.layer5d(x_k).reshape(b, c, -1) x_k = concat([x1, x2, x3, x4], 2).transpose(0, 2, 1) # b 110 c x_q = self.conv_1(x) x_q = x_q.reshape(b, c, -1) x_attention = nn.bmm(x_k, x_q) # b 110 N x_v = self.conv_2(x) x1 = self.layer5a(x_v).reshape(b, c, -1) x2 = self.layer5b(x_v).reshape(b, c, -1) x3 = self.layer5c(x_v).reshape(b, c, -1) x4 = self.layer5d(x_v).reshape(b, c, -1) x_v = concat([x1, x2, x3, x4], 2) # b c 110 x = nn.bmm(x_v, x_attention).reshape(b, c, h, w) x = self.final_conv(x) x = nn.resize(x, size=(imsize[2], imsize[3]), mode='bilinear', align_corners=True) return x
def execute(self, point_cloud, label): B, D, N = point_cloud.size() trans = self.stn(point_cloud) point_cloud = point_cloud.transpose(0, 2, 1) point_cloud = nn.bmm(point_cloud, trans) point_cloud = point_cloud.transpose(0, 2, 1) out1 = self.relu(self.bn1(self.conv1(point_cloud))) out2 = self.relu(self.bn2(self.conv2(out1))) out3 = self.relu(self.bn3(self.conv3(out2))) trans_feat = self.fstn(out3) x = out3.transpose(0, 2, 1) net_transformed = nn.bmm(x, trans_feat) net_transformed = net_transformed.transpose(0, 2, 1) out4 = self.relu(self.bn4(self.conv4(net_transformed))) out5 = self.bn5(self.conv5(out4)) out_max = jt.argmax(out5, 2, keepdims=True)[1] out_max = out_max.view(-1, 2048) out_max = concat((out_max, label), 1) expand = out_max.view(-1, 2048 + 16, 1).repeat(1, 1, N) concat_feature = concat([expand, out1, out2, out3, out4, out5], 1) net = self.relu(self.bns1(self.convs1(concat_feature))) net = self.relu(self.bns2(self.convs2(net))) net = self.relu(self.bns3(self.convs3(net))) net = self.convs4(net) return net
def execute(self, x): x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c x_k = self.k_conv(x) # b, c, n x_v = self.v_conv(x) energy = nn.bmm(x_q, x_k) # b, n, n attention = self.softmax(energy) attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) x_r = nn.bmm(x_v, attention) # b, c, n x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) x = x + x_r return x
def execute(self, features, points): if points is not None: features = features + points q = self.q_conv(features).permute(0, 2, 1) # b, n, c k = self.k_conv(features) # b, c, n v = self.v_conv(features) energy = nn.bmm(q, k) # b, n, n attention = self.softmax(energy) attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) r = nn.bmm(v, attention) # b, c, n r = self.act(self.after_norm(self.trans_conv(features - r))) features = features + r return features
def calc(use_cuda, a, b, mask): jt.flags.use_cuda = use_cuda a = jt.array(a) b = jt.array(b) mask = jt.array(mask) c = nn.bmm(a, b) da, db = jt.grad(c * mask, [a, b]) return c.data, da.data, db.data
def execute(self, context, feature): batch_size, c, h, w = feature.shape origin_feature = feature feature = feature.reshape(batch_size, c, -1).transpose(0, 2, 1) # b, h*w, c context = context.reshape(batch_size, context.shape[1], -1) # b, n_cls, h*w attention = self.softmax(context) ocr_context = nn.bmm(attention, feature).transpose(0, 2, 1) # b, c, n_cls relation = nn.bmm(feature, ocr_context).transpose(0, 2, 1) # b, n_cls, h*w attention = self.softmax(relation) #b , n_cls, h*w result = nn.bmm(ocr_context, attention).reshape(batch_size, c, h, w) result = self.conv_1x1(result) result = concat([result, origin_feature], dim=1) result = self.last_conv(result) return result
def execute(self, x): """ inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X C X C """ m_batchsize, C, height, width = x.size() proj_query = x.reshape(m_batchsize, C, -1) proj_key = x.reshape(m_batchsize, C, -1).transpose(0, 2, 1) energy = nn.bmm(proj_query, proj_key) #energy_new = jt.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy attention = self.softmax(energy) proj_value = x.reshape(m_batchsize, C, -1) out = nn.bmm(attention, proj_value) out = out.reshape(m_batchsize, C, height, width) out = self.gamma * out + x return out
def execute(self, x): """ inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X (HxW) X (HxW) """ m_batchsize, C, height, width = x.size() proj_query = self.query_conv(x).reshape(m_batchsize, -1, width * height).transpose( 0, 2, 1) proj_key = self.key_conv(x).reshape(m_batchsize, -1, width * height) energy = nn.bmm(proj_query, proj_key) attention = self.softmax(energy) proj_value = self.value_conv(x).reshape(m_batchsize, -1, width * height) out = nn.bmm(proj_value, attention.transpose(0, 2, 1)) out = out.reshape(m_batchsize, C, height, width) out = self.gamma * out + x return out
def execute(self, x, mask=None): b, n, _ = x.shape h = self.heads q, k, v = self.to_qkv(x).chunk(3, dim=-1) q = q.reshape(b, n, h, -1) q = q.transpose(0, 2, 1, 3) k = k.reshape(b, n, h, -1) k = k.transpose(0, 2, 1, 3) v = v.reshape(b, n, h, -1) v = v.transpose(0, 2, 1, 3) #b,h,n,d d = q.shape[-1] q = q.reshape(b * h, n, d) k = k.reshape(b * h, n, d).transpose(0, 2, 1) dots = nn.bmm(q, k).reshape(b, h, n, n) dots = dots * self.scale if mask is not None: mask = nn.pad(mask.flatten(1), (1, 0), value=1) assert mask.shape[-1] == dots.shape[ -1], 'mask has incorrect shapes' mask = mask.unsqueeze(1) * mask.unsqueeze(2) dots.masked_fill_(~mask, float('-inf')) del mask attn = nn.softmax(dots, dim=-1) out = nn.bmm(attn.reshape(b * h, n, n), v.reshape(b * h, n, d)).reshape(b, h, n, d) out = out.transpose(0, 2, 1, 3).reshape(b, n, h * d) out = self.to_out(out) return out
def execute(self, x): b, n, c = x.shape qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # attn = nn.bmm(q,k.transpose(0,1,3,2))*self.scale attn = nn.bmm_transpose(q, k) * self.scale attn = nn.softmax(attn, dim=-1) attn = self.attn_drop(attn) out = nn.bmm(attn, v) out = out.transpose(0, 2, 1, 3).reshape(b, n, c) out = self.proj(out) out = self.proj_drop(out) return out
def execute( self, query, key=None, value=None, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None, before_softmax=False, need_head_weights=False, ): if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.shape assert embed_dim == self.embed_dim assert list(query.shape) == [tgt_len, bsz, embed_dim] assert incremental_state is None, "TODO: incremental_state is not None" saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q = q * self.scaling assert self.bias_k is None, "TODO: self.bias_k is not None:" q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) if k is not None: k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) if v is not None: v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) assert saved_state is None, "TODO: saved_state is not None" assert k is not None src_len = k.shape[1] assert key_padding_mask is None, "TODO: key_padding_mask is not None" assert not self.add_zero_attn, "TODO: self.add_zero_attn=True" attn_weights = nn.bmm(q, k.transpose(0, 2, 1)) assert list( attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] assert attn_mask is None, "TODO: attn_mask is not None" assert key_padding_mask is None, "TODO: key_padding_mask is not None" if before_softmax: return attn_weights, v attn_weights_float = nn.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) assert v is not None attn = nn.bmm(attn_weights, v) assert list( attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.shape[1] == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights = None if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose( 1, 0, 2, 3) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dims=[0]) return attn, attn_weights