Esempio n. 1
0
    def fast_nms(self, boxes, scores, masks, max_num_detections=100):
        iou_threshold = self.nms_thresh
        top_k = self.top_k

        # 同类方框根据得分降序排列
        scores, idx = P.argsort(scores, axis=1, descending=True)

        idx = idx[:, :top_k]
        scores = scores[:, :top_k]

        num_classes, num_dets = P.shape(idx)[0], P.shape(idx)[1]

        idx = P.reshape(idx, (-1, ))
        boxes = P.gather(boxes, idx)
        boxes = P.reshape(boxes, (num_classes, num_dets, 4))
        masks = P.gather(masks, idx)
        masks = P.reshape(masks, (num_classes, num_dets, -1))

        # 计算一个c×n×n的IOU矩阵,其中每个n×n矩阵表示对该类n个候选框,两两之间的IOU
        iou = jaccard(boxes, boxes)
        # 因为自己与自己的IOU=1,IOU(A,B)=IOU(B,A),所以对上一步得到的IOU矩阵
        # 进行一次处理。具体做法是将每一个通道,的对角线元素和下三角部分置为0
        rows = P.range(0, num_dets, 1, 'int32')
        cols = P.range(0, num_dets, 1, 'int32')
        rows = P.expand(P.reshape(rows, (1, -1)), [num_dets, 1])
        cols = P.expand(P.reshape(cols, (-1, 1)), [1, num_dets])
        tri_mask = P.cast(rows > cols, 'float32')
        tri_mask = P.expand(P.reshape(tri_mask, (1, num_dets, num_dets)),
                            [num_classes, 1, 1])
        iou = tri_mask * iou
        iou_max = P.reduce_max(iou, dim=1)

        # Now just filter out the ones higher than the threshold
        keep = P.where(iou_max <= iou_threshold)

        # Assign each kept detection to its corresponding class
        classes = P.range(0, num_classes, 1, 'int32')
        classes = P.expand(P.reshape(classes, (-1, 1)), [1, num_dets])
        classes = P.gather_nd(classes, keep)

        boxes = P.gather_nd(boxes, keep)
        masks = P.gather_nd(masks, keep)
        scores = P.gather_nd(scores, keep)

        # Only keep the top cfg.max_num_detections highest scores across all classes
        scores, idx = P.argsort(scores, axis=0, descending=True)
        idx = idx[:max_num_detections]
        scores = scores[:max_num_detections]

        classes = P.gather(classes, idx)
        boxes = P.gather(boxes, idx)
        masks = P.gather(masks, idx)

        return boxes, masks, classes, scores
Esempio n. 2
0
def reorder_neuron_head(model, head_importance, neuron_importance):
    # reorder heads and ffn neurons
    for layer, current_importance in enumerate(neuron_importance):
        # reorder heads
        idx = L.argsort(head_importance[layer], descending=True)[-1]
        #model.encoder_stack.block[layer].attn.reorder_heads(idx)
        reorder_head(model.encoder_stack.block[layer].attn, idx)
        # reorder neurons
        idx = L.argsort(FD.to_variable(current_importance), descending=True)[-1]
        #model.encoder_stack.block[layer].ffn.reorder_neurons(idx)
        reorder_neuron(model.encoder_stack.block[layer].ffn, idx)
Esempio n. 3
0
def fast_nms(boxes, scores, conf_thresh, nms_thresh, keep_top_k, nms_top_k):
    '''
    :param boxes:    [?, 4]
    :param scores:   [80, ?]
    '''

    # 同类方框根据得分降序排列
    scores, idx = P.argsort(scores, axis=1, descending=True)

    idx = idx[:, :keep_top_k]
    scores = scores[:, :keep_top_k]

    num_classes, num_dets = P.shape(idx)[0], P.shape(idx)[1]

    idx = P.reshape(idx, (-1, ))
    boxes = P.gather(boxes, idx)
    boxes = P.reshape(boxes, (num_classes, num_dets, 4))

    # 计算一个c×n×n的IOU矩阵,其中每个n×n矩阵表示对该类n个候选框,两两之间的IOU
    iou = _iou(boxes, boxes)

    # 因为自己与自己的IOU=1,IOU(A,B)=IOU(B,A),所以对上一步得到的IOU矩阵
    # 进行一次处理。具体做法是将每一个通道,的对角线元素和下三角部分置为0
    rows = P.range(0, num_dets, 1, 'int32')
    cols = P.range(0, num_dets, 1, 'int32')
    rows = P.expand(P.reshape(rows, (1, -1)), [num_dets, 1])
    cols = P.expand(P.reshape(cols, (-1, 1)), [1, num_dets])
    tri_mask = P.cast(rows > cols, 'float32')
    tri_mask = P.expand(P.reshape(tri_mask, (1, num_dets, num_dets)),
                        [num_classes, 1, 1])
    iou = tri_mask * iou
    iou_max = P.reduce_max(iou, dim=1)

    # 同一类别,n个框与“分数比它高的框”的最高iou超过nms_thresh的话,就丢弃。下标是0的框肯定被保留。
    keep = P.where(iou_max <= nms_thresh)

    # Assign each kept detection to its corresponding class
    classes = P.range(0, num_classes, 1, 'int32')
    classes = P.expand(P.reshape(classes, (-1, 1)), [1, num_dets])
    classes = P.gather_nd(classes, keep)

    boxes = P.gather_nd(boxes, keep)
    scores = P.gather_nd(scores, keep)

    # Only keep the top cfg.max_num_detections highest scores across all classes
    scores, idx = P.argsort(scores, axis=0, descending=True)
    idx = idx[:nms_top_k]
    scores = scores[:nms_top_k]

    classes = P.gather(classes, idx)
    boxes = P.gather(boxes, idx)

    return boxes, scores, classes
Esempio n. 4
0
def matrix_nms(bboxes,
               scores,
               score_threshold,
               post_threshold,
               nms_top_k,
               keep_top_k,
               use_gaussian=False,
               gaussian_sigma=2.):
    scores = L.transpose(scores, [1, 0])
    inds = L.where(scores > score_threshold)
    if len(inds) == 0:
        return L.zeros((0, 6), 'float32') - 1.0

    cate_scores = L.gather_nd(scores, inds)
    cate_labels = inds[:, 1]
    bboxes = L.gather(bboxes, inds[:, 0])

    # sort and keep top nms_top_k
    _, sort_inds = L.argsort(cate_scores, descending=True)
    if nms_top_k > 0 and len(sort_inds) > nms_top_k:
        sort_inds = sort_inds[:nms_top_k]
    bboxes = L.gather(bboxes, sort_inds)
    cate_scores = L.gather(cate_scores, sort_inds)
    cate_labels = L.gather(cate_labels, sort_inds)

    # Matrix NMS
    kernel = 'gaussian' if use_gaussian else 'linear'
    cate_scores = _matrix_nms(bboxes, cate_labels, cate_scores, kernel=kernel, sigma=gaussian_sigma)

    # filter.
    keep = L.where(cate_scores >= post_threshold)
    if len(keep) == 0:
        return L.zeros((0, 6), 'float32') - 1.0
    bboxes = L.gather(bboxes, keep)
    cate_scores = L.gather(cate_scores, keep)
    cate_labels = L.gather(cate_labels, keep)

    # sort and keep keep_top_k
    _, sort_inds = L.argsort(cate_scores, descending=True)
    if len(sort_inds) > keep_top_k:
        sort_inds = sort_inds[:keep_top_k]
    bboxes = L.gather(bboxes, sort_inds)
    cate_scores = L.gather(cate_scores, sort_inds)
    cate_labels = L.gather(cate_labels, sort_inds)

    cate_scores = L.unsqueeze(cate_scores, 1)
    cate_labels = L.unsqueeze(cate_labels, 1)
    cate_labels = L.cast(cate_labels, 'float32')
    pred = L.concat([cate_labels, cate_scores, bboxes], 1)

    return pred
Esempio n. 5
0
def no_nms(bboxes,
           scores,
           score_threshold,
           keep_top_k):
    scores = L.transpose(scores, [1, 0])
    inds = L.where(scores > score_threshold)
    if len(inds) == 0:
        return L.zeros((0, 6), 'float32') - 1.0

    cate_scores = L.gather_nd(scores, inds)
    cate_labels = inds[:, 1]
    bboxes = L.gather(bboxes, inds[:, 0])

    # sort and keep top keep_top_k
    _, sort_inds = L.argsort(cate_scores, descending=True)
    if keep_top_k > 0 and len(sort_inds) > keep_top_k:
        sort_inds = sort_inds[:keep_top_k]
    bboxes = L.gather(bboxes, sort_inds)
    cate_scores = L.gather(cate_scores, sort_inds)
    cate_labels = L.gather(cate_labels, sort_inds)

    cate_scores = L.unsqueeze(cate_scores, 1)
    cate_labels = L.unsqueeze(cate_labels, 1)
    cate_labels = L.cast(cate_labels, 'float32')
    pred = L.concat([cate_labels, cate_scores, bboxes], 1)

    return pred
Esempio n. 6
0
    def __build_edges(self, edges, node_shift, edge_lod, edge_feats):
        """ Merge subgraph edges. 
        """
        if isinstance(edges, tuple):
            src, dst = edges
        else:
            src = edges[:, 0]
            dst = edges[:, 1]

        src = L.reshape(src, [-1])
        dst = L.reshape(dst, [-1])
        src = paddle_helper.ensure_dtype(src, dtype="int32")
        dst = paddle_helper.ensure_dtype(dst, dtype="int32")
        # preprocess edges
        lod_dst = L.lod_reset(dst, edge_lod)
        node_shift = L.reshape(node_shift, [-1, 1])
        node_shift = L.sequence_expand_as(node_shift, lod_dst)
        node_shift = L.reshape(node_shift, [-1])
        src = src + node_shift
        dst = dst + node_shift
        # sort edges
        self._edges_dst, index = L.argsort(dst)
        self._edges_src = L.gather(src, index, overwrite=False)

        # assign edge features
        if edge_feats is not None:
            for key, efeat in edge_feats.items():
                self.edge_feat_tensor_dict[key] = L.gather(efeat,
                                                           index,
                                                           overwrite=False)
Esempio n. 7
0
        def exist_objs_3(keep, masks, classes, scores, upsampled_size_out,
                         resize_shape, ori_shape):
            keep = L.reshape(keep, (-1, ))
            keep.stop_gradient = True
            masks = L.gather(masks, keep)  # [M4, s4, s4]   M4个物体的掩码概率
            scores = L.gather(scores, keep)  # [M4, ]   M4个物体的分数
            classes = L.gather(classes, keep)  # [M4, ]   M4个物体的类别id

            # 第五次过滤,只保留得分前cfg['max_per_img']个物体
            _, sort_inds = L.argsort(scores, axis=-1, descending=True)
            sort_inds = sort_inds[:cfg['max_per_img']]
            sort_inds.stop_gradient = True

            masks = L.gather(masks, sort_inds)  # [M5, s4, s4]   M5个物体的掩码概率
            scores = L.gather(scores, sort_inds)  # [M5, ]   M5个物体的分数
            classes = L.gather(classes, sort_inds)  # [M5, ]   M5个物体的类别id

            masks = L.resize_bilinear(
                L.unsqueeze(masks, axes=[0]),
                out_shape=upsampled_size_out,
                align_corners=False,
                align_mode=0)[:, :, :resize_shape[0], :resize_shape[1]]  # 去掉黑边
            masks = L.resize_bilinear(masks,
                                      out_shape=ori_shape[:2],
                                      align_corners=False,
                                      align_mode=0)  # 插值成原图大小
            masks = L.cast(masks > cfg['mask_thr'], 'float32')[0]
            return masks, classes, scores
Esempio n. 8
0
def chunk_softmax(logits, labels, topk=10):
    after_exp = L.exp(logits)
    out, _ = L.argsort(after_exp, axis=-1)
    denorm = L.reduce_sum(out[:, -topk:], dim=-1, keep_dim=True)
    probs = after_exp / denorm
    one_hot = F.one_hot(labels, depth=probs.shape[-1])
    loss = -L.reduce_sum(one_hot * L.log(probs)) / logits.shape[0]
    return loss
Esempio n. 9
0
 def test_argsort(self):
     program = Program()
     with program_guard(program):
         data = layers.data(name='x', shape=[2, 3, 3], dtype="float32")
         out, ids = layers.argsort(input=data, axis=1)
         self.assertIsNotNone(out)
         self.assertIsNotNone(ids)
     print(str(program))
Esempio n. 10
0
def uniq_edges(src, dst, num_nodes):
    sorted_dst = L.cast(dst, dtype="int64")
    sorted_src = L.cast(src, dtype="int64")
    num_nodes = L.cast(num_nodes, dtype="int64")
    edge_hash = sorted_dst * num_nodes + sorted_src
    edge_hash, _ = L.argsort(edge_hash)
    edge_hash, _ = L.unique(edge_hash, dtype="int64")
    sorted_src = L.elementwise_mod(edge_hash, num_nodes)
    sorted_dst = L.elementwise_div(edge_hash, num_nodes)
    sorted_src = L.cast(sorted_src, dtype="int32")
    sorted_dst = L.cast(sorted_dst, dtype="int32")
    return sorted_src, sorted_dst
Esempio n. 11
0
    def test_affine_grid(self):
        program = Program()
        with program_guard(program):
            data = layers.data(name='data', shape=[2, 3, 3], dtype="float32")
            out, ids = layers.argsort(input=data, axis=1)

            theta = layers.data(name="theta", shape=[2, 3], dtype="float32")
            out_shape = layers.data(
                name="out_shape", shape=[-1], dtype="float32")
            data_0 = layers.affine_grid(theta, out_shape)
            data_1 = layers.affine_grid(theta, [5, 3, 28, 28])

            self.assertIsNotNone(data_0)
            self.assertIsNotNone(data_1)
        print(str(program))
Esempio n. 12
0
    def pack_padded_sequence(self, x, mask, pad_index):
        """
        Packs a padded sequences x.

        Args:
            x: input matrix
            mask: mask matrix
            pad_index: pad_index

        Returns:
            new_x: output
            batch_sizes: sort batch_size by step.
            sorted_indices: The index of x sorted by length

        >>> x
        [
            [5, 6, 7, 0],
            [1, 2, 3, 4],
            [8, 9, 0, 0]
        ]
        >>> mask
        [
            [True, True, True, False],
            [True, True, True, True],
            [True, True, False, False]
        ]
        >>> self.pack_padded_sequence(x, mask, 0)
        [1, 5, 8, 2, 6 ,9 , 3 , 7, 4]
        """
        # sentence length
        mask = layers.cast(mask, 'int64')
        lens = layers.reduce_sum(mask, dim=-1)
        # Sort by sentence length in descending order
        _, sorted_indices = layers.argsort(lens, descending=True)
        sorted_x = layers.index_select(x, sorted_indices)
        sorted_mask = layers.index_select(mask, sorted_indices)
        # transpose
        t_x = layers.transpose(sorted_x, perm=[1, 0, 2])
        t_mask = layers.transpose(sorted_mask, perm=[1, 0])
        # mask_select
        new_x = nn.masked_select(t_x, t_mask)
        # Batch by step
        batch_sizes = layers.reduce_sum(t_mask, -1)
        # remove zero
        batch_sizes = nn.masked_select(batch_sizes, batch_sizes != 0)

        return new_x, batch_sizes.numpy().tolist(), sorted_indices
Esempio n. 13
0
 def topp_sampling(self, probs):
     sorted_probs, sorted_idx = layers.argsort(probs, descending=True)
     cum_sorted_probs = layers.cumsum(sorted_probs, axis=1, exclusive=True)
     lt_cond = paddle.cast(
         paddle.less_than(
             cum_sorted_probs,
             layers.fill_constant_batch_size_like(cum_sorted_probs,
                                                  cum_sorted_probs.shape,
                                                  cum_sorted_probs.dtype,
                                                  self.topp)), "float32")
     old_probs = probs
     candidate_probs = sorted_probs * lt_cond
     probs = candidate_probs / paddle.sum(
         candidate_probs, axis=-1, keep_dim=True)
     sampling_ids = layers.sampling_id(probs, dtype="int")
     sampling_ids = paddle.index_sample(sorted_idx,
                                        paddle.unsqueeze(sampling_ids, [1]))
     sampling_ids = paddle.squeeze(sampling_ids, [1])
     probs = old_probs
     return probs, sampling_ids
Esempio n. 14
0
    def forward(self, x, seq_mask, pad_index, hx=None):
        """Forward network"""
        x, batch_sizes, sorted_indices = self.pack_padded_sequence(
            x, seq_mask, pad_index)
        _, unsorted_indices = layers.argsort(sorted_indices)
        batch_size = batch_sizes[0]
        h_n, c_n = [], []

        if hx is None:
            ih = layers.zeros(shape=(self.num_layers * 2, batch_size,
                                     self.hidden_size),
                              dtype=x[0].dtype)
            h, c = ih, ih
        else:
            h, c = self.permute_hidden(hx, sorted_indices)
        h = layers.reshape(h, shape=(self.num_layers, 2, -1, self.hidden_size))
        c = layers.reshape(c, shape=(self.num_layers, 2, -1, self.hidden_size))

        for i in range(self.num_layers):
            x = layers.split(x, batch_sizes, dim=0)
            if self.training and self.dropout > 0:
                mask = SharedDropout.get_mask(x[0], self.dropout)
                x = [j * mask[:len(j)] for j in x]
            x_f, (h_f, c_f) = self.layer_forward(x=x,
                                                 hx=(h[i, 0], c[i, 0]),
                                                 cell=self.f_cells[i],
                                                 batch_sizes=batch_sizes)
            x_b, (h_b, c_b) = self.layer_forward(x=x,
                                                 hx=(h[i, 1], c[i, 1]),
                                                 cell=self.b_cells[i],
                                                 batch_sizes=batch_sizes,
                                                 reverse=True)
            x = layers.concat((x_f, x_b), axis=-1)
            h_n.append(layers.stack((h_f, h_b)))
            c_n.append(layers.stack((c_f, c_b)))
        x = self.pad_packed_sequence(x, batch_sizes, unsorted_indices)
        hx = layers.concat(h_n, axis=0), layers.concat(c_n, axis=0)
        hx = self.permute_hidden(hx, unsorted_indices)

        return x, hx
Esempio n. 15
0
    def __build_edges(self, edges, node_shift, edge_lod):
        """ Merge subgraph edges. 
        """
        if isinstance(edges, tuple):
            src, dst = edges
        else:
            src = edges[:, 0]
            dst = edges[:, 1]

        src = L.reshape(src, [-1])
        dst = L.reshape(dst, [-1])
        src = paddle_helper.ensure_dtype(src, dtype="int32")
        dst = paddle_helper.ensure_dtype(dst, dtype="int32")
        # preprocess edges
        lod_dst = L.lod_reset(dst, edge_lod)
        node_shift = L.reshape(node_shift, [-1, 1])
        node_shift = L.sequence_expand_as(node_shift, lod_dst)
        node_shift = L.reshape(node_shift, [-1])
        src = src + node_shift
        dst = dst + node_shift
        # sort edges
        self._edges_dst, index = L.argsort(dst)
        self._edges_src = L.gather(src, index, overwrite=False)
Esempio n. 16
0
def topk_pool(gw, score, graph_id, ratio):
    """Implementation of topk pooling, where k means pooling ratio.
    
    Args:
        gw: Graph wrapper object.

        score: The attention score of all nodes, which is used to select 
               important nodes.

        graph_id: The graphs that the nodes belong to.

        ratio: The pooling ratio of nodes we want to select.

    Return: 
        perm: The index of nodes we choose.

        ratio_length: The selected node numbers of each graph.
    """

    graph_lod = gw.graph_lod
    graph_nodes = gw.num_nodes
    num_graph = gw.num_graph

    num_nodes = L.ones(shape=[graph_nodes], dtype="float32")
    num_nodes = L.lod_reset(num_nodes, graph_lod)
    num_nodes_per_graph = L.sequence_pool(num_nodes, pool_type='sum')
    max_num_nodes = L.reduce_max(num_nodes_per_graph, dim=0)
    max_num_nodes = L.cast(max_num_nodes, dtype="int32")

    index = L.arange(0, gw.num_nodes, dtype="int64")
    offset = L.gather(graph_lod, graph_id, overwrite=False)
    index = (index - offset) + (graph_id * max_num_nodes)
    index.stop_gradient = True

    # padding
    dense_score = L.fill_constant(shape=[num_graph * max_num_nodes],
                                  dtype="float32",
                                  value=-999999)
    index = L.reshape(index, shape=[-1])
    dense_score = L.scatter(dense_score, index, updates=score)
    num_graph = L.cast(num_graph, dtype="int32")
    dense_score = L.reshape(dense_score, shape=[num_graph, max_num_nodes])

    # record the sorted index
    _, sort_index = L.argsort(dense_score, axis=-1, descending=True)

    # recover the index range
    graph_lod = graph_lod[:-1]
    graph_lod = L.reshape(graph_lod, shape=[-1, 1])
    graph_lod = L.cast(graph_lod, dtype="int64")
    sort_index = L.elementwise_add(sort_index, graph_lod, axis=-1)
    sort_index = L.reshape(sort_index, shape=[-1, 1])

    # use sequence_slice to choose selected node index
    pad_lod = L.arange(0, (num_graph + 1) * max_num_nodes,
                       step=max_num_nodes,
                       dtype="int32")
    sort_index = L.lod_reset(sort_index, pad_lod)
    ratio_length = L.ceil(num_nodes_per_graph * ratio)
    ratio_length = L.cast(ratio_length, dtype="int64")
    ratio_length = L.reshape(ratio_length, shape=[-1, 1])
    offset = L.zeros(shape=[num_graph, 1], dtype="int64")
    choose_index = L.sequence_slice(input=sort_index,
                                    offset=offset,
                                    length=ratio_length)

    perm = L.reshape(choose_index, shape=[-1])
    return perm, ratio_length
Esempio n. 17
0
     for r1, r2, r3 in zip(token_ids, seg_ids, labels):
         print(r1)
         print(r2)
         print(r3)
         print(convert_ids_to_tokens(tokenizer.vocab, r1))
 for step, d in enumerate(tqdm(train_batch_data, desc='training')):
     ids, sids, labels = d
     # print(ids.shape, sids.shape, labels.shape)
     ids, sids, labels = FD.to_variable(ids), FD.to_variable(
         sids), FD.to_variable(labels)
     loss, logits = model(ids, sids, labels=labels)
     if args.ohem_ratio > 0:
         labels = L.reshape(labels, [-1, 1])
         loss = L.softmax_with_cross_entropy(logits, labels)
         N = int(args.bsz * args.ohem_ratio)
         top_loss = L.argsort(loss, axis=0)[0][-N:]
         if args.debug:
             print(loss)
             print(top_loss)
             print(N)
         loss = L.reduce_sum(top_loss) / N
     loss.backward()
     global_step += 1
     if step % 1000 == 0 and step > 0:
         print('train loss %.5f lr %.3e' %
               (loss.numpy(), opt.current_step_lr()))
     opt.minimize(loss)
     model.clear_gradients()
     if global_step % args.save_steps == 0:
         F.save_dygraph(model.state_dict(),
                        args.save_dir + '_%s' % global_step)
    def __call__(self, step_fn, state):
        """
        Running beam search.

        @param : step_fn : decoding one step
        @type : function

        @param : state : initial state
        @type : dict
        """
        batch_size = state["batch_size"]
        beam_size = self.beam_size

        # shape: [batch_size, 1]
        pos_index = layers.range(0, batch_size, 1, dtype="int64")
        pos_index = layers.scale(pos_index, beam_size)
        pos_index = F.unsqueeze(pos_index, [1])

        # shape: [batch_size, beam_size, 1]
        predictions = layers.fill_constant(shape=[batch_size, beam_size, 1],
                                           dtype="int64",
                                           value=self.bos_id)

        # initial input
        state["pred_token"] = predictions[:, :1]
        # shape: [batch_size, vocab_size]
        scores, state = step_fn(state)

        unk_penalty = np.zeros(self.vocab_size, dtype="float32")
        unk_penalty[self.unk_id] = -1e10
        unk_penalty = layers.assign(unk_penalty)

        eos_penalty = np.zeros(self.vocab_size, dtype="float32")
        eos_penalty[self.eos_id] = -1e10
        eos_penalty = layers.assign(eos_penalty)

        scores_after_end = np.full(self.vocab_size, -1e10, dtype="float32")
        scores_after_end[self.pad_id] = 0
        scores_after_end = layers.assign(scores_after_end)

        if self.ignore_unk:
            scores = scores + unk_penalty
        scores = scores + eos_penalty

        # shape: [batch_size, beam_size]
        sequence_scores, preds = layers.topk(scores, self.beam_size)

        predictions = layers.concat(
            [predictions, F.unsqueeze(preds, [2])], axis=2)
        state = repeat(state, beam_size)

        parent_idx_list = []
        pred_list = []

        for step in range(2, self.max_gen_len + 1):
            pre_ids = predictions[:, :, -1:]
            state["pred_token"] = layers.reshape(
                pre_ids, shape=[batch_size * beam_size, 1, 1])
            state["pred_mask"] = 1 - F.equal(state["pred_token"], self.pad_id)
            state["pred_pos"] = state["pred_pos"] + 1
            scores, state = step_fn(state)

            # Generate next
            # scores shape: [batch_size, beam_size, vocab_size]
            if self.ignore_unk:
                scores = scores + unk_penalty

            if step <= self.min_gen_len:
                scores = scores + eos_penalty

            scores = layers.reshape(
                scores, shape=[batch_size, beam_size, self.vocab_size])

            # previous token is [PAD] or [EOS]
            pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(
                pre_ids, self.pad_id)

            scores = scores * (1 - pre_eos_mask) + \
                layers.expand(pre_eos_mask, [1, 1, self.vocab_size]) * scores_after_end
            if self.length_average:
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 -
                                                                    1 / step)
                sequence_scores = F.unsqueeze(sequence_scores,
                                              [2]) * scaled_value
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step)
                scores = scores * scaled_value
            elif self.length_penalty >= 0.0:
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
                    (math.pow((4 + step) / (5 + step), self.length_penalty))
                sequence_scores = layers.elementwise_mul(scaled_value,
                                                         sequence_scores,
                                                         axis=0)
                scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
                    (math.pow(1 / (5 + step), self.length_penalty))
                scores = scores * scaled_value
            scores = layers.elementwise_add(scores, sequence_scores, axis=0)
            scores = layers.reshape(
                scores, shape=[batch_size, beam_size * self.vocab_size])

            topk_scores, topk_indices = layers.topk(scores, beam_size)
            vocab_size = layers.fill_constant(shape=[1],
                                              dtype="int64",
                                              value=self.vocab_size)
            parent_idx = layers.elementwise_floordiv(topk_indices, vocab_size)
            preds = layers.elementwise_mod(topk_indices, vocab_size)

            # Gather state / sequence_scores
            parent_idx = layers.elementwise_add(parent_idx, pos_index, axis=0)
            parent_idx = layers.reshape(parent_idx, [batch_size * beam_size])
            state = gather(state, parent_idx)
            sequence_scores = topk_scores

            predictions = layers.reshape(predictions,
                                         shape=[batch_size * beam_size, step])
            predictions = gather(predictions, parent_idx)
            predictions = layers.reshape(predictions,
                                         shape=[batch_size, beam_size, step])
            predictions = layers.concat(
                [predictions, F.unsqueeze(preds, [2])], axis=2)

        pre_ids = predictions[:, :, -1]
        pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(
            pre_ids, self.pad_id)
        sequence_scores = sequence_scores * pre_eos_mask + layers.scale(
            1 - pre_eos_mask, -1e10)

        _, indices = layers.argsort(sequence_scores, axis=1)
        indices = indices + pos_index
        indices = layers.reshape(indices, [-1])
        sequence_scores = layers.reshape(sequence_scores,
                                         [batch_size * beam_size])
        predictions = layers.reshape(predictions, [batch_size * beam_size, -1])
        sequence_scores = gather(sequence_scores, indices)
        predictions = layers.gather(predictions, indices)
        sequence_scores = layers.reshape(sequence_scores,
                                         [batch_size, beam_size])
        predictions = layers.reshape(predictions, [batch_size, beam_size, -1])

        results = {
            "preds": predictions[:, -1],
            "scores": sequence_scores[:, -1]
        }
        return results
Esempio n. 19
0
    def inference(self, model, inputs, outputs):
        """
        Run inference.

        Args:
            inputs(dict): Its key is input name(str) and its value is a Variable.
            model(object): A generate model. Need to implement `_generation_network` and `_calc_logits`.

        Returns:
            dict(str:Variable): Its key is output name(str) and its value is a Variable.
        """
        # prepare while loop
        max_len = layers.fill_constant(shape=[1],
                                       dtype="int64",
                                       value=self.max_dec_len,
                                       force_cpu=True)
        min_len = layers.fill_constant(shape=[1],
                                       dtype="int64",
                                       value=self.min_dec_len,
                                       force_cpu=True)
        step_idx = layers.fill_constant(shape=[1],
                                        dtype="int64",
                                        value=0,
                                        force_cpu=True)

        ids = layers.array_write(layers.reshape(inputs["tgt_ids"], (-1, 1)),
                                 step_idx)
        pos_biases = layers.array_write(
            layers.reshape(inputs["tgt_pos"], (-1, 1)), step_idx)
        scores = layers.array_write(inputs["init_score"], step_idx)
        tgt_generation_mask = layers.array_write(inputs["tgt_generation_mask"],
                                                 step_idx)
        parent_idx = inputs["parent_idx"]

        if self.decoding_strategy == "beam_search":
            beam_size = self.beam_size
        else:
            beam_size = 1

        eos_penalty = np.zeros(self.vocab_size, dtype="float32")
        eos_penalty[self.eos_id] = -1e9
        eos_penalty = layers.assign(eos_penalty)

        token_penalty = np.zeros(self.vocab_size, dtype="float32")
        token_penalty[self.unk_id] = -1e9
        if self.mask_id >= 0:
            token_penalty[self.mask_id] = -1e9
        token_penalty = layers.assign(token_penalty)

        # start while loop
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
            pre_scores = layers.array_read(array=scores, i=step_idx)
            pos_bias = layers.array_read(array=pos_biases, i=step_idx)
            pos_bias = layers.gather(input=pos_bias, index=parent_idx)

            tmp_tgt_generation_mask = layers.array_read(tgt_generation_mask,
                                                        i=step_idx)
            dtype = tmp_tgt_generation_mask.dtype

            append_mask = layers.fill_constant_batch_size_like(
                input=pre_ids, value=1.0, shape=[-1, 1, 1], dtype=dtype)
            tmp_tgt_generation_mask = layers.concat(
                [tmp_tgt_generation_mask, append_mask], axis=2)
            pre_mask = tmp_tgt_generation_mask = layers.gather(
                input=tmp_tgt_generation_mask, index=parent_idx)

            pre_sent = layers.fill_constant_batch_size_like(
                input=pre_mask, value=1, shape=[-1, 1, 1], dtype=pre_ids.dtype)

            if self.continuous_position:
                pre_pos = layers.elementwise_mul(
                    x=layers.fill_constant_batch_size_like(
                        input=pre_mask,
                        value=1,
                        shape=[-1, 1, 1],
                        dtype=pre_ids.dtype),
                    y=step_idx,
                    axis=0) + pos_bias
            else:
                pre_pos = layers.elementwise_mul(
                    x=layers.fill_constant_batch_size_like(
                        input=pre_mask,
                        value=1,
                        shape=[-1, 1, 1],
                        dtype=pre_ids.dtype),
                    y=step_idx,
                    axis=0)

            dec_out, _ = model._generation_network(
                token_ids=pre_ids,
                type_ids=pre_sent,
                pos_ids=pre_pos,
                generation_mask=tmp_tgt_generation_mask,
                gather_idx=parent_idx)
            logits = model._calc_logits(dec_out)

            # ignore unk and mask token
            if self.ignore_unk:
                logits = layers.elementwise_add(logits, token_penalty, axis=1)

            # min dec length
            min_len_cond = layers.less_than(x=step_idx, y=min_len)

            def min_len_penalty():
                """Plus minimum length penalty."""
                return layers.elementwise_add(logits, eos_penalty, axis=1)

            def no_penalty():
                """No penalty."""
                return logits

            logits = layers.case([(min_len_cond, min_len_penalty)],
                                 default=no_penalty)

            # get probs
            probs = layers.softmax(logits / self.temperature)

            if self.decoding_strategy == "beam_search":
                topk_scores, topk_indices = layers.topk(input=probs,
                                                        k=beam_size)
            else:
                if self.decoding_strategy.startswith("sampling"):
                    sampling_ids = layers.sampling_id(probs, dtype="int")
                elif self.decoding_strategy.startswith("topk_sampling"):
                    topk_probs, _ = layers.topk(input=probs, k=self.topk)
                    ge_cond = layers.cast(
                        layers.greater_equal(
                            probs, layers.unsqueeze(topk_probs[:, -1], [1])),
                        "float32")
                    old_probs = probs
                    probs = probs * ge_cond / layers.reduce_sum(
                        topk_probs, dim=-1, keep_dim=True)
                    sampling_ids = layers.sampling_id(probs, dtype="int")
                    probs = old_probs
                elif self.decoding_strategy.startswith("topp_sampling"):
                    sorted_probs, sorted_idx = layers.argsort(probs,
                                                              descending=True)
                    cum_sorted_probs = layers.cumsum(sorted_probs,
                                                     axis=1,
                                                     exclusive=True)
                    lt_cond = layers.cast(
                        layers.less_than(
                            cum_sorted_probs,
                            layers.fill_constant_batch_size_like(
                                cum_sorted_probs, cum_sorted_probs.shape,
                                cum_sorted_probs.dtype, self.topp)), "float32")
                    old_probs = probs
                    candidate_probs = sorted_probs * lt_cond
                    probs = candidate_probs / layers.reduce_sum(
                        candidate_probs, dim=-1, keep_dim=True)
                    sampling_ids = layers.sampling_id(probs, dtype="int")
                    sampling_ids = layers.index_sample(
                        sorted_idx, layers.unsqueeze(sampling_ids, [1]))
                    sampling_ids = layers.squeeze(sampling_ids, [1])
                    probs = old_probs
                else:
                    raise ValueError(self.decoding_strategy)

                sampling_scores = layers.one_hot(
                    layers.unsqueeze(sampling_ids, [1]), probs.shape[1])
                sampling_scores = sampling_scores * probs - (
                    1 - sampling_scores) * 1e3
                topk_scores, topk_indices = layers.topk(input=sampling_scores,
                                                        k=1)

            pre_len = layers.cast(step_idx, "float32")
            layers.increment(x=step_idx, value=1.0, in_place=True)
            cur_len = layers.cast(step_idx, "float32")

            # update scores
            if self.length_average:
                accu_scores = layers.elementwise_add(x=layers.log(topk_scores),
                                                     y=pre_scores * pre_len,
                                                     axis=0) / cur_len
            elif self.length_penalty > 0:
                pre_lp = layers.pow((5 + pre_len) / 6, self.length_penalty)
                cur_lp = layers.pow((5 + cur_len) / 6, self.length_penalty)
                accu_scores = layers.elementwise_add(x=layers.log(topk_scores),
                                                     y=pre_scores * pre_lp,
                                                     axis=0) / cur_lp
            else:
                accu_scores = layers.elementwise_add(x=layers.log(topk_scores),
                                                     y=pre_scores,
                                                     axis=0)
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
            accu_scores = layers.lod_reset(accu_scores, pre_ids)
            selected_ids, selected_scores, gather_idx = layers.beam_search(
                pre_ids=pre_ids,
                pre_scores=pre_scores,
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=self.eos_id,
                return_parent_idx=True)

            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
            layers.array_write(pre_mask, i=step_idx, array=tgt_generation_mask)
            layers.array_write(pos_bias, i=step_idx, array=pos_biases)

            layers.assign(gather_idx, parent_idx)

            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)

        finished_ids, finished_scores = layers.beam_search_decode(
            ids, scores, beam_size=beam_size, end_id=self.eos_id)

        predictions = {
            "finished_ids": finished_ids,
            "finished_scores": finished_scores,
            "token_ids": inputs["token_ids"],
            "data_id": inputs["data_id"]
        }
        return predictions
Esempio n. 20
0
    def ohem_conf_loss(self, pred_allboxes_conf, batch_size, labels_neg_mask,
                       labels_pos_mask, labels_pos_index, class_vectors,
                       labels_pos_cid):
        batch_conf = P.reshape(pred_allboxes_conf, (-1, self.num_classes))
        loss_c = log_sum_exp(batch_conf) - batch_conf[:, 0]
        loss_c = P.reshape(loss_c, (batch_size, -1))  # (batch_size, 19248)
        labels_neg_mask = P.concat(labels_neg_mask,
                                   axis=0)  # (batch_size*19248, 1)
        labels_neg_mask = P.reshape(labels_neg_mask,
                                    (batch_size, -1))  # (batch_size, 19248)
        loss_c = labels_neg_mask * loss_c  # 只留下负样本损失, (batch_size, 19248)
        sorted_loss_c, loss_idx = P.argsort(loss_c, axis=-1, descending=True)

        labels_pos_mask = P.concat(labels_pos_mask,
                                   axis=0)  # (batch_size*19248, 1)
        labels_pos_mask = P.reshape(labels_pos_mask,
                                    (batch_size, -1))  # (batch_size, 19248)
        num_pos = P.cast(P.reduce_sum(labels_pos_mask, dim=1),
                         'int32')  # (batch_size, )
        num_neg = self.negpos_ratio * num_pos  # (batch_size, )
        neg_topk_mask = []
        for idx in range(batch_size):
            desc = P.range(num_neg[idx],
                           num_neg[idx] - P.shape(labels_pos_mask)[1], -1,
                           'int32')
            neg_topk_mask.append(desc)
        neg_topk_mask = P.concat(neg_topk_mask, axis=0)  # (batch_size*19248, )
        neg_topk_mask = P.reshape(neg_topk_mask,
                                  (batch_size, -1))  # (batch_size, 19248)
        neg_topk_mask = P.cast(neg_topk_mask > 0,
                               'float32')  # (batch_size, 19248)
        sorted_loss_c = neg_topk_mask * sorted_loss_c
        selected_poss = []
        selected_negs = []
        selected_pos_class_vectors = []
        selected_neg_class_vectors = []
        for idx in range(batch_size):
            selected_neg_idx_idx = P.where(sorted_loss_c[idx] > 0)
            selected_neg_idx_idx.stop_gradient = True
            selected_neg_idx = P.gather(loss_idx[idx], selected_neg_idx_idx)
            selected_neg_idx.stop_gradient = True
            selected_neg = P.gather(pred_allboxes_conf[idx], selected_neg_idx)
            selected_neg.stop_gradient = True
            selected_negs.append(selected_neg)
            selected_pos = P.gather(pred_allboxes_conf[idx],
                                    labels_pos_index[idx])
            selected_pos.stop_gradient = True
            selected_poss.append(selected_pos)

            zeros = P.fill_constant(shape=[
                P.shape(selected_neg)[0],
            ],
                                    value=0,
                                    dtype='int32')
            zeros.stop_gradient = True
            selected_neg_class_vector = P.gather(class_vectors, zeros)
            selected_neg_class_vector.stop_gradient = True
            selected_neg_class_vectors.append(selected_neg_class_vector)

            labels_pos_cid.stop_gradient = True
            labels_pos_index[idx].stop_gradient = True
            selected_pos_cid = P.gather(labels_pos_cid[idx],
                                        labels_pos_index[idx])
            selected_pos_cid.stop_gradient = True
            selected_pos_class_vector = P.gather(class_vectors,
                                                 selected_pos_cid)
            selected_pos_class_vector.stop_gradient = True
            selected_pos_class_vectors.append(selected_pos_class_vector)
        selected_negs = P.concat(selected_negs, axis=0)  # (?, 1+80)
        selected_poss = P.concat(selected_poss, axis=0)  # (?, 1+80)
        pred_ = P.concat([selected_negs, selected_poss], axis=0)  # (?, 1+80)
        selected_neg_class_vectors = P.concat(selected_neg_class_vectors,
                                              axis=0)  # (?, 1+80)
        selected_pos_class_vectors = P.concat(selected_pos_class_vectors,
                                              axis=0)  # (?, 1+80)
        labels_ = P.concat(
            [selected_neg_class_vectors, selected_pos_class_vectors],
            axis=0)  # (?, 1+80)

        # softmax交叉熵
        fenzi = P.exp(pred_)
        fenmu = P.reduce_sum(fenzi, dim=1, keep_dim=True)
        pred_prob = fenzi / P.expand_as(fenmu, target_tensor=fenzi)
        conf_loss = labels_ * (0 - P.log(pred_prob + 1e-9))  # 交叉熵,加了极小的常数防止nan
        conf_loss = P.reduce_sum(conf_loss)
        return conf_loss
Esempio n. 21
0
    def get_seg_single(self, cate_preds, mask_proto, kernel_preds,
                       featmap_size, resize_shape, ori_shape):
        '''

        :param cate_preds:   [所有格子数, 80]
        :param mask_proto:   [1, 256, s4, s4]   掩码原型
        :param kernel_preds:   [所有格子数, 256]   每个格子生成的卷积核,是1x1卷积核,输入通道数是256,即掩码原型的通道数。
        :param featmap_size:   (s4, s4)
        :param resize_shape:   shape=[3, ]
        :param ori_shape:      shape=[3, ]
        :return:
        '''
        # overall info.
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4
                              )  # 输入网络的图片大小
        cfg = self.nms_cfg

        # 第一次过滤,分数过滤
        inds = L.where(cate_preds > cfg['score_thr'])  # [M, 2]

        # if len(inds) == 0:
        #     return None
        # 静态图里写条件判断太难了。
        def exist_objs_1(inds, cate_preds):
            inds.stop_gradient = True
            scores = L.gather_nd(cate_preds, inds)  # [M, ]   M个物体的分数
            return inds, scores

        def no_objs_1(cate_preds):
            inds = L.zeros((1, 2), np.int64)
            inds.stop_gradient = True
            scores = L.gather_nd(cate_preds,
                                 inds) - 99.0  # [M, ]   M个物体的分数。后面会被过滤掉。
            return inds, scores

        # 是否有物体
        inds, scores = L.cond(
            L.shape(inds)[0] == 0, lambda: no_objs_1(cate_preds),
            lambda: exist_objs_1(inds, cate_preds))

        classes = inds[:, 1]  # [M, ]   M个物体的类别id
        kernel_preds = L.gather(kernel_preds, inds[:,
                                                   0])  # [M, 256]   M个物体的卷积核

        n_stage = len(self.seg_num_grids)  # 5个输出层
        strides = []
        for ind_ in range(n_stage):
            st = L.zeros((1, ), dtype=np.float32) + self.strides[ind_]
            st = L.expand(st, [
                self.seg_num_grids[ind_]**2,
            ])  # [40*40, ]
            strides.append(st)
        strides = L.concat(strides, axis=0)
        strides.stop_gradient = True
        strides = L.gather(strides, inds[:, 0])  # [M, ]   M个物体的下采样倍率

        # mask encoding.原版SOLO中的写法。1x1的卷积核卷积掩码原型,即可得到掩码。
        # M, C = kernel_preds.shape
        # kernel_preds = kernel_preds.view(M, C, 1, 1)    # 被当做卷积核使
        # seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
        # 1x1的卷积核卷积掩码原型,等价于矩阵相乘。注意,3x3卷积核的话可不等价。
        # 这里是由于暂时没发现等价api,所以用矩阵相乘代替。solov2和yolact在这里是一样的。
        mask_proto = L.squeeze(mask_proto, axes=[0])  # [256, s4, s4]
        mask_proto = L.transpose(mask_proto, perm=[1, 2, 0])  # [s4, s4, 256]
        masks = L.matmul(mask_proto, kernel_preds,
                         transpose_y=True)  # [s4, s4, M]
        masks = L.sigmoid(masks)  # [s4, s4, M]
        masks = L.transpose(masks, perm=[2, 0, 1])  # [M, s4, s4]

        # mask.
        seg_masks = L.cast(masks > cfg['mask_thr'],
                           'float32')  # [M, s4, s4]   前景的话值为1
        sum_masks = L.reduce_sum(seg_masks, dim=[1, 2])  # [M, ]   M个物体的掩码面积

        # 第二次过滤,下采样倍率过滤。掩码的面积 超过 下采样倍率 才保留下来。
        keep = L.where(sum_masks > strides)

        # if keep.sum() == 0:
        #     return None

        # 静态图里写条件判断太难了。
        def exist_objs_2(keep, seg_masks, masks, sum_masks, scores, classes):
            keep = L.reshape(keep, (-1, ))  # [M2, ]
            keep.stop_gradient = True
            seg_masks = L.gather(seg_masks, keep)  # [M2, s4, s4]   M2个物体的掩码
            masks = L.gather(masks, keep)  # [M2, s4, s4]   M2个物体的掩码概率
            sum_masks = L.gather(sum_masks, keep)  # [M2, ]   M2个物体的掩码面积
            scores = L.gather(scores, keep)  # [M2, ]   M2个物体的分数
            classes = L.gather(classes, keep)  # [M2, ]   M2个物体的类别id
            return seg_masks, masks, sum_masks, scores, classes

        def no_objs_2(seg_masks, masks, sum_masks, scores, classes):
            keep = L.zeros((1, ), np.int64)
            keep.stop_gradient = True
            seg_masks = L.gather(seg_masks, keep)  # [M2, s4, s4]   M2个物体的掩码
            masks = L.gather(masks, keep)  # [M2, s4, s4]   M2个物体的掩码概率
            sum_masks = L.gather(sum_masks, keep)  # [M2, ]   M2个物体的掩码面积
            scores = L.gather(scores,
                              keep) - 99.0  # [M2, ]   M2个物体的分数。负分数,后面会被过滤掉。
            classes = L.gather(classes, keep)  # [M2, ]   M2个物体的类别id
            return seg_masks, masks, sum_masks, scores, classes

        # 是否有物体
        seg_masks, masks, sum_masks, scores, classes = L.cond(
            L.shape(keep)[0] == 0,
            lambda: no_objs_2(seg_masks, masks, sum_masks, scores, classes),
            lambda: exist_objs_2(keep, seg_masks, masks, sum_masks, scores,
                                 classes))

        # mask scoring.
        # [M2, ]   前景的掩码概率求和,再除以掩码面积。即M2个物体的前景部分的平均掩码概率
        avg_prob = L.reduce_sum(masks * seg_masks, dim=[1, 2]) / sum_masks
        scores *= avg_prob  # [M2, ]   M2个物体的最终分数 = 分类概率 * 平均掩码概率

        # 第三次过滤,只保留得分前cfg['nms_pre']个物体
        _, sort_inds = L.argsort(scores, axis=-1,
                                 descending=True)  # 最终分数降序。最大值的下标,第2大值的下标,...
        sort_inds = sort_inds[:cfg['nms_pre']]  # 最多cfg['nms_pre']个物体。

        seg_masks = L.gather(seg_masks, sort_inds)  # [M3, s4, s4]   M3个物体的掩码
        masks = L.gather(masks, sort_inds)  # [M3, s4, s4]   M3个物体的掩码概率
        sum_masks = L.gather(sum_masks, sort_inds)  # [M3, ]   M3个物体的掩码面积
        scores = L.gather(scores, sort_inds)  # [M3, ]   M3个物体的分数
        classes = L.gather(classes, sort_inds)  # [M3, ]   M3个物体的类别id

        # Matrix NMS
        scores = matrix_nms(seg_masks,
                            classes,
                            scores,
                            kernel=cfg['kernel'],
                            sigma=cfg['sigma'],
                            sum_masks=sum_masks)

        # 第四次过滤,分数过滤
        keep = L.where(scores >= cfg['update_thr'])

        # if keep.sum() == 0:
        #     return None

        def exist_objs_3(keep, masks, classes, scores, upsampled_size_out,
                         resize_shape, ori_shape):
            keep = L.reshape(keep, (-1, ))
            keep.stop_gradient = True
            masks = L.gather(masks, keep)  # [M4, s4, s4]   M4个物体的掩码概率
            scores = L.gather(scores, keep)  # [M4, ]   M4个物体的分数
            classes = L.gather(classes, keep)  # [M4, ]   M4个物体的类别id

            # 第五次过滤,只保留得分前cfg['max_per_img']个物体
            _, sort_inds = L.argsort(scores, axis=-1, descending=True)
            sort_inds = sort_inds[:cfg['max_per_img']]
            sort_inds.stop_gradient = True

            masks = L.gather(masks, sort_inds)  # [M5, s4, s4]   M5个物体的掩码概率
            scores = L.gather(scores, sort_inds)  # [M5, ]   M5个物体的分数
            classes = L.gather(classes, sort_inds)  # [M5, ]   M5个物体的类别id

            masks = L.resize_bilinear(
                L.unsqueeze(masks, axes=[0]),
                out_shape=upsampled_size_out,
                align_corners=False,
                align_mode=0)[:, :, :resize_shape[0], :resize_shape[1]]  # 去掉黑边
            masks = L.resize_bilinear(masks,
                                      out_shape=ori_shape[:2],
                                      align_corners=False,
                                      align_mode=0)  # 插值成原图大小
            masks = L.cast(masks > cfg['mask_thr'], 'float32')[0]
            return masks, classes, scores

        def no_objs_3():
            masks = L.zeros([1, 1, 1], 'float32') - 1.0
            classes = L.zeros([
                1,
            ], 'int64') - 1
            scores = L.zeros([
                1,
            ], 'float32') - 2.0
            return masks, classes, scores

        # 是否有物体
        masks, classes, scores = L.cond(
            L.shape(keep)[0] == 0, no_objs_3,
            lambda: exist_objs_3(keep, masks, classes, scores,
                                 upsampled_size_out, resize_shape, ori_shape))
        return masks, classes, scores