Esempio n. 1
0
 def forward(self, mid, ref):
     B, C, H, W = mid.shape
     mid = F.normalize(mid, p=2, axis=1)
     ref = F.normalize(ref, p=2, axis=1)
     cost_volume, ref = compute_cost_volume(
         mid, ref, max_displacement=self.d)  # [B, (2d+1)**2, H, W]
     cost_volume = F.dimshuffle(cost_volume, (0, 2, 3, 1))
     cost_volume = cost_volume.reshape((-1, (2 * self.d + 1)**2))
     # argmax
     indices = F.top_k(cost_volume, k=self.K,
                       descending=True)[1]  # [B*H*W, K]
     del cost_volume
     ref_list = []  # [B, C, H, W]
     origin_i_j = F.arange(0, H * W, 1)  # float32
     origin_i = F.floor(origin_i_j / W)  # (H*W, )
     origin_j = F.mod(origin_i_j, W)  # (H*W, )
     del origin_i_j
     # reshape ref
     ref = ref.reshape((B, C, (H + 2 * self.d) * (W + 2 * self.d)))
     for i in range(self.K):
         index = indices[:, i]  # [B*H*W, ]
         index = index.reshape((-1, H * W))
         index_i = F.floor(index / (2 * self.d + 1)) + origin_i  # [B, H*W]
         index_j = F.mod(index, (2 * self.d + 1)) + origin_j  # [B, H*W]
         # 根据每个pixel的i,j 算出index
         index = index_i * W + index_j  # [B, H*W]
         index = index.astype('int32')
         # add axis
         index = F.add_axis(index, axis=1)  # [B, 1, H*W]
         # broadcast
         index = F.broadcast_to(index, (B, C, H * W))
         # gather
         output = F.gather(ref, axis=2, index=index)  # [B, C, H*W]
         ref_list.append(output.reshape((B, C, H, W)))
     return self.conv(F.concat(ref_list, axis=1))
Esempio n. 2
0
    def forward(self, x):
        B, C, H, W = x.shape
        N = self.frames
        C = C // N
        A2 = F.dimshuffle(self.A2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W)
        B2 = F.dimshuffle(self.B2(x).reshape(B, N, C, H, W), (0, 1, 3, 4, 2)).reshape(B, N*H*W, C)
        A3 = self.A3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W)
        B3 = F.dimshuffle(self.B3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W), (0, 2, 1))

        D2 = F.dimshuffle(self.D2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W)
        D3 = self.D3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W)

        attention2 = F.softmax(F.batched_matrix_mul(A2, B2), axis = -1)  # [B, C, C]
        attention3 = F.softmax(F.batched_matrix_mul(A3, B3), axis = -1)  # [B, N, N]

        E2 = F.dimshuffle(F.batched_matrix_mul(attention2, D2).reshape(B, C, N, H, W), (0, 2, 1, 3, 4)).reshape(B, N*C, H, W)
        E3 = F.batched_matrix_mul(attention3, D3).reshape(B, N*C, H, W)
        return x + E2 + E3
Esempio n. 3
0
    def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes):
        """Compute loss of the head.

            Args:
                cls_scores (Tensor): [B,1,37,37]
                bbox_preds (Tensor): [B,2,37,37]
                centernesses (Tensor): [B,1,37,37]
                gt_bboxes (Tensor): [B,4], in [tl_x, tl_y, br_x, br_y] format.
                
            Returns:
                dict[str, Tensor]: A dictionary of loss components.
        """

        B, _, H, W = cls_scores.shape
        cls_labels, bbox_targets, centerness_targets = self.get_cls_reg_ctr_targets(
            self.fm_ctr, gt_bboxes,
            self.bbox_scale)  # (B, 1, 37, 37), (B, 4, 37, 37), (B,1,37,37)

        # cls
        cls_scores = cls_scores.reshape(B, 1, -1)  # (B, 1, 37*37)
        cls_scores = F.dimshuffle(cls_scores, (0, 2, 1))  # (B, 37*37, 1)
        loss_cls = self.loss_cls(cls_scores, cls_labels.reshape(
            B, -1)) / (B * H * W)

        # reg
        bbox_preds = F.concat([bbox_preds, self.z_size - 1 - bbox_preds],
                              axis=1)  # [B,4,37,37]
        bbox_preds = F.dimshuffle(bbox_preds, (0, 2, 3, 1))
        bbox_preds = bbox_preds.reshape(-1, 4)  # (B*37*37, 4)

        bbox_targets = F.dimshuffle(bbox_targets, (0, 2, 3, 1))
        bbox_targets = bbox_targets.reshape(-1, 4)  # (B*37*37, 4)
        loss_reg = self.loss_bbox(
            bbox_preds, bbox_targets, weight=cls_labels.reshape(
                (B * H * W, ))) / cls_labels.sum()

        # center
        loss_ctr = self.loss_centerness(centernesses,
                                        centerness_targets,
                                        weight=cls_labels) / cls_labels.sum()

        loss = (loss_cls + self.lambda1 * loss_reg + self.lambda2 * loss_ctr)
        return loss, loss_cls, loss_reg, loss_ctr
Esempio n. 4
0
 def forward(self, inputs):
     # N C iH iW
     N, C, iH, iW = inputs.shape
     oH = iH * self.scale
     oW = iW * self.scale
     oC = C // (self.scale ** 2)
     # N C s s iH iW
     output = inputs.reshape(N, oC, self.scale, self.scale, iH, iW)
     # N C iH s iW s
     output = F.dimshuffle(output, (0, 1, 4, 3, 5, 2))
     # N C oH oW
     output = output.reshape(N, oC, oH, oW)
     return output
Esempio n. 5
0
 def forward(self, now_LR, pre_h_SD):
     """
     now_LR: B,3,H,W
     pre_h_SD: B,48,H,W
     """
     batch, C, H, W = pre_h_SD.shape
     kernels = self.conv(now_LR)  # [B, k*k, H, W]
     batchwise_ans = []
     for idx in range(batch):
         kernel = kernels[idx]  # [k*k, H, W]
         kernel = F.dimshuffle(kernel, (1, 2, 0))  # [H, W , k*k]
         kernel = F.reshape(kernel, (H, W, 1, self.K, self.K, 1))
         kernel = F.broadcast_to(kernel, (C, H, W, 1, self.K, self.K, 1))
         batchwise_ans.append(
             F.local_conv2d(
                 F.add_axis(pre_h_SD[idx], 0), kernel, [1, 1], [1, 1],
                 [1, 1]))  # [1, C, H, W]      some bug with padding
     similarity_matrix = F.concat(batchwise_ans, axis=0)  # [B,C,H,W]
     del batchwise_ans
     similarity_matrix = F.sigmoid(similarity_matrix)
     return F.multiply(pre_h_SD, similarity_matrix)