def test_neg_axis(): x = tensor(np.random.normal(0, 1, (32, 5))) y = F.argmax(x, axis=-1) yy = F.argmax(x, axis=1) np.testing.assert_equal(y.numpy(), yy.numpy()) y = F.argmax(x, axis=(-1, -2)) yy = F.argmax(x, axis=(0, 1)) np.testing.assert_equal(y.numpy(), yy.numpy()) y = F.argmin(x, axis=(-1, -2)) yy = F.argmin(x, axis=(0, 1)) np.testing.assert_equal(y.numpy(), yy.numpy())
def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): # stride: 64,32,16,8,4 -> 4, 8, 16, 32 fpn_fms = fpn_fms[1:][::-1] stride = [4, 8, 16, 32] pool_features, rcnn_rois, labels, bbox_targets = roi_pool( fpn_fms, rcnn_rois, stride, (7, 7), 'roi_align', labels, bbox_targets) flatten_feature = F.flatten(pool_features, start_axis=1) roi_feature = F.relu(self.fc1(flatten_feature)) roi_feature = F.relu(self.fc2(roi_feature)) pred_emd_pred_cls_0 = self.emd_pred_cls_0(roi_feature) pred_emd_pred_delta_0 = self.emd_pred_delta_0(roi_feature) pred_emd_pred_cls_1 = self.emd_pred_cls_1(roi_feature) pred_emd_pred_delta_1 = self.emd_pred_delta_1(roi_feature) if self.training: loss0 = emd_loss( pred_emd_pred_delta_0, pred_emd_pred_cls_0, pred_emd_pred_delta_1, pred_emd_pred_cls_1, bbox_targets, labels) loss1 = emd_loss( pred_emd_pred_delta_1, pred_emd_pred_cls_1, pred_emd_pred_delta_0, pred_emd_pred_cls_0, bbox_targets, labels) loss = F.concat([loss0, loss1], axis=1) indices = F.argmin(loss, axis=1) loss_emd = F.indexing_one_hot(loss, indices, 1) loss_emd = loss_emd.sum()/loss_emd.shapeof()[0] loss_dict = {} loss_dict['loss_rcnn_emd'] = loss_emd return loss_dict else: pred_scores_0 = F.softmax(pred_emd_pred_cls_0)[:, 1:].reshape(-1, 1) pred_scores_1 = F.softmax(pred_emd_pred_cls_1)[:, 1:].reshape(-1, 1) pred_delta_0 = pred_emd_pred_delta_0[:, 4:].reshape(-1, 4) pred_delta_1 = pred_emd_pred_delta_1[:, 4:].reshape(-1, 4) target_shape = (rcnn_rois.shapeof()[0], config.num_classes - 1, 4) base_rois = F.add_axis(rcnn_rois[:, 1:5], 1).broadcast(target_shape).reshape(-1, 4) pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True) pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True) pred_bbox_0 = F.concat([pred_bbox_0, pred_scores_0], axis=1) pred_bbox_1 = F.concat([pred_bbox_1, pred_scores_1], axis=1) #[{head0, pre1, tag1}, {head1, pre1, tag1}, {head0, pre1, tag2}, ...] pred_bbox = F.concat((pred_bbox_0, pred_bbox_1), axis=1).reshape(-1,5) return pred_bbox
def run_argmin(): x = F.zeros((100, 100)) x[:] = float("inf") idxs = F.argmin(x, axis=0) return idxs
def get_ground_truth(self, anchors_list, batched_gt_boxes, batched_num_gts): labels_list = [] offsets_list = [] ctrness_list = [] all_level_anchors = F.concat(anchors_list, axis=0) for bid in range(batched_gt_boxes.shape[0]): gt_boxes = batched_gt_boxes[bid, :batched_num_gts[bid]] offsets = self.point_coder.encode( all_level_anchors, F.expand_dims(gt_boxes[:, :4], axis=1)) object_sizes_of_interest = F.concat([ F.broadcast_to( F.expand_dims(mge.tensor(size, dtype=np.float32), axis=0), (anchors_i.shape[0], 2)) for anchors_i, size in zip( anchors_list, self.cfg.object_sizes_of_interest) ], axis=0) max_offsets = F.max(offsets, axis=2) is_cared_in_the_level = ( (max_offsets >= F.expand_dims(object_sizes_of_interest[:, 0], axis=0)) & (max_offsets <= F.expand_dims(object_sizes_of_interest[:, 1], axis=0))) if self.cfg.center_sampling_radius > 0: gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:4]) / 2 is_in_boxes = [] for stride, anchors_i in zip(self.cfg.stride, anchors_list): radius = stride * self.cfg.center_sampling_radius center_boxes = F.concat([ F.maximum(gt_centers - radius, gt_boxes[:, :2]), F.minimum(gt_centers + radius, gt_boxes[:, 2:4]), ], axis=1) center_offsets = self.point_coder.encode( anchors_i, F.expand_dims(center_boxes, axis=1)) is_in_boxes.append(F.min(center_offsets, axis=2) > 0) is_in_boxes = F.concat(is_in_boxes, axis=1) else: is_in_boxes = F.min(offsets, axis=2) > 0 gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # FIXME: use repeat instead of broadcast_to areas = F.broadcast_to(F.expand_dims(gt_area, axis=1), offsets.shape[:2]) areas[~is_cared_in_the_level] = float("inf") areas[~is_in_boxes] = float("inf") match_indices = F.argmin(areas, axis=0) gt_boxes_matched = gt_boxes[match_indices] anchor_min_area = F.indexing_one_hot(areas, match_indices, axis=0) labels = gt_boxes_matched[:, 4].astype(np.int32) labels[anchor_min_area == float("inf")] = 0 offsets = self.point_coder.encode(all_level_anchors, gt_boxes_matched[:, :4]) left_right = offsets[:, [0, 2]] top_bottom = offsets[:, [1, 3]] ctrness = F.sqrt( F.maximum( F.min(left_right, axis=1) / F.max(left_right, axis=1), 0) * F.maximum( F.min(top_bottom, axis=1) / F.max(top_bottom, axis=1), 0)) labels_list.append(labels) offsets_list.append(offsets) ctrness_list.append(ctrness) return ( F.stack(labels_list, axis=0).detach(), F.stack(offsets_list, axis=0).detach(), F.stack(ctrness_list, axis=0).detach(), )
def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): # stride: 64,32,16,8,4 -> 4, 8, 16, 32 fpn_fms = fpn_fms[1:][::-1] stride = [4, 8, 16, 32] pool_features, rcnn_rois, labels, bbox_targets = roi_pool( fpn_fms, rcnn_rois, stride, (7, 7), 'roi_align', labels, bbox_targets) flatten_feature = F.flatten(pool_features, start_axis=1) roi_feature = F.relu(self.fc1(flatten_feature)) roi_feature = F.relu(self.fc2(roi_feature)) pred_emd_pred_cls_0 = self.emd_pred_cls_0(roi_feature) pred_emd_pred_delta_0 = self.emd_pred_delta_0(roi_feature) pred_emd_pred_cls_1 = self.emd_pred_cls_1(roi_feature) pred_emd_pred_delta_1 = self.emd_pred_delta_1(roi_feature) pred_emd_scores_0 = F.softmax(pred_emd_pred_cls_0) pred_emd_scores_1 = F.softmax(pred_emd_pred_cls_1) # make refine feature box_0 = F.concat((pred_emd_pred_delta_0, pred_emd_scores_0[:, 1][:, None]), axis=1)[:, None, :] box_1 = F.concat((pred_emd_pred_delta_1, pred_emd_scores_1[:, 1][:, None]), axis=1)[:, None, :] boxes_feature_0 = box_0.broadcast( box_0.shapeof()[0], 4, box_0.shapeof()[-1]).reshape(box_0.shapeof()[0], -1) boxes_feature_1 = box_1.broadcast( box_1.shapeof()[0], 4, box_1.shapeof()[-1]).reshape(box_1.shapeof()[0], -1) boxes_feature_0 = F.concat((roi_feature, boxes_feature_0), axis=1) boxes_feature_1 = F.concat((roi_feature, boxes_feature_1), axis=1) refine_feature_0 = F.relu(self.fc3(boxes_feature_0)) refine_feature_1 = F.relu(self.fc3(boxes_feature_1)) # refine pred_ref_pred_cls_0 = self.ref_pred_cls_0(refine_feature_0) pred_ref_pred_delta_0 = self.ref_pred_delta_0(refine_feature_0) pred_ref_pred_cls_1 = self.ref_pred_cls_1(refine_feature_1) pred_ref_pred_delta_1 = self.ref_pred_delta_1(refine_feature_1) if self.training: loss0 = emd_loss( pred_emd_pred_delta_0, pred_emd_pred_cls_0, pred_emd_pred_delta_1, pred_emd_pred_cls_1, bbox_targets, labels) loss1 = emd_loss( pred_emd_pred_delta_1, pred_emd_pred_cls_1, pred_emd_pred_delta_0, pred_emd_pred_cls_0, bbox_targets, labels) loss2 = emd_loss( pred_ref_pred_delta_0, pred_ref_pred_cls_0, pred_ref_pred_delta_1, pred_ref_pred_cls_1, bbox_targets, labels) loss3 = emd_loss( pred_ref_pred_delta_1, pred_ref_pred_cls_1, pred_ref_pred_delta_0, pred_ref_pred_cls_0, bbox_targets, labels) loss_rcnn = F.concat([loss0, loss1], axis=1) loss_ref = F.concat([loss2, loss3], axis=1) indices_rcnn = F.argmin(loss_rcnn, axis=1) indices_ref = F.argmin(loss_ref, axis=1) loss_rcnn = F.indexing_one_hot(loss_rcnn, indices_rcnn, 1) loss_ref = F.indexing_one_hot(loss_ref, indices_ref, 1) loss_rcnn = loss_rcnn.sum()/loss_rcnn.shapeof()[0] loss_ref = loss_ref.sum()/loss_ref.shapeof()[0] loss_dict = {} loss_dict['loss_rcnn_emd'] = loss_rcnn loss_dict['loss_ref_emd'] = loss_ref return loss_dict else: pred_ref_scores_0 = F.softmax(pred_ref_pred_cls_0) pred_ref_scores_1 = F.softmax(pred_ref_pred_cls_1) pred_bbox_0 = restore_bbox(rcnn_rois[:, 1:5], pred_ref_pred_delta_0, True) pred_bbox_1 = restore_bbox(rcnn_rois[:, 1:5], pred_ref_pred_delta_1, True) pred_bbox_0 = F.concat([pred_bbox_0, pred_ref_scores_0[:, 1].reshape(-1,1)], axis=1) pred_bbox_1 = F.concat([pred_bbox_1, pred_ref_scores_1[:, 1].reshape(-1,1)], axis=1) pred_bbox = F.concat((pred_bbox_0, pred_bbox_1), axis=1).reshape(-1,5) return pred_bbox
def fwd(data): return F.argmax(data), F.argmin(data)