예제 #1
0
def get_pred_result(hm_pred, offset_pred, wh_pred, k=100):
    ctx = hm_pred.context
    batch_size, num_classes, _, _ = hm_pred.shape
    topk_cat_x_idx, topk_cat_y_idx, cls_id = topk(hm_pred, k=k)
    
    batch_index = nd.arange(batch_size)
    batch_indices = nd.repeat(batch_index, repeats=num_classes)
    batch_indices = nd.reshape(batch_indices, (1, batch_size*k))
    batch_indices = batch_indices.as_in_context(ctx)
    
    cls_id = nd.reshape(cls_id, (1, batch_size*k))
    topk_cat_y_idx = nd.reshape(topk_cat_y_idx, (1, batch_size*k))
    topk_cat_x_idx = nd.reshape(topk_cat_x_idx, (1, batch_size*k))
    
    score_indices = nd.concat(batch_indices, cls_id, topk_cat_y_idx, topk_cat_x_idx, dim=0)
    
    scores = nd.gather_nd(hm_pred, score_indices)
    
    fake_idx_0 = nd.zeros_like(nd.arange(batch_size*k)).reshape((1, -1))
    fake_idx_0 = fake_idx_0.as_in_context(ctx)
    fake_idx_1 = nd.ones((1, batch_size*k))
    fake_idx_1 = fake_idx_1.as_in_context(ctx)

    fake_indices_0 = nd.concat(batch_indices, fake_idx_0, topk_cat_y_idx, topk_cat_x_idx, dim=0)
    fake_indices_1 = nd.concat(batch_indices, fake_idx_1, topk_cat_y_idx, topk_cat_x_idx, dim=0)
    x_offset = nd.gather_nd(offset_pred, fake_indices_0)
    y_offset = nd.gather_nd(offset_pred, fake_indices_1)

    h = nd.gather_nd(wh_pred, fake_indices_0)
    w = nd.gather_nd(wh_pred, fake_indices_1)

    x_offset_ = nd.broadcast_mul(topk_cat_x_idx, x_offset)
    y_offset_ = nd.broadcast_mul(topk_cat_y_idx, y_offset)

    topk_cat_x_idx = nd.broadcast_add(topk_cat_x_idx, x_offset_)
    topk_cat_y_idx = nd.broadcast_add(topk_cat_y_idx, y_offset_)

    xmin = topk_cat_x_idx - w/2
    ymin = topk_cat_y_idx - h/2
    xmax = topk_cat_x_idx + w/2
    ymax = topk_cat_y_idx + h/2
    
    xmin = nd.reshape(xmin, (batch_size, k)).expand_dims(axis=-1)
    ymin = nd.reshape(ymin, (batch_size, k)).expand_dims(axis=-1)
    xmax = nd.reshape(xmax, (batch_size, k)).expand_dims(axis=-1)
    ymax = nd.reshape(ymax, (batch_size, k)).expand_dims(axis=-1)
    cls_id = nd.reshape(cls_id, (batch_size, k)).expand_dims(axis=-1)
    scores = nd.reshape(scores, (batch_size, k)).expand_dims(axis=-1)

    results = nd.concat(xmin, ymin, xmax, ymax, cls_id, scores, dim=-1)

    return results
예제 #2
0
def refine_bbox_nd(bbox, bbox_delta, im_info=None, means=None, stds=None):

    xmin, ymin, xmax, ymax = nd.split(data=bbox, num_outputs=4, axis=1)
    bbox_width = xmax - xmin + 1.
    bbox_height = ymax - ymin + 1.
    center_x = 0.5 * (xmin + xmax)
    center_y = 0.5 * (ymin + ymax)

    bbox_delta_reshape = nd.Reshape(data=bbox_delta, shape=(0, -1, 4))
    dx, dy, dw, dh = nd.split(data=bbox_delta_reshape,
                              num_outputs=4,
                              axis=2,
                              squeeze_axis=1)
    if (means is not None) and (stds is not None):
        dx = dx * stds[0] + means[0]
        dy = dy * stds[1] + means[1]
        dw = dw * stds[2] + means[2]
        dh = dh * stds[3] + means[3]

    refine_center_x = nd.broadcast_add(lhs=center_x,
                                       rhs=nd.broadcast_mul(lhs=bbox_width,
                                                            rhs=dx))
    refine_center_y = nd.broadcast_add(lhs=center_y,
                                       rhs=nd.broadcast_mul(lhs=bbox_height,
                                                            rhs=dy))
    refined_width = nd.broadcast_mul(lhs=bbox_width, rhs=nd.exp(dw))
    refined_height = nd.broadcast_mul(lhs=bbox_height, rhs=nd.exp(dh))
    w_offset = 0.5 * (refined_width - 1.)
    h_offset = 0.5 * (refined_height - 1.)
    refined_xmin = nd.expand_dims(refine_center_x - w_offset, axis=1)
    refined_ymin = nd.expand_dims(refine_center_y - h_offset, axis=1)
    refined_xmax = nd.expand_dims(refine_center_x + w_offset, axis=1)
    refined_ymax = nd.expand_dims(refine_center_y + h_offset, axis=1)

    refined_bbox = nd.concat(refined_xmin,
                             refined_ymin,
                             refined_xmax,
                             refined_ymax,
                             dim=1)
    if im_info is not None:
        # assume im_info [[height, width, scale]] with shape (1,3)
        im_hw = nd.slice_axis(im_info, axis=1, begin=0, end=2)
        im_wh = nd.reverse(im_hw, axis=1)
        im_wh = im_wh - 1.
        im_wh = nd.tile(data=im_wh, reps=(1, 2))
        im_wh = nd.Reshape(im_wh, shape=(1, 4, 1))
        refined_bbox = nd.broadcast_minimum(lhs=refined_bbox, rhs=im_wh)
        refined_bbox = nd.broadcast_maximum(lhs=refined_bbox,
                                            rhs=nd.zeros_like(refined_bbox))
    # print refined_bbox.debug_str()
    return refined_bbox
def lifted_loss(net,data,label):
    label = label.reshape(-1, 1)
    label_mat = nd.equal(label, label.T).astype('float32')
    vec = net(data)
    dist_self = nd.sum(nd.square(vec), axis=1, keepdims=True)
    dist_mat = nd.broadcast_add(dist_self, dist_self.T) - 2 * nd.dot(vec, vec.T)
    p_row = nd.sum(nd.exp(1.0 - (dist_mat)) * (1 - label_mat), 1, True)
    loss = 1000 * (nd.log(p_row + p_row.T + 1e-5) + dist_mat) * label_mat / (2 * label_mat.sum())
    return loss
def contrastive_loss(net,data,label):
    label = label.reshape(-1, 1)
    label_mat = nd.relu(-nd.abs(label - label.T) + 1).astype('float32')
    vec = net(data)
    vec = nd.Flatten(vec)
    dist_self = nd.sum(nd.square(vec), axis=1, keepdims=True)
    dist_mat = nd.broadcast_add(dist_self, dist_self.T) - 2 * nd.dot(vec, vec.T)
    loss = label_mat * dist_mat + nd.relu((1.0 - dist_mat * (1 - label_mat))).astype('float32')
    return loss
def accuracy_metric(gallery_features, gallery_label, query_features, query_label):
    B1 = nd.sum(nd.square(gallery_features), axis=1, keepdims=True)
    B2 = nd.sum(nd.square(query_features), axis=1, keepdims=True)
    dist_mat = nd.broadcast_add(B2, B1.T) - 2 * nd.dot(query_features, gallery_features.T)
    label_mask = nd.broadcast_equal(dist_mat, nd.min(dist_mat, axis=1, keepdims=True)).astype('float32')
    pre_label_mat = nd.broadcast_mul(label_mask, gallery_label.reshape(1, -1).astype('float32'))
    pre_label_list = nd.max(pre_label_mat, axis=1)
    cor_num = nd.sum(nd.equal(pre_label_list, query_label.astype('float32')))
    return cor_num.asnumpy()[0] / len(query_label)
def Treplit_hard_loss(net,data,label):
    label = label.reshape(-1, 1)
    label_mat = nd.equal(label, label.T).astype('float32')
    vec = net(data)
    dist_self = nd.sum(nd.square(vec), axis=1, keepdims=True)
    dist_mat = nd.broadcast_add(dist_self, dist_self.T) - 2 * nd.dot(vec, vec.T)
    p_min=nd.log(nd.sum(label_mat*nd.exp(dist_mat),axis=1))
    p_max=nd.log(nd.sum((1-label_mat)*nd.exp(-dist_mat)),axis=1)
    loss=nd.relu(p_min+p_max+1)
    return loss
    def create_coords_pyramid(self,
                              batch_size,
                              shape,
                              n_blocks=4,
                              n_lvls=[4, 4, 3, 2],
                              r=4):
        ori_h, ori_w = shape

        h = ori_h // 4
        w = ori_w // 4

        dx = nd.arange(-r, r + 1).reshape((1, 2 * r + 1)).repeat(2 * r + 1,
                                                                 axis=0)
        dy = nd.arange(-r, r + 1).reshape((1, 2 * r + 1)).repeat(2 * r + 1,
                                                                 axis=1)

        delta = nd.stack(dx, dy, axis=-1)
        delta_lvl = delta.reshape((1, 2 * r + 1, 2 * r + 1, 2))

        coords_pyramid = []
        for i in range(n_blocks):  # from h/4 (block2 ~ block6) h/64
            # x-y indexing
            gx = nd.arange(w).reshape((1, w)).repeat(h, axis=0)
            gy = nd.arange(h).reshape((h, 1)).repeat(w, axis=1)

            coords = nd.stack(gx, gy, axis=-1)
            coords = nd.expand_dims(coords, axis=0)
            coords = nd.tile(coords, (batch_size, 1, 1, 1))
            coords = nd.reshape(coords, (-1, 1, 1, 2))

            n_lvl = n_lvls[i]
            for j in range(n_lvl):
                centroid_lvl = coords / 2**j
                coords_lvl = nd.broadcast_add(centroid_lvl, delta_lvl)
                coord_x, coord_y = coords_lvl.split(axis=-1, num_outputs=2)

                # normalized, w need to be divided
                coord_x = 2 * coord_x / ((w / 2**j) - 1) - 1
                coord_y = 2 * coord_y / ((h / 2**j) - 1) - 1
                coord = nd.concat(coord_x, coord_y, dim=-1)
                coord = nd.transpose(coord, (0, 3, 1, 2))

                coords_pyramid.append(coord)

            h, w = h // 2, w // 2

        return coords_pyramid
예제 #8
0
    def forward(self, x, gt_box=None, obj_box=None):
        """Forward Faster-RCNN network.

        The behavior during traing and inference is different.

        Parameters
        ----------
        x : mxnet.nd.NDArray or mxnet.symbol
            The network input tensor.
        gt_box : mxnet.nd.NDArray or mxnet.symbol
            The ground-truth bbox tensor with shape (1, M, 4).
        obj_box : mxnet.nd.NDArray or mxnet.symbol
            The object bbox tensor with shape (1, N, 4).

        Returns
        -------
        (ids, scores, bboxes)
            During inference, returns final class id, confidence scores, bounding
            boxes.

        """
        #####################################################
        # Extracting Features from First 4 Layers of Resnet #
        #####################################################

        feat = self.features(x)

        ################################################
        # ROI Pooling of features using bounding boxes #
        ################################################

        rsn_box = obj_box.reshape((-1, 4))

        # create batchid
        rsn_batchid = F.zeros_like(rsn_box.slice_axis(axis=-1, begin=0, end=1))
        rsn_rois = F.concat(*[rsn_batchid, rsn_box], dim=-1)
        gt_batchid = F.zeros_like(gt_box.slice_axis(axis=-1, begin=0, end=1))
        gt_rois = F.concat(
            *[gt_batchid.reshape((-1, 1)),
              gt_box.reshape((-1, 4))], dim=-1)

        # ROI features
        if self._roi_mode == 'pool':
            pooled_feat = F.ROIPooling(feat, gt_rois, self._roi_size,
                                       1. / self._stride)
            pooled_ctx_feat = F.ROIPooling(feat, rsn_rois, self._roi_size,
                                           1. / self._stride)
        elif self._roi_mode == 'align':
            pooled_feat = F.contrib.ROIAlign(feat,
                                             gt_rois,
                                             self._roi_size,
                                             1. / self._stride,
                                             sample_ratio=2)
            pooled_ctx_feat = F.contrib.ROIAlign(feat,
                                                 rsn_rois,
                                                 self._roi_size,
                                                 1. / self._stride,
                                                 sample_ratio=2)
        else:
            raise ValueError("Invalid roi mode: {}".format(self._roi_mode))

        ############################################################
        # Passing the Pooled features through Last Layer of Resnet #
        ############################################################

        # RCNN prediction
        top_feat = self.top_features(pooled_feat)
        # contextual region prediction
        top_ctx_feat = self.top_features(pooled_ctx_feat)

        ##########################
        # GLOBAL AVERAGE POOLING #
        ##########################

        if self.use_global_avg_pool:
            top_feat = self.global_avg_pool(top_feat)
            top_ctx_feat = self.global_avg_pool(top_ctx_feat)

        top_feat = self.fc(top_feat)
        top_ctx_feat = self.fc_ctx(top_ctx_feat)

        #######################################
        # Adding Class token to each Sequence #
        #######################################
        top_ctx_feat = mx.nd.Concat(self.embedding(
            mx.nd.array([1025], ctx=mx.gpu())),
                                    top_ctx_feat,
                                    dim=0)

        ###############################################
        # Positional Embedding of pooled eye features #
        ###############################################

        pos = self.get_indices(rsn_box)
        pos_emb = self.embedding(pos).reshape(-1, 1024)
        top_ctx_feat = top_ctx_feat + pos_emb

        ########################################################################
        # Passing the pooled eye features(with context information) in Encoder #
        ########################################################################

        if self._additional_output:
            relation_ctx_feat, relation = self.relation(top_ctx_feat)
        else:
            relation_ctx_feat = self.relation(top_ctx_feat)

        # top_feat = top_feat + relation_feat
        # top_ctx_feat = top_ctx_feat + relation_ctx_feat

        #####################################################
        # Using the cls_token to enhance our human features #
        #####################################################
        top_feat = F.broadcast_add(top_feat, relation_ctx_feat[0])

        ##################
        # Classification #
        ##################

        cls_pred = self.class_predictor(top_feat)
        # scanpath = self.to_scanpath(relation_ctx_feat[1:]).reshape(1,-1)
        # scanpath = mx.nd.tile(scanpath, (cls_pred.shape[0], 1))
        # cls_pred = mx.nd.Concat(cls_pred, scanpath, dim=1)

        # ctx_cls_pred = self.ctx_class_predictor(top_ctx_feat)

        # cls_pred (B * N, C) -> (B, N, C)
        cls_pred = cls_pred.reshape(
            (self._max_batch, -1, self.num_class))  #+scanpath.shape[1]))
        # ctx_cls_pred = ctx_cls_pred.reshape((self._max_batch, -1, self.num_class))

        # ctx_cls_pred = ctx_cls_pred.max(axis=1, keepdims=True)
        # cls_pred = F.broadcast_add(cls_pred, ctx_cls_pred)
        #         extra_data = self.to_scanpath(relation_ctx_feat[1:])

        if self._additional_output:
            return cls_pred, relation  #, extra_data
        return cls_pred  #, extra_data
예제 #9
0
    def forward(self, is_train, req, in_data, out_data, aux):
        nms_start_time = time.time()
        #inputs
        cls_score = in_data[0]
        bbox_pred = in_data[1]
        rois = in_data[2]
        im_info = in_data[3]
        fc_all_2_relu = in_data[4]
        nms_rank_weight = in_data[5]
        nms_rank_bias = in_data[6]
        roi_feat_embedding_weight = in_data[7]
        roi_feat_embedding_bias = in_data[8]
        nms_pair_pos_fc1_1_weight = in_data[9]
        nms_pair_pos_fc1_1_bias = in_data[10]
        nms_query_1_weight = in_data[11]
        nms_query_1_bias = in_data[12]
        nms_key_1_weight = in_data[13]
        nms_key_1_bias = in_data[14]
        nms_linear_out_1_weight = in_data[15]
        nms_linear_out_1_bias = in_data[16]
        nms_logit_weight = in_data[17]
        nms_logit_bias = in_data[18]
        if self.has_non_gt_index:
            non_gt_index = in_data[19]
        else:
            non_gt_index = None

        if self.nongt_dim is not None:
            cls_score_nongt = nd.slice_axis(data=cls_score,
                                            axis=0,
                                            begin=0,
                                            end=self.nongt_dim)
            # cls_score_nongt = monitor_wrapper(cls_score_nongt, 'cls_score_nongt')
            bbox_pred_nongt = nd.slice_axis(data=bbox_pred,
                                            axis=0,
                                            begin=0,
                                            end=self.nongt_dim)
        elif non_gt_index is not None:
            cls_score_nongt = nd.take(a=cls_score, indices=non_gt_index)
            bbox_pred_nongt = nd.take(a=bbox_pred, indices=non_gt_index)
        else:
            cls_score_nongt = cls_score
            bbox_pred_nongt = bbox_pred
        bbox_pred_nongt = nd.BlockGrad(bbox_pred_nongt)

        # remove batch idx and gt roi
        sliced_rois = nd.slice_axis(data=rois, axis=1, begin=1, end=None)
        if self.nongt_dim is not None:
            sliced_rois = nd.slice_axis(data=sliced_rois,
                                        axis=0,
                                        begin=0,
                                        end=self.nongt_dim)
        elif non_gt_index is not None:
            sliced_rois = nd.take(a=sliced_rois, indices=non_gt_index)
        # bbox_pred_nobg, [num_rois, 4*(num_reg_classes-1)]
        bbox_pred_nobg = nd.slice_axis(data=bbox_pred_nongt,
                                       axis=1,
                                       begin=4,
                                       end=None)
        # [num_boxes, 4, num_reg_classes-1]
        refined_bbox = refine_bbox_nd(sliced_rois,
                                      bbox_pred_nobg,
                                      im_info,
                                      means=self.bbox_means,
                                      stds=self.bbox_stds)
        # softmax cls_score to cls_prob, [num_rois, num_classes]
        cls_prob = nd.softmax(data=cls_score_nongt, axis=-1)
        cls_prob_nobg = nd.slice_axis(cls_prob, axis=1, begin=1, end=None)
        sorted_cls_prob_nobg = nd.sort(data=cls_prob_nobg,
                                       axis=0,
                                       is_ascend=False)
        # sorted_score, [first_n, num_fg_classes]
        sorted_score = nd.slice_axis(sorted_cls_prob_nobg,
                                     axis=0,
                                     begin=0,
                                     end=self.first_n,
                                     name='sorted_score')
        max_score_per_class = sorted_score.max(axis=0)
        max_score_per_class_numpy = max_score_per_class.asnumpy()

        valid_class_thresh = self.class_thresh
        valid_class_thresh = np.minimum(valid_class_thresh,
                                        max_score_per_class_numpy.max())
        valid_class_indices = np.where(
            max_score_per_class_numpy >= valid_class_thresh)[0]
        invalid_class_indices = np.where(
            max_score_per_class_numpy < valid_class_thresh)[0]
        num_valid_classes = len(valid_class_indices)
        valid_class_indices_nd = nd.array(valid_class_indices,
                                          ctx=sorted_score.context)

        # sort by score
        rank_indices = nd.argsort(data=cls_prob_nobg, axis=0, is_ascend=False)
        # first_rank_indices, [first_n, num_fg_classes]
        first_rank_indices = nd.slice_axis(rank_indices,
                                           axis=0,
                                           begin=0,
                                           end=self.first_n)
        valid_first_rank_indices = first_rank_indices.transpose().take(
            valid_class_indices_nd).transpose()

        # sorted_bbox, [first_n, num_fg_classes, 4, num_reg_classes-1]
        sorted_bbox = nd.take(a=refined_bbox, indices=first_rank_indices)
        if self.class_agnostic:
            # sorted_bbox, [first_n, num_fg_classes, 4]
            sorted_bbox = nd.Reshape(sorted_bbox,
                                     shape=(0, 0, 0),
                                     name='sorted_bbox')
        else:
            cls_mask = nd.arange(0, self.num_fg_classes)
            cls_mask = nd.Reshape(cls_mask, shape=(1, -1, 1))
            cls_mask = nd.broadcast_to(cls_mask, shape=(self.first_n, 0, 4))
            # sorted_bbox, [first_n, num_fg_classes, 4]
            sorted_bbox = nd.pick(data=sorted_bbox,
                                  name='sorted_bbox',
                                  index=cls_mask,
                                  axis=3)

        valid_sorted_bbox = sorted_bbox.transpose(
            (1, 0, 2)).take(valid_class_indices_nd).transpose((1, 0, 2))

        # sorted_bbox = monitor_wrapper(sorted_bbox, 'sorted_bbox')
        # nms_rank_embedding, [first_n, 1024]
        nms_rank_embedding = extract_rank_embedding_nd(self.first_n, 1024)
        # nms_rank_feat, [first_n, 1024]
        nms_rank_feat = nd.FullyConnected(name='nms_rank',
                                          data=nms_rank_embedding,
                                          num_hidden=128,
                                          weight=nms_rank_weight,
                                          bias=nms_rank_bias)
        # nms_position_matrix, [num_valid_classes, first_n, first_n, 4]
        nms_position_matrix = extract_multi_position_matrix_nd(
            valid_sorted_bbox)
        # roi_feature_embedding, [num_rois, 1024]
        # fc_all_2_relu = monitor_wrapper(fc_all_2_relu, 'fc_all_2_relu')
        roi_feat_embedding = nd.FullyConnected(
            name='roi_feat_embedding',
            data=fc_all_2_relu,
            num_hidden=128,
            weight=roi_feat_embedding_weight,
            bias=roi_feat_embedding_bias)
        # sorted_roi_feat, [first_n, num_valid_classes, 128]
        sorted_roi_feat = nd.take(a=roi_feat_embedding,
                                  indices=valid_first_rank_indices)

        # vectorized nms
        # nms_embedding_feat, [first_n, num_valid_classes, 128]
        nms_embedding_feat = nd.broadcast_add(lhs=sorted_roi_feat,
                                              rhs=nd.expand_dims(nms_rank_feat,
                                                                 axis=1))
        # nms_attention_1, [first_n, num_valid_classes, 1024]
        nms_attention_1 = nms_attention_nd(
            nms_embedding_feat,
            nms_position_matrix,
            nms_pair_pos_fc1_1_weight,
            nms_pair_pos_fc1_1_bias,
            nms_query_1_weight,
            nms_query_1_bias,
            nms_key_1_weight,
            nms_key_1_bias,
            nms_linear_out_1_weight,
            nms_linear_out_1_bias,
            num_rois=self.first_n,
            index=1,
            group=self.nms_attention_group,
            dim=self.nms_attention_dim,
            fc_dim=self.nms_attention_fc_dim,
            feat_dim=self.nms_attention_feat_dim)
        nms_all_feat_1 = nms_embedding_feat + nms_attention_1
        nms_all_feat_1_relu = nd.Activation(data=nms_all_feat_1,
                                            act_type='relu',
                                            name='nms_all_feat_1_relu')
        # [first_n * num_valid_classes, 1024]
        nms_all_feat_1_relu_reshape = nd.Reshape(nms_all_feat_1_relu,
                                                 shape=(-3, -2))
        # logit, [first_n * num_valid_classes, num_thresh]
        nms_conditional_logit = nd.FullyConnected(
            name='nms_logit',
            data=nms_all_feat_1_relu_reshape,
            num_hidden=self.num_thresh,
            weight=nms_logit_weight,
            bias=nms_logit_bias)
        # logit_reshape, [first_n, num_valid_classes, num_thresh]
        nms_conditional_logit_reshape = nd.Reshape(nms_conditional_logit,
                                                   shape=(self.first_n,
                                                          num_valid_classes,
                                                          self.num_thresh))
        nms_conditional_score = nd.Activation(
            data=nms_conditional_logit_reshape,
            act_type='sigmoid',
            name='nms_conditional_score')
        if num_valid_classes == self.num_fg_classes:
            full_nms_conditional_score = nms_conditional_score
        else:
            full_nms_conditional_score = nd.concat(
                nms_conditional_score,
                nd.zeros(
                    (self.first_n, self.num_fg_classes - num_valid_classes,
                     self.num_thresh),
                    ctx=nms_conditional_score.context),
                dim=1)

        all_indexes = np.concatenate(
            (valid_class_indices, invalid_class_indices))
        restore_indexes = np.zeros((self.num_fg_classes))
        restore_indexes[all_indexes] = np.arange(self.num_fg_classes)
        restore_indexes = nd.array(restore_indexes,
                                   ctx=nms_conditional_score.context)
        full_nms_conditional_score = full_nms_conditional_score.transpose(
            (1, 0, 2)).take(restore_indexes).transpose((1, 0, 2))

        sorted_score_reshape = nd.expand_dims(sorted_score, axis=2)
        # sorted_score_reshape = nd.BlockGrad(sorted_score_reshape)
        nms_multi_score = nd.broadcast_mul(lhs=sorted_score_reshape,
                                           rhs=full_nms_conditional_score)
        _ = nms_multi_score.mean().asnumpy()

        all_time = time.time() - nms_start_time
        if 'learn_nms_time' not in globals().keys(
        ) or 'learn_nms_count' not in globals().keys():
            globals()['learn_nms_time'] = []
            globals()['learn_nms_count'] = 0

        if globals()['learn_nms_count'] >= 1000:
            globals()['learn_nms_time'].pop(0)
            globals()['learn_nms_time'].append(all_time)
        else:
            globals()['learn_nms_time'].append(all_time)

        globals()['learn_nms_count'] += 1
        if globals()['learn_nms_count'] % 250 == 0:
            print("--->> learn nms running average time cost: {}".format(
                float(sum(globals()['learn_nms_time'])) /
                (1000 if globals()['learn_nms_count'] > 1000 else
                 globals()['learn_nms_count'])))

        self.assign(out_data[0], req[0], nms_multi_score)
        self.assign(out_data[1], req[1], sorted_bbox)
        self.assign(out_data[2], req[2], sorted_score)