Exemple #1
0
    def construct(self, teacher, student, neg):
        expand_dims = ops.ExpandDims()  # unsqueeze算子
        teacher_vgg, student_vgg, neg_vgg = self.vgg(teacher), self.vgg(
            student), self.vgg(neg)

        loss = 0
        for i in range(len(teacher_vgg)):
            neg_i = expand_dims(neg_vgg[i], 0)  # [8, n_feats, w, h]
            # neg_i = neg_i.repeat(student_vgg[i].shape[0], axis=0)  #TODO:1.3版本才会支持Tensor.repeat
            neg_i = np.repeat(neg_i, student_vgg[i].shape[0],
                              axis=0)  # [16, 8, n_feats, w, h]
            neg_i = neg_i.transpose((1, 0, 2, 3, 4))  # [8, 16, n_feats, w, h]

            d_ts = self.l1(stop_gradient(teacher_vgg[i]), student_vgg[i])
            # d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs().sum(axis=0).mean() #TODO:1.3版本才支持Tensor.sum
            d_sn = (stop_gradient(neg_i) -
                    student_vgg[i]).abs()  # [8, 16, n_feats, w, h]
            # print(d_sn.shape)
            reduceSum = ops.ReduceSum()
            d_sn = reduceSum(d_sn, 0).mean()
            # print(d_sn)

            contrastive = d_ts / (d_sn + 1e-7)
            loss += self.weights[i] * contrastive

        return self.get_loss(loss)
Exemple #2
0
    def construct(self, logit_paf, logit_heatmap, gt_paf, gt_heatmap,
                  ignore_mask):
        # Input
        # ignore_mask, make sure the ignore_mask the 0-1 array instead of the bool-false array
        heatmaps_loss = []
        pafs_loss = []
        total_loss = 0

        paf_masks = self.tile(self.expand_dims(ignore_mask, 1),
                              (1, self.shape(gt_paf)[1], 1, 1))
        heatmap_masks = self.tile(self.expand_dims(ignore_mask, 1),
                                  (1, self.shape(gt_heatmap)[1], 1, 1))

        paf_masks = F.stop_gradient(paf_masks)
        heatmap_masks = F.stop_gradient(heatmap_masks)
        for logit_paf_t, logit_heatmap_t in zip(logit_paf, logit_heatmap):
            # TEST
            # tensor1 -- tuple
            # tensor1 = self.maxoftensor(logit_paf_t)[1]
            # tensor2 = self.maxoftensor(logit_heatmap_t)[1]
            # tensor3 = self.maxoftensor(tensor1)[1]
            # tensor4 = self.maxoftensor(tensor2)[1]
            # self.print("paf",tensor3)
            # self.print("heatmaps",tensor2)
            pafs_loss_t = self.mean_square_error(logit_paf_t, gt_paf,
                                                 paf_masks)
            heatmaps_loss_t = self.mean_square_error(logit_heatmap_t,
                                                     gt_heatmap, heatmap_masks)

            total_loss += pafs_loss_t + heatmaps_loss_t
            heatmaps_loss.append(heatmaps_loss_t)
            pafs_loss.append(pafs_loss_t)

        return total_loss, heatmaps_loss, pafs_loss
 def construct(self, x1, x2, x3, x4, x5):
     z1 = x1 + x1
     z2 = x1 * x2
     t = (z1, z2, x3, x4, x5)
     z2 = t[1]
     z2 = stop_gradient(z2)
     return z1, z2, x3, x4, x5
def stop_test5(x, y):
    """ stop_test3 """
    x = x + y
    o1, o2 = stop_func(x, y)
    c = stop_gradient(o1)
    c = o2 + c
    return c
Exemple #5
0
    def _sample(self, shape=(), probs=None):
        """
        Sampling.

        Args:
            shape (tuple): The shape of the sample. Default: ().
            probs (Tensor): Event probabilities. Default: self.probs.

        Returns:
            Tensor, shape is shape(probs)[:-1] + sample_shape
        """
        if self.device_target == 'Ascend':
            raise_not_implemented_util('On d backend, sample', self.name)
        shape = self.checktuple(shape, 'shape')
        probs = self._check_param_type(probs)
        num_classes = self.shape(probs)[-1]
        batch_shape = self.shape(probs)[:-1]

        sample_shape = shape + batch_shape
        drop_dim = False
        if sample_shape == ():
            drop_dim = True
            sample_shape = (1, )

        probs_2d = self.reshape(probs, (-1, num_classes))
        sample_tensor = self.fill(self.dtype, shape, 1.0)
        sample_tensor = self.reshape(sample_tensor, (-1, 1))
        num_sample = self.shape(sample_tensor)[0]
        samples = self.multinomial(probs_2d, num_sample)
        samples = self.squeeze(self.transpose(samples, (1, 0)))
        samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
        if drop_dim:
            return self.squeeze_first_axis(samples)
        samples = stop_gradient(samples)
        return samples
 def construct(self, x, y):
     u = x + y
     v = x - y
     c, z = self.mul(u, v)
     c = stop_gradient(c)
     ret1 = c + x + y
     ret2 = z + y + y
     return ret1, ret2
 def construct(self, x):
     x = self.fc(x, self.weight)
     x = self.cast(x, mstype.float32)
     x = self.relu(self.fc2(x))
     x = self.fc2(x)
     x = stop_gradient(x)
     x = self.biasAdd(x, self.bias)
     return x
Exemple #8
0
    def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
        """
        prediction : origin output from yolo
        pred_xy: (sigmoid(xy)+grid)/grid_size
        pred_wh: (exp(wh)*anchors)/input_shape
        y_true : after normalize
        gt_box: [batch, maxboxes, xyhw] after normalize
        """
        object_mask = y_true[:, :, :, :, 4:5]
        class_probs = y_true[:, :, :, :, 5:]
        true_boxes = y_true[:, :, :, :, :4]

        grid_shape = P.Shape()(prediction)[1:3]
        grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)

        pred_boxes = self.concat((pred_xy, pred_wh))
        true_wh = y_true[:, :, :, :, 2:4]
        true_wh = P.Select()(P.Equal()(true_wh, 0.0),
                             P.Fill()(P.DType()(true_wh),
                                      P.Shape()(true_wh), 1.0),
                             true_wh)
        true_wh = P.Log()(true_wh / self.anchors * input_shape)
        # 2-w*h for large picture, use small scale, since small obj need more precise
        box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]

        gt_shape = P.Shape()(gt_box)
        gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))

        # add one more dimension for broadcast
        iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
        # gt_box is x,y,h,w after normalize
        # [batch, grid[0], grid[1], num_anchor, num_gt]
        best_iou = self.reduce_max(iou, -1)
        # [batch, grid[0], grid[1], num_anchor]

        # ignore_mask IOU too small
        ignore_mask = best_iou < self.ignore_threshold
        ignore_mask = P.Cast()(ignore_mask, ms.float32)
        ignore_mask = P.ExpandDims()(ignore_mask, -1)
        # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
        # so we turn off its gradient
        ignore_mask = F.stop_gradient(ignore_mask)

        confidence_loss = self.confidence_loss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
        class_loss = self.class_loss(object_mask, prediction[:, :, :, :, 5:], class_probs)

        object_mask_me = P.Reshape()(object_mask, (-1, 1))  # [8, 72, 72, 3, 1]
        box_loss_scale_me = P.Reshape()(box_loss_scale, (-1, 1))
        pred_boxes_me = xywh2x1y1x2y2(pred_boxes)
        pred_boxes_me = P.Reshape()(pred_boxes_me, (-1, 4))
        true_boxes_me = xywh2x1y1x2y2(true_boxes)
        true_boxes_me = P.Reshape()(true_boxes_me, (-1, 4))
        ciou = self.giou(pred_boxes_me, true_boxes_me)
        ciou_loss = object_mask_me * box_loss_scale_me * (1 - ciou)
        ciou_loss_me = self.reduce_sum(ciou_loss, ())
        loss = ciou_loss_me * 10 + confidence_loss + class_loss
        batch_size = P.Shape()(prediction)[0]
        return loss / batch_size
 def construct(self, x, y):
     u = x + y
     v = x - y
     c, z = self.mul(u, v)
     c1 = stop_gradient(c)
     c2 = c
     ret1 = c1 + x + y + c2
     ret2 = z + y + y
     return ret1, ret2
Exemple #10
0
    def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box):

        object_mask = y_true[:, :, :, :, 4:5]
        class_probs = y_true[:, :, :, :, 5:]

        grid_shape = P.Shape()(prediction)[1:3]
        grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)

        pred_boxes = self.concat((pred_xy, pred_wh))
        true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
        true_wh = y_true[:, :, :, :, 2:4]
        true_wh = P.Select()(P.Equal()(true_wh,
                                       0.0), P.Fill()(P.DType()(true_wh),
                                                      P.Shape()(true_wh), 1.0),
                             true_wh)
        true_wh = P.Log()(true_wh / self.anchors * self.input_shape)
        box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]

        gt_shape = P.Shape()(gt_box)
        gt_box = P.Reshape()(gt_box,
                             (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))

        iou = self.iou(P.ExpandDims()(pred_boxes, -2),
                       gt_box)  # [batch, grid[0], grid[1], num_anchor, num_gt]
        best_iou = self.reduce_max(iou,
                                   -1)  # [batch, grid[0], grid[1], num_anchor]
        ignore_mask = best_iou < self.ignore_threshold
        ignore_mask = P.Cast()(ignore_mask, ms.float32)
        ignore_mask = P.ExpandDims()(ignore_mask, -1)
        ignore_mask = F.stop_gradient(ignore_mask)

        xy_loss = object_mask * box_loss_scale * self.cross_entropy(
            prediction[:, :, :, :, :2], true_xy)
        wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(
            true_wh - prediction[:, :, :, :, 2:4])
        confidence_loss = self.cross_entropy(prediction[:, :, :, :, 4:5],
                                             object_mask)
        confidence_loss = object_mask * confidence_loss + (
            1 - object_mask) * confidence_loss * ignore_mask
        class_loss = object_mask * self.cross_entropy(
            prediction[:, :, :, :, 5:], class_probs)

        # Get smooth loss
        xy_loss = self.reduce_sum(xy_loss, ())
        wh_loss = self.reduce_sum(wh_loss, ())
        confidence_loss = self.reduce_sum(confidence_loss, ())
        class_loss = self.reduce_sum(class_loss, ())

        loss = xy_loss + wh_loss + confidence_loss + class_loss
        return loss / P.Shape()(prediction)[0]
Exemple #11
0
 def construct(self, x):
     tensor_dtype = x.dtype
     _check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
     if tensor_dtype == mstype.float16:
         x = self.cast(x, mstype.float32)
     mean = self.reduce_mean(x, self.axis)
     variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis)
     if not self.keep_dims:
         mean = self.squeeze(mean)
         variance = self.squeeze(variance)
     if tensor_dtype == mstype.float16:
         mean = self.cast(mean, mstype.float16)
         variance = self.cast(variance, mstype.float16)
         return mean, variance
     return mean, variance
Exemple #12
0
    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        c1 = self.maxpool(x)

        c2 = self.layer1(c1)
        identity = c2
        if not self.weights_update:
            identity = F.stop_gradient(c2)
        c3 = self.layer2(identity)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)

        return identity, c3, c4, c5
Exemple #13
0
    def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels,
                  gt_valids):
        '''
        inputs(Tensor): Inputs tensor from lstm.
        img_metas(Tensor): Image shape.
        anchor_list(Tensor): Total anchor list.
        gt_labels(Tensor): Ground truth labels.
        gt_valids(Tensor): Whether ground truth is valid.
        '''
        rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs)
        rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape)
        rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls)
        rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape)
        rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg)
        output = ()
        bbox_targets = ()
        bbox_weights = ()
        labels = ()
        label_weights = ()
        if self.training:
            for i in range(self.batch_size):
                valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\
                    mstype.int32)
                gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
                gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
                gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
                bbox_target, bbox_weight, label, label_weight = self.get_targets(
                    gt_bboxes_i, gt_labels_i,
                    self.cast(valid_flag_list, mstype.bool_), anchor_list,
                    gt_valids_i)
                bbox_weight = self.cast(bbox_weight, mstype.float16)
                label_weight = self.cast(label_weight, mstype.float16)
                bbox_targets += (bbox_target, )
                bbox_weights += (bbox_weight, )
                labels += (label, )
                label_weights += (label_weight, )
            bbox_target_with_batchsize = self.concat(bbox_targets)
            bbox_weight_with_batchsize = self.concat(bbox_weights)
            label_with_batchsize = self.concat(labels)
            label_weight_with_batchsize = self.concat(label_weights)

            bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
            bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
            label_ = F.stop_gradient(label_with_batchsize)
            label_weight_ = F.stop_gradient(label_weight_with_batchsize)
            rpn_cls_score = self.cast(rpn_cls_score, mstype.float32)
            if self.use_sigmoid_cls:
                label_ = self.cast(label_, mstype.float32)
            loss_cls = self.loss_cls(rpn_cls_score, label_)
            loss_cls = loss_cls * label_weight_
            loss_cls = self.sum_loss(loss_cls, (0, )) / self.num_expected_total
            rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32)
            bbox_target_ = self.cast(bbox_target_, mstype.float32)
            loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_)
            bbox_weight_ = self.tile(
                self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)),
                (1, 4))
            loss_reg = loss_reg * bbox_weight_
            loss_reg = self.sum_loss(loss_reg, (1, ))
            loss_reg = self.sum_loss(loss_reg, (0, )) / self.num_expected_total
            loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg
            output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori,
                      loss_cls, loss_reg)
        else:
            output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori,
                      self.placeh1, self.placeh1)
        return output
 def stop_test(x):
     y = x + x
     y = stop_gradient(y)
     ret = x + y
     return ret
 def construct(self, x, y):
     ret = x * y
     ret = stop_gradient(ret)
     return ret
def stop_test4(x, y):
    """ stop_test4 """
    c = x + y
    c_s = stop_gradient(c)
    e = c + c_s
    return e
def stop_test2(x, y):
    """ stop_test2 """
    c = x * y
    c_s = stop_gradient(c)
    d = c_s + x * y
    return d * y
def stop_test1(x, y):
    """ stop_test1 """
    c = x * y
    c_s = stop_gradient(c)
    return c_s
 def construct(self, x, y):
     x, y = self.prim_with_no_bprop(x, y)
     x = stop_gradient(x)
     return x, y
 def stop_test(x):
     return stop_gradient(x)
Exemple #21
0
    def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
        x = self.backbone(img_data)
        x = self.fpn_ncek(x)

        rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(
            x, img_metas, self.anchor_list, gt_bboxes, self.gt_labels_stage1,
            gt_valids)

        if self.training:
            proposal, proposal_mask = self.proposal_generator(
                cls_score, bbox_pred, self.anchor_list)
        else:
            proposal, proposal_mask = self.proposal_generator_test(
                cls_score, bbox_pred, self.anchor_list)

        gt_labels = self.cast(gt_labels, mstype.int32)
        gt_valids = self.cast(gt_valids, mstype.int32)
        bboxes_tuple = ()
        deltas_tuple = ()
        labels_tuple = ()
        mask_tuple = ()
        if self.training:
            for i in range(self.train_batch_size):
                gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])

                gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
                gt_labels_i = self.cast(gt_labels_i, mstype.uint8)

                gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
                gt_valids_i = self.cast(gt_valids_i, mstype.bool_)

                bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(
                    gt_bboxes_i, gt_labels_i, proposal_mask[i],
                    proposal[i][::, 0:4:1], gt_valids_i)
                bboxes_tuple += (bboxes, )
                deltas_tuple += (deltas, )
                labels_tuple += (labels, )
                mask_tuple += (mask, )

            bbox_targets = self.concat(deltas_tuple)
            rcnn_labels = self.concat(labels_tuple)
            bbox_targets = F.stop_gradient(bbox_targets)
            rcnn_labels = F.stop_gradient(rcnn_labels)
            rcnn_labels = self.cast(rcnn_labels, mstype.int32)
        else:
            mask_tuple += proposal_mask
            bbox_targets = proposal_mask
            rcnn_labels = proposal_mask
            for p_i in proposal:
                bboxes_tuple += (p_i[::, 0:4:1], )

        if self.training:
            if self.train_batch_size > 1:
                bboxes_all = self.concat(bboxes_tuple)
            else:
                bboxes_all = bboxes_tuple[0]
            rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
        else:
            if self.test_batch_size > 1:
                bboxes_all = self.concat(bboxes_tuple)
            else:
                bboxes_all = bboxes_tuple[0]
            if self.device_type == "Ascend":
                bboxes_all = self.cast(bboxes_all, mstype.float16)
            rois = self.concat_1(
                (self.roi_align_index_test_tensor, bboxes_all))

        rois = self.cast(rois, mstype.float32)
        rois = F.stop_gradient(rois)

        if self.training:
            roi_feats = self.roi_align(rois, self.cast(x[0], mstype.float32),
                                       self.cast(x[1], mstype.float32),
                                       self.cast(x[2], mstype.float32),
                                       self.cast(x[3], mstype.float32))
        else:
            roi_feats = self.roi_align_test(rois,
                                            self.cast(x[0], mstype.float32),
                                            self.cast(x[1], mstype.float32),
                                            self.cast(x[2], mstype.float32),
                                            self.cast(x[3], mstype.float32))

        roi_feats = self.cast(roi_feats, self.ms_type)
        rcnn_masks = self.concat(mask_tuple)
        rcnn_masks = F.stop_gradient(rcnn_masks)
        rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
        rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(
            roi_feats, bbox_targets, rcnn_labels, rcnn_mask_squeeze)

        output = ()
        if self.training:
            output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
                       rcnn_cls_loss, rcnn_reg_loss)
        else:
            output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss,
                                         rcnn_masks, bboxes_all, img_metas)

        return output
Exemple #22
0
    def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
        loss_print = ()
        rpn_cls_score = ()
        rpn_bbox_pred = ()
        rpn_cls_score_total = ()
        rpn_bbox_pred_total = ()

        for i in range(self.num_layers):
            x1, x2 = self.rpn_convs_list[i](inputs[i])

            rpn_cls_score_total = rpn_cls_score_total + (x1,)
            rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)

            x1 = self.transpose(x1, self.trans_shape)
            x1 = self.reshape(x1, self.reshape_shape_cls)

            x2 = self.transpose(x2, self.trans_shape)
            x2 = self.reshape(x2, self.reshape_shape_reg)

            rpn_cls_score = rpn_cls_score + (x1,)
            rpn_bbox_pred = rpn_bbox_pred + (x2,)

        loss = self.loss
        clsloss = self.clsloss
        regloss = self.regloss
        bbox_targets = ()
        bbox_weights = ()
        labels = ()
        label_weights = ()

        output = ()
        if self.training:
            for i in range(self.batch_size):
                multi_level_flags = ()
                anchor_list_tuple = ()

                for j in range(self.num_layers):
                    res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])),
                                    mstype.int32)
                    multi_level_flags = multi_level_flags + (res,)
                    anchor_list_tuple = anchor_list_tuple + (anchor_list[j],)

                valid_flag_list = self.concat(multi_level_flags)
                anchor_using_list = self.concat(anchor_list_tuple)

                gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
                gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
                gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])

                bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
                                                                                 gt_labels_i,
                                                                                 self.cast(valid_flag_list,
                                                                                           mstype.bool_),
                                                                                 anchor_using_list, gt_valids_i)

                bbox_weight = self.cast(bbox_weight, self.ms_type)
                label = self.cast(label, self.ms_type)
                label_weight = self.cast(label_weight, self.ms_type)

                for j in range(self.num_layers):
                    begin = self.slice_index[j]
                    end = self.slice_index[j + 1]
                    stride = 1
                    bbox_targets += (bbox_target[begin:end:stride, ::],)
                    bbox_weights += (bbox_weight[begin:end:stride],)
                    labels += (label[begin:end:stride],)
                    label_weights += (label_weight[begin:end:stride],)

            for i in range(self.num_layers):
                bbox_target_using = ()
                bbox_weight_using = ()
                label_using = ()
                label_weight_using = ()

                for j in range(self.batch_size):
                    bbox_target_using += (bbox_targets[i + (self.num_layers * j)],)
                    bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],)
                    label_using += (labels[i + (self.num_layers * j)],)
                    label_weight_using += (label_weights[i + (self.num_layers * j)],)

                bbox_target_with_batchsize = self.concat(bbox_target_using)
                bbox_weight_with_batchsize = self.concat(bbox_weight_using)
                label_with_batchsize = self.concat(label_using)
                label_weight_with_batchsize = self.concat(label_weight_using)

                # stop
                bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
                bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
                label_ = F.stop_gradient(label_with_batchsize)
                label_weight_ = F.stop_gradient(label_weight_with_batchsize)

                cls_score_i = rpn_cls_score[i]
                reg_score_i = rpn_bbox_pred[i]

                loss_cls = self.loss_cls(cls_score_i, label_)
                loss_cls_item = loss_cls * label_weight_
                loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total

                loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
                bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4))
                loss_reg = loss_reg * bbox_weight_
                loss_reg_item = self.sum_loss(loss_reg, (1,))
                loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total

                loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item

                loss += loss_total
                loss_print += (loss_total, loss_cls_item, loss_reg_item)
                clsloss += loss_cls_item
                regloss += loss_reg_item

                output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
        else:
            output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)

        return output
 def stop_test(x, y):
     ret = x * y
     ret = stop_gradient(ret)
     return ret
 def construct(self, x1, x2):
     x1, x2 = stop_gradient(self.prim_with_multi_output(x1, x2))
     return x1, x2
    def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
        # f1, f2, f3, f4, f5 = self.vgg16_feature_extractor(img_data)
        _, _, _, f4, f5 = self.vgg16_feature_extractor(img_data)
        f4 = self.cast(f4, mstype.float32)
        f5 = self.cast(f5, mstype.float32)
        x = (f4, f5)

        rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(
            x, img_metas, self.anchor_list, gt_bboxes, self.gt_labels_stage1,
            gt_valids)

        if self.training:
            proposal, proposal_mask = self.proposal_generator(
                cls_score, bbox_pred, self.anchor_list)
        else:
            proposal, proposal_mask = self.proposal_generator_test(
                cls_score, bbox_pred, self.anchor_list)

        gt_labels = self.cast(gt_labels, mstype.int32)
        gt_valids = self.cast(gt_valids, mstype.int32)
        bboxes_tuple = ()
        deltas_tuple = ()
        labels_tuple = ()
        mask_tuple = ()
        if self.training:
            for i in range(self.train_batch_size):
                gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])

                gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
                gt_labels_i = self.cast(gt_labels_i, mstype.uint8)

                gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
                gt_valids_i = self.cast(gt_valids_i, mstype.bool_)

                bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(
                    gt_bboxes_i, gt_labels_i, proposal_mask[i],
                    proposal[i][::, 0:4:1], gt_valids_i)
                bboxes_tuple += (bboxes, )
                deltas_tuple += (deltas, )
                labels_tuple += (labels, )
                mask_tuple += (mask, )

            bbox_targets = self.concat(deltas_tuple)
            rcnn_labels = self.concat(labels_tuple)
            bbox_targets = F.stop_gradient(bbox_targets)
            rcnn_labels = F.stop_gradient(rcnn_labels)
            rcnn_labels = self.cast(rcnn_labels, mstype.int32)
        else:
            mask_tuple += proposal_mask
            bbox_targets = proposal_mask
            rcnn_labels = proposal_mask
            for p_i in proposal:
                bboxes_tuple += (p_i[::, 0:4:1], )

        if self.training:
            if self.train_batch_size > 1:
                bboxes_all = self.concat(bboxes_tuple)
            else:
                bboxes_all = bboxes_tuple[0]
            rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
        else:
            if self.test_batch_size > 1:
                bboxes_all = self.concat(bboxes_tuple)
            else:
                bboxes_all = bboxes_tuple[0]
            rois = self.concat_1(
                (self.roi_align_index_test_tensor, bboxes_all))

        rois = self.cast(rois, mstype.float32)
        rois = F.stop_gradient(rois)

        roi_feats = self.roi_align5(x[1], rois)
        roi_align4_out = self.roi_align4(x[0], rois)

        roi_align4_out = self.cast(roi_align4_out, mstype.float32)
        roi_feats = self.cast(roi_feats, mstype.float32)
        roi_feats = self.concat1((roi_feats, roi_align4_out))

        roi_feats = self.cast(roi_feats, mstype.float32)
        roi_feats = self.roi_align_fuse(roi_feats)

        roi_feats = self.cast(roi_feats, mstype.float32)

        rcnn_masks = self.concat(mask_tuple)
        rcnn_masks = F.stop_gradient(rcnn_masks)
        rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
        rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(
            roi_feats, bbox_targets, rcnn_labels, rcnn_mask_squeeze)

        output = ()
        if self.training:
            output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
                       rcnn_cls_loss, rcnn_reg_loss)
        else:
            output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss,
                                         rcnn_masks, bboxes_all, img_metas)

        return output