Esempio n. 1
0
def label_offset(anchors, bbox, match, sample, 
                 means=(0,0,0,0), stds=(0.1,0.1,0.2,0.2), flatten=True):
    anchors = anchors.reshape((-1,4))
    N, _ = anchors.shape
    B, M, _ = bbox.shape
    anchor_x, anchor_y, anchor_w, anchor_h = corner_to_center(anchors, split=True)
    
    bbox = bbox.reshape((B,1,M,4))
    bbox = nd.broadcast_to(bbox, (B,N,M,4))
    bbox = nd.stack(*[nd.pick(bbox[:,:,:,p], match) for p in range(4)], axis=-1)
    bbox_x, bbox_y, bbox_w, bbox_h = corner_to_center(bbox, split=True)
    
    offset_x = ((bbox_x - anchor_x) / anchor_w - means[0]) / stds[0]
    offset_y = ((bbox_y - anchor_y) / anchor_h - means[1]) / stds[1]
    offset_w = (nd.log(bbox_w/anchor_w) - means[2]) / stds[2]
    offset_h = (nd.log(bbox_h/anchor_h) - means[3]) / stds[3]
    offset = nd.concat(*(offset_x, offset_y, offset_w, offset_h), dim=-1)
    sample = sample.reshape((B,N,1))
    sample = nd.broadcast_to(sample, (B,N,4)) > 0.5
    
    anchor_offset = nd.where(sample, offset, nd.zeros_like(offset))
    anchor_mask = nd.where(sample, nd.ones_like(offset), nd.zeros_like(offset))
    
    if flatten:
        anchor_offset = anchor_offset.reshape((B,-1))
        anchor_mask = anchor_mask.reshape((B,-1))
        
    return anchor_mask, anchor_offset
Esempio n. 2
0
def test_broadcast():
    a = nd.ones(shape=(LARGE_X, SMALL_Y))
    b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
    res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
    assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1]
    res = mx.nd.broadcast_like(b, a)
    assert np.sum(res[-1].asnumpy() == LARGE_X) == a.shape[1]
Esempio n. 3
0
 def forward(self, x):
     x = nd.pick(x,
                 nd.broadcast_to(self._dim.data(), x.shape[0]),
                 keepdims=True)
     x -= self._split.data()
     x *= nd.relu(self._sharpness.data())
     return nd.tanh(x)
Esempio n. 4
0
    def init_weights(self, ctx):

        self.first_stage.initialize(ctx=ctx)
        self.res_layers.initialize(ctx=ctx)
        self.head.initialize(ctx=ctx)

        if self.pretrained_base:
            if self.depth == 50:
                resnet2d = resnet50_v1b(pretrained=True)
            elif self.depth == 101:
                resnet2d = resnet101_v1b(pretrained=True)
            else:
                print('No such 2D pre-trained network of depth %d.' % (self.depth))

            weights2d = resnet2d.collect_params()
            if self.nonlocal_cfg is None:
                weights3d = self.collect_params()
            else:
                train_params_list = []
                raw_params = self.collect_params()
                for raw_name in raw_params.keys():
                    if 'nonlocal' in raw_name:
                        continue
                    train_params_list.append(raw_name)
                init_patterns = '|'.join(train_params_list)
                weights3d = self.collect_params(init_patterns)
            assert len(weights2d.keys()) == len(weights3d.keys()), 'Number of parameters should be same.'

            dict2d = {}
            for key_id, key_name in enumerate(weights2d.keys()):
                dict2d[key_id] = key_name

            dict3d = {}
            for key_id, key_name in enumerate(weights3d.keys()):
                dict3d[key_id] = key_name

            dict_transform = {}
            for key_id, key_name in dict3d.items():
                dict_transform[dict2d[key_id]] = key_name

            cnt = 0
            for key2d, key3d in dict_transform.items():
                if 'conv' in key3d:
                    temporal_dim = weights3d[key3d].shape[2]
                    temporal_2d = nd.expand_dims(weights2d[key2d].data(), axis=2)
                    inflated_2d = nd.broadcast_to(temporal_2d, shape=[0, 0, temporal_dim, 0, 0]) / temporal_dim
                    assert inflated_2d.shape == weights3d[key3d].shape, 'the shape of %s and %s does not match. ' % (key2d, key3d)
                    weights3d[key3d].set_data(inflated_2d)
                    cnt += 1
                    print('%s is done with shape: ' % (key3d), weights3d[key3d].shape)
                if 'batchnorm' in key3d:
                    assert weights2d[key2d].shape == weights3d[key3d].shape, 'the shape of %s and %s does not match. ' % (key2d, key3d)
                    weights3d[key3d].set_data(weights2d[key2d].data())
                    cnt += 1
                    print('%s is done with shape: ' % (key3d), weights3d[key3d].shape)
                if 'dense' in key3d:
                    cnt += 1
                    print('%s is skipped with shape: ' % (key3d), weights3d[key3d].shape)

            assert cnt == len(weights2d.keys()), 'Not all parameters have been ported, check the initialization.'
Esempio n. 5
0
def test_broadcast():
    a = nd.ones(shape=(LARGE_X, SMALL_Y))
    b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
    res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
    assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1]
    res = mx.nd.broadcast_like(b, a)
    assert np.sum(res[-1].asnumpy() == LARGE_X) == a.shape[1]
def create_input_for_rounding_ops():
    # Creates an vector with values (-LARGE_X/2 .... -2, -1, 0, 1, 2, .... , LARGE_X/2-1)
    # then divides each element by 2 i.e (-LARGE_X/4 .... -1, -0.5, 0, 0.5, 1, .... , LARGE_X/4-1)
    # and finally broadcasts to 
    inp = nd.arange(-LARGE_X//2, LARGE_X//2, dtype=np.float64).reshape(1, LARGE_X)
    inp = inp/2
    inp = nd.broadcast_to(inp, (SMALL_Y, LARGE_X))
    return inp
Esempio n. 7
0
    def forward(self, x, crisp=False):
        pick_index = nd.broadcast_to(self._dim.data(), x.shape[0])
        x = nd.pick(x, pick_index, keepdims=True)
        x = x - self._split.data()
        if (crisp == False):
            x = x * nd.relu(self._sharpness.data())

        return nd.sigmoid(x)
Esempio n. 8
0
    def forward(self, feature, label, begin_states, is_training):
        ''' Decode the hidden states to a temporal sequence.

        Parameters
        ----------
        feature: a NDArray with shape [n, d].
        label: a NDArray with shape [n, b, t, d].
        begin_states: a list of hidden states (list of hidden units with shape [n, b, d]) of RNNs.
        is_training: bool
        
        Returns
        -------
            outputs: the prediction, which is a NDArray with shape [n, b, t, d]
        '''
        ctx = label.context

        num_nodes, batch_size, seq_len, _ = label.shape
        aux = label[:, :, :, self.output_dim:]  # [n,b,t,d]
        label = label[:, :, :, :self.output_dim]  # [n,b,t,d]

        go = nd.zeros(shape=(num_nodes, batch_size, self.input_dim), ctx=ctx)
        output, states = [], begin_states

        for i in range(seq_len):
            # get next input
            if i == 0: data = go
            else:
                prev = nd.concat(output[i - 1], aux[:, :, i - 1], dim=-1)
                truth = nd.concat(label[:, :, i - 1], aux[:, :, i - 1], dim=-1)
                if is_training and self.use_sampling: value = self.sampling()
                else: value = 0
                data = value * truth + (1 - value) * prev

            # unroll 1 step
            for depth, cell in enumerate(self.cells):
                data, states[depth] = cell.forward_single(
                    feature, data, states[depth])
                if self.graphs[depth] is not None:
                    _data = data
                    for g in self.graphs[depth]:
                        _data = _data + g(data, feature)
                    data = _data

            # append feature to output
            _feature = nd.expand_dims(feature, axis=1)  # [n, 1, d]
            _feature = nd.broadcast_to(_feature,
                                       shape=(0, batch_size, 0))  # [n, b, d]
            data = nd.concat(data, _feature, dim=-1)  # [n, b, t, d]

            # proj output to prediction
            data = nd.reshape(data, shape=(num_nodes * batch_size, -1))
            data = self.proj(data)
            data = nd.reshape(data, shape=(num_nodes, batch_size, -1))

            output.append(data)

        output = nd.stack(*output, axis=2)
        return output
Esempio n. 9
0
 def msg_edge(self, edge):
     dist = edge.data['dist']
     while len(dist.shape) < len(edge.src['state'].shape):
         dist = nd.expand_dims(dist, axis=1)
     dist = nd.broadcast_to(dist,
                            shape=edge.src['state'].shape[:-1] + (0, ))
     state = nd.concat(edge.src['state'], edge.dst['state'], dist, dim=-1)
     alpha = nd.LeakyReLU(self.dense(state))
     return {'alpha': alpha, 'state': edge.src['state']}
Esempio n. 10
0
def test_where():
    a = nd.ones(shape=(LARGE_X, SMALL_Y))
    b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
    b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
    res = nd.where(b > 100, a, b)
    assert np.sum(res[-1].asnumpy() == 1) == b.shape[1]

    csr_cond = nd.sparse.cast_storage(b < 10, 'csr')
    res = nd.sparse.where(csr_cond, a, b)
    assert np.sum(res[0].asnumpy() == 1) == b.shape[1]
Esempio n. 11
0
def test_where():
    a = nd.ones(shape=(LARGE_X, SMALL_Y))
    b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
    b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
    res = nd.where(b > 100, a, b)
    assert np.sum(res[-1].asnumpy() == 1) == b.shape[1]

    csr_cond = nd.sparse.cast_storage(b < 10, 'csr')
    res = nd.sparse.where(csr_cond, a, b)
    assert np.sum(res[0].asnumpy() == 1) == b.shape[1]
Esempio n. 12
0
def label_box_cls(match, sample, gt_cls, ignore_label=-1):
    B, N = match.shape
    B, M = gt_cls.shape
    # (B,N,M)
    gt_cls = gt_cls.reshape((B,1,M))
    gt_cls = nd.broadcast_to(gt_cls, (B,N,M))
    # (B,N)
    label_cls = nd.pick(gt_cls, match, axis=-1) + 1
    label_cls = nd.where(sample > 0.5, label_cls, nd.ones_like(label_cls)*ignore_label)
    label_cls = nd.where(sample < -0.5, nd.zeros_like(label_cls), label_cls)
    # (B,N)
    label_mask = label_cls > -0.5
    return label_cls, label_mask
Esempio n. 13
0
    def forward(self, x):

        x2_1 = self.net[0](x)
        x2_2 = self.net[1](x2_1)

        # x_scale=resize(x,224,224)
        # x1_1=self.net[0](x_scale)

        x1_2 = self.net[2](x2_1)
        x1_2 = x1_2.expand_dims(axis=2)
        x1_2 = x1_2.expand_dims(axis=3)
        x1_2 = nd.broadcast_to(x1_2, shape=x2_2.shape)

        x12 = nd.concat(x1_2, x2_2, dim=1)

        hs = self.net[3](x12)

        return hs
Esempio n. 14
0
def test_clip():
    a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
    b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y))
    res = nd.clip(b, a_min=100, a_max=1000)
    assert np.sum(res[-1].asnumpy() == 1000) == b.shape[1]
def create_2d_tensor(rows, columns, dtype=np.int64):
    a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
    b = nd.broadcast_to(a, shape=(a.shape[0], columns))
    return nd.array(b, dtype=dtype)
Esempio n. 16
0
    def __init__(self, nclass=1000, norm_layer=BatchNorm, num_segments=1,
                 norm_kwargs=None, partial_bn=False, pretrained_base=True,
                 dropout_ratio=0.5, init_std=0.01,
                 ctx=None, **kwargs):
        super(I3D_InceptionV1, self).__init__(**kwargs)

        self.num_segments = num_segments
        self.feat_dim = 1024
        self.dropout_ratio = dropout_ratio
        self.init_std = init_std

        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            self.features.add(_make_basic_conv(in_channels=3, channels=64, kernel_size=7, strides=2, padding=3, norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(nn.MaxPool3D(pool_size=(1, 3, 3), strides=(1, 2, 2), padding=(0, 1, 1)))

            if partial_bn:
                if norm_kwargs is not None:
                    norm_kwargs['use_global_stats'] = True
                else:
                    norm_kwargs = {}
                    norm_kwargs['use_global_stats'] = True

            self.features.add(_make_basic_conv(in_channels=64, channels=64, kernel_size=1, norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(_make_basic_conv(in_channels=64, channels=192, kernel_size=3, padding=(1, 1, 1), norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(nn.MaxPool3D(pool_size=(1, 3, 3), strides=(1, 2, 2), padding=(0, 1, 1)))

            self.features.add(_make_Mixed_3a(192, 32, 'Mixed_3a_', norm_layer, norm_kwargs))
            self.features.add(_make_Mixed_3b(256, 64, 'Mixed_3b_', norm_layer, norm_kwargs))
            self.features.add(nn.MaxPool3D(pool_size=3, strides=(2, 2, 2), padding=(1, 1, 1)))

            self.features.add(_make_Mixed_4a(480, 64, 'Mixed_4a_', norm_layer, norm_kwargs))
            self.features.add(_make_Mixed_4b(512, 64, 'Mixed_4b_', norm_layer, norm_kwargs))
            self.features.add(_make_Mixed_4c(512, 64, 'Mixed_4c_', norm_layer, norm_kwargs))
            self.features.add(_make_Mixed_4d(512, 64, 'Mixed_4d_', norm_layer, norm_kwargs))
            self.features.add(_make_Mixed_4e(528, 128, 'Mixed_4e_', norm_layer, norm_kwargs))
            self.features.add(nn.MaxPool3D(pool_size=2, strides=(2, 2, 2)))

            self.features.add(_make_Mixed_5a(832, 128, 'Mixed_5a_', norm_layer, norm_kwargs))
            self.features.add(_make_Mixed_5b(832, 128, 'Mixed_5b_', norm_layer, norm_kwargs))
            self.features.add(nn.GlobalAvgPool3D())

            self.head = nn.HybridSequential(prefix='')
            self.head.add(nn.Dropout(rate=self.dropout_ratio))
            self.output = nn.Dense(units=nclass, in_units=self.feat_dim, weight_initializer=init.Normal(sigma=self.init_std))
            self.head.add(self.output)

            self.features.initialize(ctx=ctx)
            self.head.initialize(ctx=ctx)

            if pretrained_base:
                inceptionv1_2d = googlenet(pretrained=True)
                weights2d = inceptionv1_2d.collect_params()
                weights3d = self.collect_params()
                assert len(weights2d.keys()) == len(weights3d.keys()), 'Number of parameters should be same.'

                dict2d = {}
                for key_id, key_name in enumerate(weights2d.keys()):
                    dict2d[key_id] = key_name

                dict3d = {}
                for key_id, key_name in enumerate(weights3d.keys()):
                    dict3d[key_id] = key_name

                dict_transform = {}
                for key_id, key_name in dict3d.items():
                    dict_transform[dict2d[key_id]] = key_name

                cnt = 0
                for key2d, key3d in dict_transform.items():
                    if 'conv' in key3d:
                        temporal_dim = weights3d[key3d].shape[2]
                        temporal_2d = nd.expand_dims(weights2d[key2d].data(), axis=2)
                        inflated_2d = nd.broadcast_to(temporal_2d, shape=[0, 0, temporal_dim, 0, 0]) / temporal_dim
                        assert inflated_2d.shape == weights3d[key3d].shape, 'the shape of %s and %s does not match. ' % (key2d, key3d)
                        weights3d[key3d].set_data(inflated_2d)
                        cnt += 1
                        print('%s is done with shape: ' % (key3d), weights3d[key3d].shape)
                    if 'batchnorm' in key3d:
                        assert weights2d[key2d].shape == weights3d[key3d].shape, 'the shape of %s and %s does not match. ' % (key2d, key3d)
                        weights3d[key3d].set_data(weights2d[key2d].data())
                        cnt += 1
                        print('%s is done with shape: ' % (key3d), weights3d[key3d].shape)
                    if 'dense' in key3d:
                        cnt += 1
                        print('%s is skipped with shape: ' % (key3d), weights3d[key3d].shape)

                assert cnt == len(weights2d.keys()), 'Not all parameters have been ported, check the initialization.'
def create_input_for_trigonometric_ops(vals):
    # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using tile operator
    inp = nd.array(vals).reshape(1, 5)
    inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10))
    return inp
Esempio n. 18
0
def test_clip():
    a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
    b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y))
    res = nd.clip(b, a_min=100, a_max=1000)
    assert np.sum(res[-1].asnumpy() == 1000) == b.shape[1]
Esempio n. 19
0
    def __init__(self, nclass=1000, pretrained=False, pretrained_base=True,
                 num_segments=1, num_crop=1, feat_ext=False,
                 dropout_ratio=0.5, init_std=0.01, partial_bn=False,
                 ctx=None, norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
        super(I3D_InceptionV3, self).__init__(**kwargs)

        self.num_segments = num_segments
        self.num_crop = num_crop
        self.feat_dim = 2048
        self.dropout_ratio = dropout_ratio
        self.init_std = init_std
        self.feat_ext = feat_ext

        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            self.features.add(_make_basic_conv(in_channels=3, channels=32, kernel_size=3, strides=2, padding=(1, 0, 0),
                                               norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            if partial_bn:
                if norm_kwargs is not None:
                    norm_kwargs['use_global_stats'] = True
                else:
                    norm_kwargs = {}
                    norm_kwargs['use_global_stats'] = True

            self.features.add(_make_basic_conv(in_channels=32, channels=32, kernel_size=3, padding=(1, 0, 0),
                                               norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(_make_basic_conv(in_channels=32, channels=64, kernel_size=3, padding=1,
                                               norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(nn.MaxPool3D(pool_size=3, strides=(1, 2, 2), padding=(1, 0, 0)))
            self.features.add(_make_basic_conv(in_channels=64, channels=80, kernel_size=1,
                                               norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(_make_basic_conv(in_channels=80, channels=192, kernel_size=3, padding=(1, 0, 0),
                                               norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            self.features.add(nn.MaxPool3D(pool_size=3, strides=(1, 2, 2), padding=(1, 0, 0)))
            self.features.add(_make_A(192, 32, 'A1_', norm_layer, norm_kwargs))
            self.features.add(_make_A(256, 64, 'A2_', norm_layer, norm_kwargs))
            self.features.add(_make_A(288, 64, 'A3_', norm_layer, norm_kwargs))
            self.features.add(_make_B('B_', norm_layer, norm_kwargs))
            self.features.add(_make_C(768, 128, 'C1_', norm_layer, norm_kwargs))
            self.features.add(_make_C(768, 160, 'C2_', norm_layer, norm_kwargs))
            self.features.add(_make_C(768, 160, 'C3_', norm_layer, norm_kwargs))
            self.features.add(_make_C(768, 192, 'C4_', norm_layer, norm_kwargs))
            self.features.add(_make_D('D_', norm_layer, norm_kwargs))
            self.features.add(_make_E(1280, 'E1_', norm_layer, norm_kwargs))
            self.features.add(_make_E(2048, 'E2_', norm_layer, norm_kwargs))
            self.features.add(nn.GlobalAvgPool3D())

            self.head = nn.HybridSequential(prefix='')
            self.head.add(nn.Dropout(rate=self.dropout_ratio))
            self.output = nn.Dense(units=nclass, in_units=self.feat_dim, weight_initializer=init.Normal(sigma=self.init_std))
            self.head.add(self.output)

            self.features.initialize(ctx=ctx)
            self.head.initialize(ctx=ctx)

            if pretrained_base and not pretrained:
                inceptionv3_2d = inception_v3(pretrained=True)
                weights2d = inceptionv3_2d.collect_params()
                weights3d = self.collect_params()
                assert len(weights2d.keys()) == len(weights3d.keys()), 'Number of parameters should be same.'

                dict2d = {}
                for key_id, key_name in enumerate(weights2d.keys()):
                    dict2d[key_id] = key_name

                dict3d = {}
                for key_id, key_name in enumerate(weights3d.keys()):
                    dict3d[key_id] = key_name

                dict_transform = {}
                for key_id, key_name in dict3d.items():
                    dict_transform[dict2d[key_id]] = key_name

                cnt = 0
                for key2d, key3d in dict_transform.items():
                    if 'conv' in key3d:
                        temporal_dim = weights3d[key3d].shape[2]
                        temporal_2d = nd.expand_dims(weights2d[key2d].data(), axis=2)
                        inflated_2d = nd.broadcast_to(temporal_2d, shape=[0, 0, temporal_dim, 0, 0]) / temporal_dim
                        assert inflated_2d.shape == weights3d[key3d].shape, 'the shape of %s and %s does not match. ' % (key2d, key3d)
                        weights3d[key3d].set_data(inflated_2d)
                        cnt += 1
                        print('%s is done with shape: ' % (key3d), weights3d[key3d].shape)
                    if 'batchnorm' in key3d:
                        assert weights2d[key2d].shape == weights3d[key3d].shape, 'the shape of %s and %s does not match. ' % (key2d, key3d)
                        weights3d[key3d].set_data(weights2d[key2d].data())
                        cnt += 1
                        print('%s is done with shape: ' % (key3d), weights3d[key3d].shape)
                    if 'dense' in key3d:
                        cnt += 1
                        print('%s is skipped with shape: ' % (key3d), weights3d[key3d].shape)

                assert cnt == len(weights2d.keys()), 'Not all parameters have been ported, check the initialization.'
Esempio n. 20
0
    def forward(self, is_train, req, in_data, out_data, aux):
        nms_start_time = time.time()
        #inputs
        cls_score = in_data[0]
        bbox_pred = in_data[1]
        rois = in_data[2]
        im_info = in_data[3]
        fc_all_2_relu = in_data[4]
        nms_rank_weight = in_data[5]
        nms_rank_bias = in_data[6]
        roi_feat_embedding_weight = in_data[7]
        roi_feat_embedding_bias = in_data[8]
        nms_pair_pos_fc1_1_weight = in_data[9]
        nms_pair_pos_fc1_1_bias = in_data[10]
        nms_query_1_weight = in_data[11]
        nms_query_1_bias = in_data[12]
        nms_key_1_weight = in_data[13]
        nms_key_1_bias = in_data[14]
        nms_linear_out_1_weight = in_data[15]
        nms_linear_out_1_bias = in_data[16]
        nms_logit_weight = in_data[17]
        nms_logit_bias = in_data[18]
        if self.has_non_gt_index:
            non_gt_index = in_data[19]
        else:
            non_gt_index = None

        if self.nongt_dim is not None:
            cls_score_nongt = nd.slice_axis(data=cls_score,
                                            axis=0,
                                            begin=0,
                                            end=self.nongt_dim)
            # cls_score_nongt = monitor_wrapper(cls_score_nongt, 'cls_score_nongt')
            bbox_pred_nongt = nd.slice_axis(data=bbox_pred,
                                            axis=0,
                                            begin=0,
                                            end=self.nongt_dim)
        elif non_gt_index is not None:
            cls_score_nongt = nd.take(a=cls_score, indices=non_gt_index)
            bbox_pred_nongt = nd.take(a=bbox_pred, indices=non_gt_index)
        else:
            cls_score_nongt = cls_score
            bbox_pred_nongt = bbox_pred
        bbox_pred_nongt = nd.BlockGrad(bbox_pred_nongt)

        # remove batch idx and gt roi
        sliced_rois = nd.slice_axis(data=rois, axis=1, begin=1, end=None)
        if self.nongt_dim is not None:
            sliced_rois = nd.slice_axis(data=sliced_rois,
                                        axis=0,
                                        begin=0,
                                        end=self.nongt_dim)
        elif non_gt_index is not None:
            sliced_rois = nd.take(a=sliced_rois, indices=non_gt_index)
        # bbox_pred_nobg, [num_rois, 4*(num_reg_classes-1)]
        bbox_pred_nobg = nd.slice_axis(data=bbox_pred_nongt,
                                       axis=1,
                                       begin=4,
                                       end=None)
        # [num_boxes, 4, num_reg_classes-1]
        refined_bbox = refine_bbox_nd(sliced_rois,
                                      bbox_pred_nobg,
                                      im_info,
                                      means=self.bbox_means,
                                      stds=self.bbox_stds)
        # softmax cls_score to cls_prob, [num_rois, num_classes]
        cls_prob = nd.softmax(data=cls_score_nongt, axis=-1)
        cls_prob_nobg = nd.slice_axis(cls_prob, axis=1, begin=1, end=None)
        sorted_cls_prob_nobg = nd.sort(data=cls_prob_nobg,
                                       axis=0,
                                       is_ascend=False)
        # sorted_score, [first_n, num_fg_classes]
        sorted_score = nd.slice_axis(sorted_cls_prob_nobg,
                                     axis=0,
                                     begin=0,
                                     end=self.first_n,
                                     name='sorted_score')
        max_score_per_class = sorted_score.max(axis=0)
        max_score_per_class_numpy = max_score_per_class.asnumpy()

        valid_class_thresh = self.class_thresh
        valid_class_thresh = np.minimum(valid_class_thresh,
                                        max_score_per_class_numpy.max())
        valid_class_indices = np.where(
            max_score_per_class_numpy >= valid_class_thresh)[0]
        invalid_class_indices = np.where(
            max_score_per_class_numpy < valid_class_thresh)[0]
        num_valid_classes = len(valid_class_indices)
        valid_class_indices_nd = nd.array(valid_class_indices,
                                          ctx=sorted_score.context)

        # sort by score
        rank_indices = nd.argsort(data=cls_prob_nobg, axis=0, is_ascend=False)
        # first_rank_indices, [first_n, num_fg_classes]
        first_rank_indices = nd.slice_axis(rank_indices,
                                           axis=0,
                                           begin=0,
                                           end=self.first_n)
        valid_first_rank_indices = first_rank_indices.transpose().take(
            valid_class_indices_nd).transpose()

        # sorted_bbox, [first_n, num_fg_classes, 4, num_reg_classes-1]
        sorted_bbox = nd.take(a=refined_bbox, indices=first_rank_indices)
        if self.class_agnostic:
            # sorted_bbox, [first_n, num_fg_classes, 4]
            sorted_bbox = nd.Reshape(sorted_bbox,
                                     shape=(0, 0, 0),
                                     name='sorted_bbox')
        else:
            cls_mask = nd.arange(0, self.num_fg_classes)
            cls_mask = nd.Reshape(cls_mask, shape=(1, -1, 1))
            cls_mask = nd.broadcast_to(cls_mask, shape=(self.first_n, 0, 4))
            # sorted_bbox, [first_n, num_fg_classes, 4]
            sorted_bbox = nd.pick(data=sorted_bbox,
                                  name='sorted_bbox',
                                  index=cls_mask,
                                  axis=3)

        valid_sorted_bbox = sorted_bbox.transpose(
            (1, 0, 2)).take(valid_class_indices_nd).transpose((1, 0, 2))

        # sorted_bbox = monitor_wrapper(sorted_bbox, 'sorted_bbox')
        # nms_rank_embedding, [first_n, 1024]
        nms_rank_embedding = extract_rank_embedding_nd(self.first_n, 1024)
        # nms_rank_feat, [first_n, 1024]
        nms_rank_feat = nd.FullyConnected(name='nms_rank',
                                          data=nms_rank_embedding,
                                          num_hidden=128,
                                          weight=nms_rank_weight,
                                          bias=nms_rank_bias)
        # nms_position_matrix, [num_valid_classes, first_n, first_n, 4]
        nms_position_matrix = extract_multi_position_matrix_nd(
            valid_sorted_bbox)
        # roi_feature_embedding, [num_rois, 1024]
        # fc_all_2_relu = monitor_wrapper(fc_all_2_relu, 'fc_all_2_relu')
        roi_feat_embedding = nd.FullyConnected(
            name='roi_feat_embedding',
            data=fc_all_2_relu,
            num_hidden=128,
            weight=roi_feat_embedding_weight,
            bias=roi_feat_embedding_bias)
        # sorted_roi_feat, [first_n, num_valid_classes, 128]
        sorted_roi_feat = nd.take(a=roi_feat_embedding,
                                  indices=valid_first_rank_indices)

        # vectorized nms
        # nms_embedding_feat, [first_n, num_valid_classes, 128]
        nms_embedding_feat = nd.broadcast_add(lhs=sorted_roi_feat,
                                              rhs=nd.expand_dims(nms_rank_feat,
                                                                 axis=1))
        # nms_attention_1, [first_n, num_valid_classes, 1024]
        nms_attention_1 = nms_attention_nd(
            nms_embedding_feat,
            nms_position_matrix,
            nms_pair_pos_fc1_1_weight,
            nms_pair_pos_fc1_1_bias,
            nms_query_1_weight,
            nms_query_1_bias,
            nms_key_1_weight,
            nms_key_1_bias,
            nms_linear_out_1_weight,
            nms_linear_out_1_bias,
            num_rois=self.first_n,
            index=1,
            group=self.nms_attention_group,
            dim=self.nms_attention_dim,
            fc_dim=self.nms_attention_fc_dim,
            feat_dim=self.nms_attention_feat_dim)
        nms_all_feat_1 = nms_embedding_feat + nms_attention_1
        nms_all_feat_1_relu = nd.Activation(data=nms_all_feat_1,
                                            act_type='relu',
                                            name='nms_all_feat_1_relu')
        # [first_n * num_valid_classes, 1024]
        nms_all_feat_1_relu_reshape = nd.Reshape(nms_all_feat_1_relu,
                                                 shape=(-3, -2))
        # logit, [first_n * num_valid_classes, num_thresh]
        nms_conditional_logit = nd.FullyConnected(
            name='nms_logit',
            data=nms_all_feat_1_relu_reshape,
            num_hidden=self.num_thresh,
            weight=nms_logit_weight,
            bias=nms_logit_bias)
        # logit_reshape, [first_n, num_valid_classes, num_thresh]
        nms_conditional_logit_reshape = nd.Reshape(nms_conditional_logit,
                                                   shape=(self.first_n,
                                                          num_valid_classes,
                                                          self.num_thresh))
        nms_conditional_score = nd.Activation(
            data=nms_conditional_logit_reshape,
            act_type='sigmoid',
            name='nms_conditional_score')
        if num_valid_classes == self.num_fg_classes:
            full_nms_conditional_score = nms_conditional_score
        else:
            full_nms_conditional_score = nd.concat(
                nms_conditional_score,
                nd.zeros(
                    (self.first_n, self.num_fg_classes - num_valid_classes,
                     self.num_thresh),
                    ctx=nms_conditional_score.context),
                dim=1)

        all_indexes = np.concatenate(
            (valid_class_indices, invalid_class_indices))
        restore_indexes = np.zeros((self.num_fg_classes))
        restore_indexes[all_indexes] = np.arange(self.num_fg_classes)
        restore_indexes = nd.array(restore_indexes,
                                   ctx=nms_conditional_score.context)
        full_nms_conditional_score = full_nms_conditional_score.transpose(
            (1, 0, 2)).take(restore_indexes).transpose((1, 0, 2))

        sorted_score_reshape = nd.expand_dims(sorted_score, axis=2)
        # sorted_score_reshape = nd.BlockGrad(sorted_score_reshape)
        nms_multi_score = nd.broadcast_mul(lhs=sorted_score_reshape,
                                           rhs=full_nms_conditional_score)
        _ = nms_multi_score.mean().asnumpy()

        all_time = time.time() - nms_start_time
        if 'learn_nms_time' not in globals().keys(
        ) or 'learn_nms_count' not in globals().keys():
            globals()['learn_nms_time'] = []
            globals()['learn_nms_count'] = 0

        if globals()['learn_nms_count'] >= 1000:
            globals()['learn_nms_time'].pop(0)
            globals()['learn_nms_time'].append(all_time)
        else:
            globals()['learn_nms_time'].append(all_time)

        globals()['learn_nms_count'] += 1
        if globals()['learn_nms_count'] % 250 == 0:
            print("--->> learn nms running average time cost: {}".format(
                float(sum(globals()['learn_nms_time'])) /
                (1000 if globals()['learn_nms_count'] > 1000 else
                 globals()['learn_nms_count'])))

        self.assign(out_data[0], req[0], nms_multi_score)
        self.assign(out_data[1], req[1], sorted_bbox)
        self.assign(out_data[2], req[2], sorted_score)