def forward(self, mid, ref): B, C, H, W = mid.shape mid = F.normalize(mid, p=2, axis=1) ref = F.normalize(ref, p=2, axis=1) cost_volume, ref = compute_cost_volume( mid, ref, max_displacement=self.d) # [B, (2d+1)**2, H, W] cost_volume = F.dimshuffle(cost_volume, (0, 2, 3, 1)) cost_volume = cost_volume.reshape((-1, (2 * self.d + 1)**2)) # argmax indices = F.top_k(cost_volume, k=self.K, descending=True)[1] # [B*H*W, K] del cost_volume ref_list = [] # [B, C, H, W] origin_i_j = F.arange(0, H * W, 1) # float32 origin_i = F.floor(origin_i_j / W) # (H*W, ) origin_j = F.mod(origin_i_j, W) # (H*W, ) del origin_i_j # reshape ref ref = ref.reshape((B, C, (H + 2 * self.d) * (W + 2 * self.d))) for i in range(self.K): index = indices[:, i] # [B*H*W, ] index = index.reshape((-1, H * W)) index_i = F.floor(index / (2 * self.d + 1)) + origin_i # [B, H*W] index_j = F.mod(index, (2 * self.d + 1)) + origin_j # [B, H*W] # 根据每个pixel的i,j 算出index index = index_i * W + index_j # [B, H*W] index = index.astype('int32') # add axis index = F.add_axis(index, axis=1) # [B, 1, H*W] # broadcast index = F.broadcast_to(index, (B, C, H * W)) # gather output = F.gather(ref, axis=2, index=index) # [B, C, H*W] ref_list.append(output.reshape((B, C, H, W))) return self.conv(F.concat(ref_list, axis=1))
def forward(self, x): B, C, H, W = x.shape N = self.frames C = C // N A2 = F.dimshuffle(self.A2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W) B2 = F.dimshuffle(self.B2(x).reshape(B, N, C, H, W), (0, 1, 3, 4, 2)).reshape(B, N*H*W, C) A3 = self.A3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W) B3 = F.dimshuffle(self.B3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W), (0, 2, 1)) D2 = F.dimshuffle(self.D2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W) D3 = self.D3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W) attention2 = F.softmax(F.batched_matrix_mul(A2, B2), axis = -1) # [B, C, C] attention3 = F.softmax(F.batched_matrix_mul(A3, B3), axis = -1) # [B, N, N] E2 = F.dimshuffle(F.batched_matrix_mul(attention2, D2).reshape(B, C, N, H, W), (0, 2, 1, 3, 4)).reshape(B, N*C, H, W) E3 = F.batched_matrix_mul(attention3, D3).reshape(B, N*C, H, W) return x + E2 + E3
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes): """Compute loss of the head. Args: cls_scores (Tensor): [B,1,37,37] bbox_preds (Tensor): [B,2,37,37] centernesses (Tensor): [B,1,37,37] gt_bboxes (Tensor): [B,4], in [tl_x, tl_y, br_x, br_y] format. Returns: dict[str, Tensor]: A dictionary of loss components. """ B, _, H, W = cls_scores.shape cls_labels, bbox_targets, centerness_targets = self.get_cls_reg_ctr_targets( self.fm_ctr, gt_bboxes, self.bbox_scale) # (B, 1, 37, 37), (B, 4, 37, 37), (B,1,37,37) # cls cls_scores = cls_scores.reshape(B, 1, -1) # (B, 1, 37*37) cls_scores = F.dimshuffle(cls_scores, (0, 2, 1)) # (B, 37*37, 1) loss_cls = self.loss_cls(cls_scores, cls_labels.reshape( B, -1)) / (B * H * W) # reg bbox_preds = F.concat([bbox_preds, self.z_size - 1 - bbox_preds], axis=1) # [B,4,37,37] bbox_preds = F.dimshuffle(bbox_preds, (0, 2, 3, 1)) bbox_preds = bbox_preds.reshape(-1, 4) # (B*37*37, 4) bbox_targets = F.dimshuffle(bbox_targets, (0, 2, 3, 1)) bbox_targets = bbox_targets.reshape(-1, 4) # (B*37*37, 4) loss_reg = self.loss_bbox( bbox_preds, bbox_targets, weight=cls_labels.reshape( (B * H * W, ))) / cls_labels.sum() # center loss_ctr = self.loss_centerness(centernesses, centerness_targets, weight=cls_labels) / cls_labels.sum() loss = (loss_cls + self.lambda1 * loss_reg + self.lambda2 * loss_ctr) return loss, loss_cls, loss_reg, loss_ctr
def forward(self, inputs): # N C iH iW N, C, iH, iW = inputs.shape oH = iH * self.scale oW = iW * self.scale oC = C // (self.scale ** 2) # N C s s iH iW output = inputs.reshape(N, oC, self.scale, self.scale, iH, iW) # N C iH s iW s output = F.dimshuffle(output, (0, 1, 4, 3, 5, 2)) # N C oH oW output = output.reshape(N, oC, oH, oW) return output
def forward(self, now_LR, pre_h_SD): """ now_LR: B,3,H,W pre_h_SD: B,48,H,W """ batch, C, H, W = pre_h_SD.shape kernels = self.conv(now_LR) # [B, k*k, H, W] batchwise_ans = [] for idx in range(batch): kernel = kernels[idx] # [k*k, H, W] kernel = F.dimshuffle(kernel, (1, 2, 0)) # [H, W , k*k] kernel = F.reshape(kernel, (H, W, 1, self.K, self.K, 1)) kernel = F.broadcast_to(kernel, (C, H, W, 1, self.K, self.K, 1)) batchwise_ans.append( F.local_conv2d( F.add_axis(pre_h_SD[idx], 0), kernel, [1, 1], [1, 1], [1, 1])) # [1, C, H, W] some bug with padding similarity_matrix = F.concat(batchwise_ans, axis=0) # [B,C,H,W] del batchwise_ans similarity_matrix = F.sigmoid(similarity_matrix) return F.multiply(pre_h_SD, similarity_matrix)