예제 #1
0
 def execute(self, x):
     imsize = x.shape
     _, _, _, x = self.backbone(x)
     b, c, h, w = x.shape
     x_k = self.conv_0(x)
     x1 = self.layer5a(x_k).reshape(b, c, -1)
     x2 = self.layer5b(x_k).reshape(b, c, -1)
     x3 = self.layer5c(x_k).reshape(b, c, -1)
     x4 = self.layer5d(x_k).reshape(b, c, -1)
     x_k = concat([x1, x2, x3, x4], 2).transpose(0, 2, 1)  # b  110 c
     x_q = self.conv_1(x)
     x_q = x_q.reshape(b, c, -1)
     x_attention = nn.bmm(x_k, x_q)  # b 110 N
     x_v = self.conv_2(x)
     x1 = self.layer5a(x_v).reshape(b, c, -1)
     x2 = self.layer5b(x_v).reshape(b, c, -1)
     x3 = self.layer5c(x_v).reshape(b, c, -1)
     x4 = self.layer5d(x_v).reshape(b, c, -1)
     x_v = concat([x1, x2, x3, x4], 2)  # b c 110
     x = nn.bmm(x_v, x_attention).reshape(b, c, h, w)
     x = self.final_conv(x)
     x = nn.resize(x,
                   size=(imsize[2], imsize[3]),
                   mode='bilinear',
                   align_corners=True)
     return x
예제 #2
0
    def execute(self, point_cloud, label):
        B, D, N = point_cloud.size()
        trans = self.stn(point_cloud)
        point_cloud = point_cloud.transpose(0, 2, 1)
        point_cloud = nn.bmm(point_cloud, trans)

        point_cloud = point_cloud.transpose(0, 2, 1)

        out1 = self.relu(self.bn1(self.conv1(point_cloud)))
        out2 = self.relu(self.bn2(self.conv2(out1)))
        out3 = self.relu(self.bn3(self.conv3(out2)))

        trans_feat = self.fstn(out3)
        x = out3.transpose(0, 2, 1)
        net_transformed = nn.bmm(x, trans_feat)
        net_transformed = net_transformed.transpose(0, 2, 1)

        out4 = self.relu(self.bn4(self.conv4(net_transformed)))
        out5 = self.bn5(self.conv5(out4))
        out_max = jt.argmax(out5, 2, keepdims=True)[1]
        out_max = out_max.view(-1, 2048)

        out_max = concat((out_max, label), 1)
        expand = out_max.view(-1, 2048 + 16, 1).repeat(1, 1, N)
        concat_feature = concat([expand, out1, out2, out3, out4, out5], 1)
        net = self.relu(self.bns1(self.convs1(concat_feature)))
        net = self.relu(self.bns2(self.convs2(net)))
        net = self.relu(self.bns3(self.convs3(net)))
        net = self.convs4(net)
        return net
예제 #3
0
 def execute(self, x):
     x_q = self.q_conv(x).permute(0, 2, 1)  # b, n, c
     x_k = self.k_conv(x)  # b, c, n
     x_v = self.v_conv(x)
     energy = nn.bmm(x_q, x_k)  # b, n, n
     attention = self.softmax(energy)
     attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
     x_r = nn.bmm(x_v, attention)  # b, c, n
     x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
     x = x + x_r
     return x
예제 #4
0
파일: pct.py 프로젝트: xiaoxTM/jittor-pcl
 def execute(self, features, points):
     if points is not None:
         features = features + points
     q = self.q_conv(features).permute(0, 2, 1)  # b, n, c
     k = self.k_conv(features)  # b, c, n
     v = self.v_conv(features)
     energy = nn.bmm(q, k)  # b, n, n
     attention = self.softmax(energy)
     attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
     r = nn.bmm(v, attention)  # b, c, n
     r = self.act(self.after_norm(self.trans_conv(features - r)))
     features = features + r
     return features
예제 #5
0
 def calc(use_cuda, a, b, mask):
     jt.flags.use_cuda = use_cuda
     a = jt.array(a)
     b = jt.array(b)
     mask = jt.array(mask)
     c = nn.bmm(a, b)
     da, db = jt.grad(c * mask, [a, b])
     return c.data, da.data, db.data
예제 #6
0
 def execute(self, context, feature):
     batch_size, c, h, w = feature.shape
     origin_feature = feature
     feature = feature.reshape(batch_size, c, -1).transpose(0, 2,
                                                            1)  # b, h*w, c
     context = context.reshape(batch_size, context.shape[1],
                               -1)  # b, n_cls, h*w
     attention = self.softmax(context)
     ocr_context = nn.bmm(attention, feature).transpose(0, 2,
                                                        1)  # b, c, n_cls
     relation = nn.bmm(feature, ocr_context).transpose(0, 2,
                                                       1)  # b, n_cls, h*w
     attention = self.softmax(relation)  #b , n_cls, h*w
     result = nn.bmm(ocr_context, attention).reshape(batch_size, c, h, w)
     result = self.conv_1x1(result)
     result = concat([result, origin_feature], dim=1)
     result = self.last_conv(result)
     return result
예제 #7
0
    def execute(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.reshape(m_batchsize, C, -1)
        proj_key = x.reshape(m_batchsize, C, -1).transpose(0, 2, 1)
        energy = nn.bmm(proj_query, proj_key)
        #energy_new = jt.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy)
        proj_value = x.reshape(m_batchsize, C, -1)

        out = nn.bmm(attention, proj_value)
        out = out.reshape(m_batchsize, C, height, width)

        out = self.gamma * out + x
        return out
예제 #8
0
    def execute(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).reshape(m_batchsize, -1,
                                                width * height).transpose(
                                                    0, 2, 1)
        proj_key = self.key_conv(x).reshape(m_batchsize, -1, width * height)
        energy = nn.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).reshape(m_batchsize, -1,
                                                width * height)

        out = nn.bmm(proj_value, attention.transpose(0, 2, 1))
        out = out.reshape(m_batchsize, C, height, width)

        out = self.gamma * out + x
        return out
예제 #9
0
    def execute(self, x, mask=None):
        b, n, _ = x.shape
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        q = q.reshape(b, n, h, -1)
        q = q.transpose(0, 2, 1, 3)

        k = k.reshape(b, n, h, -1)
        k = k.transpose(0, 2, 1, 3)

        v = v.reshape(b, n, h, -1)
        v = v.transpose(0, 2, 1, 3)

        #b,h,n,d
        d = q.shape[-1]
        q = q.reshape(b * h, n, d)
        k = k.reshape(b * h, n, d).transpose(0, 2, 1)

        dots = nn.bmm(q, k).reshape(b, h, n, n)
        dots = dots * self.scale

        if mask is not None:
            mask = nn.pad(mask.flatten(1), (1, 0), value=1)
            assert mask.shape[-1] == dots.shape[
                -1], 'mask has incorrect shapes'
            mask = mask.unsqueeze(1) * mask.unsqueeze(2)
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        attn = nn.softmax(dots, dim=-1)

        out = nn.bmm(attn.reshape(b * h, n, n),
                     v.reshape(b * h, n, d)).reshape(b, h, n, d)
        out = out.transpose(0, 2, 1, 3).reshape(b, n, h * d)
        out = self.to_out(out)
        return out
예제 #10
0
    def execute(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
                                  c // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        # attn = nn.bmm(q,k.transpose(0,1,3,2))*self.scale
        attn = nn.bmm_transpose(q, k) * self.scale

        attn = nn.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = nn.bmm(attn, v)
        out = out.transpose(0, 2, 1, 3).reshape(b, n, c)
        out = self.proj(out)
        out = self.proj_drop(out)

        return out
    def execute(
        self,
        query,
        key=None,
        value=None,
        key_padding_mask=None,
        incremental_state=None,
        need_weights=True,
        static_kv=False,
        attn_mask=None,
        before_softmax=False,
        need_head_weights=False,
    ):
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.shape
        assert embed_dim == self.embed_dim
        assert list(query.shape) == [tgt_len, bsz, embed_dim]

        assert incremental_state is None, "TODO: incremental_state is not None"
        saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)
        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q = q * self.scaling

        assert self.bias_k is None, "TODO: self.bias_k is not None:"

        q = q.view(tgt_len, bsz * self.num_heads,
                   self.head_dim).transpose(1, 0, 2)
        if k is not None:
            k = k.view(-1, bsz * self.num_heads,
                       self.head_dim).transpose(1, 0, 2)
        if v is not None:
            v = v.view(-1, bsz * self.num_heads,
                       self.head_dim).transpose(1, 0, 2)

        assert saved_state is None, "TODO: saved_state is not None"
        assert k is not None
        src_len = k.shape[1]

        assert key_padding_mask is None, "TODO: key_padding_mask is not None"
        assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"

        attn_weights = nn.bmm(q, k.transpose(0, 2, 1))

        assert list(
            attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]

        assert attn_mask is None, "TODO: attn_mask is not None"
        assert key_padding_mask is None, "TODO: key_padding_mask is not None"

        if before_softmax:
            return attn_weights, v

        attn_weights_float = nn.softmax(attn_weights, dim=-1)
        attn_weights = attn_weights_float.type_as(attn_weights)

        assert v is not None
        attn = nn.bmm(attn_weights, v)
        assert list(
            attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.shape[1] == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len, src_len).transpose(
                                                       1, 0, 2, 3)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dims=[0])

        return attn, attn_weights