Exemple #1
0
    def predict(self, example, output, pc_range):
        t = time.time()
        dets = ddd_decode(output['hm'],
                          output['rot'],
                          output['dim'],
                          pc_range,
                          example["ground"],
                          reg=output['reg'],
                          K=self.K)

        batch_size = example['rect'].shape[0]
        self._total_inference_count += batch_size

        batch_rect = example["rect"]
        batch_Trv2c = example["Trv2c"]
        batch_P2 = example["P2"]
        batch_imgidx = example['image_idx']

        predictions_dicts = []
        for det, rect, Trv2c, P2, img_idx in zip(dets, batch_rect, batch_Trv2c,
                                                 batch_P2, batch_imgidx):

            final_alpha = get_alpha(det[:, 6:14])
            final_rot_y = final_alpha + torch.atan2(-det[:, 1], det[:, 0])
            final_box_preds = torch.cat(
                [det[:, :6], final_rot_y.unsqueeze(1)], dim=-1)

            final_scores = det[:, -2]
            final_labels = det[:, -1]

            final_box_preds_camera = box_torch_ops.box_lidar_to_camera(
                final_box_preds, rect, Trv2c)
            locs = final_box_preds_camera[:, :3]
            dims = final_box_preds_camera[:, 3:6]
            # change angles
            angles = final_box_preds_camera[:, 6]
            camera_box_origin = [0.5, 1.0, 0.5]
            box_corners = box_torch_ops.center_to_corner_box3d(
                locs, dims, angles, camera_box_origin, axis=1)
            box_corners_in_image = box_torch_ops.project_to_image(
                box_corners, P2)
            # box_corners_in_image: [N, 8, 2]
            minxy = torch.min(box_corners_in_image, dim=1)[0]
            maxxy = torch.max(box_corners_in_image, dim=1)[0]
            box_2d_preds = torch.cat([minxy, maxxy], dim=1)
            # predictions
            predictions_dict = {
                "bbox": box_2d_preds,
                "box3d_camera": final_box_preds_camera,
                "box3d_lidar": final_box_preds,
                "scores": final_scores,
                "label_preds": final_labels,
                "image_idx": img_idx,
            }
            predictions_dicts.append(predictions_dict)

        self._total_postprocess_time += time.time() - t
        return predictions_dicts
Exemple #2
0
    def predict(self, example, preds_dict):
        t = time.time()
        batch_size = example['anchors'].shape[0]
        batch_anchors = example["anchors"].view(batch_size, -1, 7)

        self._total_inference_count += batch_size
        batch_rect = example["rect"]
        batch_Trv2c = example["Trv2c"]
        batch_P2 = example["P2"]
        if "anchors_mask" not in example:
            batch_anchors_mask = [None] * batch_size
        else:
            batch_anchors_mask = example["anchors_mask"].view(batch_size, -1)
        batch_imgidx = example['image_idx']

        self._total_forward_time += time.time() - t
        t = time.time()
        batch_box_preds = preds_dict["box_preds"]
        batch_cls_preds = preds_dict["cls_preds"]
        batch_box_preds = batch_box_preds.view(batch_size, -1,
                                               self._box_coder.code_size)
        num_class_with_bg = self._num_class
        if not self._encode_background_as_zeros:
            num_class_with_bg = self._num_class + 1

        batch_cls_preds = batch_cls_preds.view(batch_size, -1,
                                               num_class_with_bg)
        batch_box_preds = self._box_coder.decode_torch(batch_box_preds,
                                                       batch_anchors)
        if self._use_direction_classifier:
            batch_dir_preds = preds_dict["dir_cls_preds"]
            batch_dir_preds = batch_dir_preds.view(batch_size, -1, 2)
        else:
            batch_dir_preds = [None] * batch_size

        predictions_dicts = []
        for box_preds, cls_preds, dir_preds, rect, Trv2c, P2, img_idx, a_mask in zip(
                batch_box_preds, batch_cls_preds, batch_dir_preds, batch_rect,
                batch_Trv2c, batch_P2, batch_imgidx, batch_anchors_mask):
            if a_mask is not None:
                box_preds = box_preds[a_mask]
                cls_preds = cls_preds[a_mask]
            if self._use_direction_classifier:
                if a_mask is not None:
                    dir_preds = dir_preds[a_mask]
                # print(dir_preds.shape)
                dir_labels = torch.max(dir_preds, dim=-1)[1]
            if self._encode_background_as_zeros:
                # this don't support softmax
                assert self._use_sigmoid_score is True
                total_scores = torch.sigmoid(cls_preds)
            else:
                # encode background as first element in one-hot vector
                if self._use_sigmoid_score:
                    total_scores = torch.sigmoid(cls_preds)[..., 1:]
                else:
                    total_scores = F.softmax(cls_preds, dim=-1)[..., 1:]
            # Apply NMS in birdeye view
            if self._use_rotate_nms:
                nms_func = box_torch_ops.rotate_nms
            else:
                nms_func = box_torch_ops.nms
            selected_boxes = None
            selected_labels = None
            selected_scores = None
            selected_dir_labels = None

            if self._multiclass_nms:
                # curently only support class-agnostic boxes.
                boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
                if not self._use_rotate_nms:
                    box_preds_corners = box_torch_ops.center_to_corner_box2d(
                        boxes_for_nms[:, :2], boxes_for_nms[:, 2:4],
                        boxes_for_nms[:, 4])
                    boxes_for_nms = box_torch_ops.corner_to_standup_nd(
                        box_preds_corners)
                boxes_for_mcnms = boxes_for_nms.unsqueeze(1)
                selected_per_class = box_torch_ops.multiclass_nms(
                    nms_func=nms_func,
                    boxes=boxes_for_mcnms,
                    scores=total_scores,
                    num_class=self._num_class,
                    pre_max_size=self._nms_pre_max_size,
                    post_max_size=self._nms_post_max_size,
                    iou_threshold=self._nms_iou_threshold,
                    score_thresh=self._nms_score_threshold,
                )
                selected_boxes, selected_labels, selected_scores = [], [], []
                selected_dir_labels = []
                for i, selected in enumerate(selected_per_class):
                    if selected is not None:
                        num_dets = selected.shape[0]
                        selected_boxes.append(box_preds[selected])
                        selected_labels.append(
                            torch.full([num_dets], i, dtype=torch.int64))
                        if self._use_direction_classifier:
                            selected_dir_labels.append(dir_labels[selected])
                        selected_scores.append(total_scores[selected, i])
                if len(selected_boxes) > 0:
                    selected_boxes = torch.cat(selected_boxes, dim=0)
                    selected_labels = torch.cat(selected_labels, dim=0)
                    selected_scores = torch.cat(selected_scores, dim=0)
                    if self._use_direction_classifier:
                        selected_dir_labels = torch.cat(selected_dir_labels,
                                                        dim=0)
                else:
                    selected_boxes = None
                    selected_labels = None
                    selected_scores = None
                    selected_dir_labels = None
            else:
                # get highest score per prediction, than apply nms
                # to remove overlapped box.
                if num_class_with_bg == 1:
                    top_scores = total_scores.squeeze(-1)
                    top_labels = torch.zeros(total_scores.shape[0],
                                             device=total_scores.device,
                                             dtype=torch.long)
                else:
                    top_scores, top_labels = torch.max(total_scores, dim=-1)

                if self._nms_score_threshold > 0.0:
                    thresh = torch.tensor(
                        [self._nms_score_threshold],
                        device=total_scores.device).type_as(total_scores)
                    top_scores_keep = (top_scores >= thresh)
                    top_scores = top_scores.masked_select(top_scores_keep)
                if top_scores.shape[0] != 0:
                    if self._nms_score_threshold > 0.0:
                        box_preds = box_preds[top_scores_keep]
                        if self._use_direction_classifier:
                            dir_labels = dir_labels[top_scores_keep]
                        top_labels = top_labels[top_scores_keep]
                    boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
                    if not self._use_rotate_nms:
                        box_preds_corners = box_torch_ops.center_to_corner_box2d(
                            boxes_for_nms[:, :2], boxes_for_nms[:, 2:4],
                            boxes_for_nms[:, 4])
                        boxes_for_nms = box_torch_ops.corner_to_standup_nd(
                            box_preds_corners)
                    # the nms in 3d detection just remove overlap boxes.
                    selected = nms_func(
                        boxes_for_nms,
                        top_scores,
                        pre_max_size=self._nms_pre_max_size,
                        post_max_size=self._nms_post_max_size,
                        iou_threshold=self._nms_iou_threshold,
                    )
                else:
                    selected = None
                if selected is not None:
                    selected_boxes = box_preds[selected]
                    if self._use_direction_classifier:
                        selected_dir_labels = dir_labels[selected]
                    selected_labels = top_labels[selected]
                    selected_scores = top_scores[selected]
            # finally generate predictions.

            if selected_boxes is not None:
                box_preds = selected_boxes
                scores = selected_scores
                label_preds = selected_labels
                if self._use_direction_classifier:
                    dir_labels = selected_dir_labels
                    opp_labels = (box_preds[..., -1] > 0) ^ dir_labels.byte()
                    box_preds[..., -1] += torch.where(
                        opp_labels,
                        torch.tensor(np.pi).type_as(box_preds),
                        torch.tensor(0.0).type_as(box_preds))
                    # box_preds[..., -1] += (
                    #     ~(dir_labels.byte())).type_as(box_preds) * np.pi
                final_box_preds = box_preds
                final_scores = scores
                final_labels = label_preds
                final_box_preds_camera = box_torch_ops.box_lidar_to_camera(
                    final_box_preds, rect, Trv2c)
                locs = final_box_preds_camera[:, :3]
                dims = final_box_preds_camera[:, 3:6]
                angles = final_box_preds_camera[:, 6]
                camera_box_origin = [0.5, 1.0, 0.5]
                box_corners = box_torch_ops.center_to_corner_box3d(
                    locs, dims, angles, camera_box_origin, axis=1)
                box_corners_in_image = box_torch_ops.project_to_image(
                    box_corners, P2)
                # box_corners_in_image: [N, 8, 2]
                minxy = torch.min(box_corners_in_image, dim=1)[0]
                maxxy = torch.max(box_corners_in_image, dim=1)[0]
                # minx = torch.min(box_corners_in_image[..., 0], dim=1)[0]
                # maxx = torch.max(box_corners_in_image[..., 0], dim=1)[0]
                # miny = torch.min(box_corners_in_image[..., 1], dim=1)[0]
                # maxy = torch.max(box_corners_in_image[..., 1], dim=1)[0]
                # box_2d_preds = torch.stack([minx, miny, maxx, maxy], dim=1)
                box_2d_preds = torch.cat([minxy, maxxy], dim=1)
                # predictions
                predictions_dict = {
                    "bbox": box_2d_preds,
                    "box3d_camera": final_box_preds_camera,
                    "box3d_lidar": final_box_preds,
                    "scores": final_scores,
                    "label_preds": label_preds,
                    "image_idx": img_idx,
                }
            else:
                predictions_dict = {
                    "bbox": None,
                    "box3d_camera": None,
                    "box3d_lidar": None,
                    "scores": None,
                    "label_preds": None,
                    "image_idx": img_idx,
                }
            predictions_dicts.append(predictions_dict)
        self._total_postprocess_time += time.time() - t
        return predictions_dicts
Exemple #3
0
    def train_stage_2(self, example, preds_dict, top_predictions_left,
                      top_predictions_right):
        t = time.time()
        batch_size = example['anchors'].shape[0]
        batch_anchors = example["anchors"].view(batch_size, -1, 7)
        batch_anchors_reshape = batch_anchors.reshape(1, 200, 176,
                                                      14)  ## 预先设定的锚框?
        batch_rect = example["rect"]  ##
        batch_Trv2c = example["Trv2c"]
        batch_P2 = example["P2"]
        batch_P3 = example["P3"]

        batch_image_shape = example["image_shape"]
        if "anchors_mask" not in example:
            batch_anchors_mask = [None] * batch_size
        else:
            batch_anchors_mask = example["anchors_mask"].view(batch_size, -1)
        batch_imgidx = example['image_idx']

        t = time.time()
        batch_box_preds = preds_dict["box_preds"]  ## 预测的3d box
        batch_cls_preds = preds_dict["cls_preds"]  ## 预测的class
        batch_box_preds = batch_box_preds.view(batch_size, -1,
                                               self._box_coder.code_size)
        num_class_with_bg = self._num_class
        if not self._encode_background_as_zeros:
            num_class_with_bg = self._num_class + 1
        batch_cls_preds = batch_cls_preds.view(batch_size, -1,
                                               num_class_with_bg)  ##
        batch_box_preds = self._box_coder.decode_torch(batch_box_preds,
                                                       batch_anchors)
        if self._use_direction_classifier:
            batch_dir_preds = preds_dict["dir_cls_preds"]
            batch_dir_preds = batch_dir_preds.view(batch_size, -1, 2)
        else:
            batch_dir_preds = [None] * batch_size

        predictions_dicts = []
        for box_preds, cls_preds, dir_preds, rect, Trv2c, P2, P3, img_idx, a_mask in zip(
                batch_box_preds, batch_cls_preds, batch_dir_preds, batch_rect,
                batch_Trv2c, batch_P2, batch_P3, batch_imgidx,
                batch_anchors_mask):
            if a_mask is not None:
                box_preds = box_preds[a_mask]
                cls_preds = cls_preds[a_mask]
            box_preds = box_preds.float()
            cls_preds = cls_preds.float()
            rect = rect.float()
            Trv2c = Trv2c.float()
            P2 = P2.float()
            P3 = P3.float()
            if self._encode_background_as_zeros:
                # this don't support softmax
                assert self._use_sigmoid_score is True
                total_scores = torch.sigmoid(cls_preds)
                #total_scores = cls_preds   # use this if you want to fuse raw log score
            else:
                # encode background as first element in one-hot vector
                if self._use_sigmoid_score:
                    total_scores = torch.sigmoid(cls_preds)[..., 1:]
                else:
                    total_scores = F.softmax(cls_preds, dim=-1)[..., 1:]

            # finally generate predictions.
            final_box_preds = box_preds
            final_scores = total_scores
            final_box_preds_camera = box_torch_ops.box_lidar_to_camera(
                final_box_preds, rect, Trv2c)  ## 将3d box转换到图像坐标系
            locs = final_box_preds_camera[:, :3]
            dims = final_box_preds_camera[:, 3:6]
            angles = final_box_preds_camera[:, 6]
            camera_box_origin = [0.5, 1.0, 0.5]
            box_corners = box_torch_ops.center_to_corner_box3d(
                locs, dims, angles, camera_box_origin, axis=1)  ##

            box_corners_in_image_left = box_torch_ops.project_to_image(
                box_corners, P2)  ## 将8个顶点投影到图像
            box_corners_in_image_right = box_torch_ops.project_to_image(
                box_corners, P3)
            # box_corners_in_image: [N, 8, 2]
            minxy = torch.min(box_corners_in_image_left, dim=1)[0]
            maxxy = torch.max(box_corners_in_image_left, dim=1)[0]
            img_height = batch_image_shape[0, 0]
            img_width = batch_image_shape[0, 1]
            minxy[:, 0] = torch.clamp(minxy[:, 0], min=0, max=img_width)
            minxy[:, 1] = torch.clamp(minxy[:, 1], min=0, max=img_height)
            maxxy[:, 0] = torch.clamp(maxxy[:, 0], min=0, max=img_width)
            maxxy[:, 1] = torch.clamp(maxxy[:, 1], min=0, max=img_height)
            box_2d_preds_left = torch.cat([minxy, maxxy], dim=1)

            minxy = torch.min(box_corners_in_image_right, dim=1)[0]
            maxxy = torch.max(box_corners_in_image_right, dim=1)[0]
            minxy[:, 0] = torch.clamp(minxy[:, 0], min=0, max=img_width)
            minxy[:, 1] = torch.clamp(minxy[:, 1], min=0, max=img_height)
            maxxy[:, 0] = torch.clamp(maxxy[:, 0], min=0, max=img_width)
            maxxy[:, 1] = torch.clamp(maxxy[:, 1], min=0, max=img_height)
            box_2d_preds_right = torch.cat([minxy, maxxy], dim=1)

            # predictions
            predictions_dict = {
                "bbox": box_2d_preds_left,
                "box3d_camera": final_box_preds_camera,
                "box3d_lidar": final_box_preds,
                "scores": final_scores,
                #"label_preds": label_preds,
                "image_idx": img_idx,
            }
            predictions_dicts.append(predictions_dict)
            dis_to_lidar = torch.norm(
                box_preds[:, :2], p=2, dim=1, keepdim=True) / 82.0  ## 到雷达的距离
            box_2d_detector_left = np.zeros((200, 4))
            box_2d_detector_right = np.zeros((200, 4))

            #
            if (top_predictions_left.shape[0] > 20):
                box_2d_detector_left = top_predictions_left[:20, :4]
            else:
                box_2d_detector_left = top_predictions_left[:, :4]

            if (top_predictions_right.shape[0] > 20):
                box_2d_detector_right = top_predictions_right[:20, :4]
            else:
                box_2d_detector_right = top_predictions_right[:, :4]
            # box_2d_detector[0:top_predictions.shape[0],:]=top_predictions[:,:4] ## 200个2d box
            # box_2d_detector = top_predictions[:,:4]
        # import ipdb;ipdb.set_trace()
            box_2d_scores_left = top_predictions_left[:, 4].reshape(-1, 1)
            box_2d_scores_right = top_predictions_right[:, 4].reshape(-1, 1)
            time_iou_build_start = time.time()
            overlaps = np.zeros((900000, 6),
                                dtype=box_2d_preds_left.detach().cpu().numpy().
                                dtype)  ## 9x1e5个可能组合
            tensor_index1 = np.zeros(
                (900000, 2),
                dtype=box_2d_preds_left.detach().cpu().numpy().dtype)
            overlaps[:, :] = -1
            tensor_index1[:, :] = -1

            #final_scores[final_scores<0.1] = 0
            #box_2d_preds[(final_scores<0.1).reshape(-1),:] = 0
            iou_test, tensor_index, max_num = se.build_stage2_training(
                box_2d_preds_left.detach().cpu().numpy(),
                box_2d_preds_right.detach().cpu().numpy(),
                box_2d_detector_left, box_2d_detector_right, -1,
                final_scores.detach().cpu().numpy(), box_2d_scores_left,
                box_2d_scores_right,
                dis_to_lidar.detach().cpu().numpy(), overlaps, tensor_index1)

            time_iou_build_end = time.time()
            iou_test_tensor = torch.FloatTensor(
                iou_test)  #iou_test_tensor shape: [160000,4]
            tensor_index_tensor = torch.LongTensor(tensor_index)
            iou_test_tensor = iou_test_tensor.permute(1, 0)
            iou_test_tensor = iou_test_tensor.reshape(1, 6, 1, 900000)
            tensor_index_tensor = tensor_index_tensor.reshape(-1, 2)
            if max_num == 0:
                non_empty_iou_test_tensor = torch.zeros(1, 6, 1, 2)
                non_empty_iou_test_tensor[:, :, :, :] = -1
                non_empty_tensor_index_tensor = torch.zeros(2, 2)
                non_empty_tensor_index_tensor[:, :] = -1
            else:
                non_empty_iou_test_tensor = iou_test_tensor[:, :, :, :max_num]
                non_empty_tensor_index_tensor = tensor_index_tensor[:
                                                                    max_num, :]
        ##
        return predictions_dicts, non_empty_iou_test_tensor, non_empty_tensor_index_tensor
Exemple #4
0
def rpn_nms(box_preds, cls_preds, example, box_coder, nms_score_threshold, nms_pre_max_size,
            nms_post_max_size, nms_iou_threshold, training, range_thresh=0):
    anchors = example["anchors"]
    batch_size = anchors.shape[0]
    batch_anchors = anchors.view(batch_size, -1, 7)
    batch_rect = example["calib"]["rect"]
    batch_Trv2c = example["calib"]["Trv2c"]
    batch_P2 = example["calib"]["P2"]
    if training:
        batch_labels = example["labels"]
        batch_reg_targets = example["reg_targets"]
        batch_dir_targets = get_direction_target(
            batch_anchors,
            batch_reg_targets,
            dir_offset=0.0,
            num_bins=2)
    else:
        batch_labels = [None] * batch_size
        batch_reg_targets = [None] * batch_size
        batch_dir_targets = [None] * batch_size

    if "anchors_mask" not in example:
        batch_anchors_mask = [None] * batch_size
    else:
        anchors_mask = example["anchors_mask"]
        batch_anchors_mask = anchors_mask.view(batch_size, -1)
    batch_box_props = box_preds.view(batch_size, -1, box_coder.code_size)
    batch_box_props = box_coder.decode_torch(batch_box_props, batch_anchors)
    batch_cls_props = cls_preds.view(batch_size, -1, 1)

    batch_far_proposals_bev = []
    batch_far_proposals_img = []
    batch_near_proposals_bev = []
    batch_near_proposals_img = []
    batch_rcnn_labels = []
    batch_rcnn_reg_target = []
    batch_rcnn_dir_target = []
    batch_rcnn_anchors = []
    for box_props, cls_props, labels, reg_target, dir_targets, rect, Trv2c, P2, a_mask, anchors in zip(
            batch_box_props, batch_cls_props, batch_labels, batch_reg_targets, batch_dir_targets,
            batch_rect, batch_Trv2c, batch_P2, batch_anchors_mask, batch_anchors):
        if a_mask is not None:
            box_props = box_props[a_mask]
            cls_props = cls_props[a_mask]
            anchors = anchors[a_mask]
            if training:
                labels = labels[a_mask]
                reg_target = reg_target[a_mask]
                dir_targets = dir_targets[a_mask]
        cls_scores = torch.sigmoid(cls_props)[..., 1:]
        top_scores = cls_props.squeeze(-1)
        nms_func = box_torch_ops.nms
        if nms_score_threshold > 0.0:
            thresh = torch.Tensor([nms_score_threshold],
                                  device=cls_scores.cpu().device).type_as(cls_scores)
            top_scores_keep = (top_scores >= thresh)
            top_scores = top_scores.masked_select(top_scores_keep)
        if top_scores.shape[0] != 0:
            # score threshold
            if nms_score_threshold > 0.0:
                box_props = box_props[top_scores_keep]
                anchors = anchors[top_scores_keep]
                if training:
                    labels = labels[top_scores_keep]
                    reg_target = reg_target[top_scores_keep]
                    dir_targets = dir_targets[top_scores_keep]
            # range
            range_thresh = torch.Tensor([range_thresh],
                                        device=box_props.cpu().device).type_as(box_props)
            # todo: uncertain, which is range
            far_boxes_idx = (box_props[:, 0] >= range_thresh)

            far_box_props = box_props[far_boxes_idx]
            far_top_socres = top_scores[far_boxes_idx]
            far_anchors = anchors[far_boxes_idx]
            if training:
                far_labels = labels[far_boxes_idx]
                far_reg_target = reg_target[far_boxes_idx]
                far_dir_target = dir_targets[far_boxes_idx]
            if far_box_props.shape[0] != 0:
                far_boxes_for_nms = far_box_props[:, [0, 1, 3, 4, 6]]
                far_box_props_corners = box_torch_ops.center_to_corner_box2d(
                    far_boxes_for_nms[:, :2], far_boxes_for_nms[:, 2:4],
                    far_boxes_for_nms[:, 4])
                far_boxes_for_nms = box_torch_ops.corner_to_standup_nd(
                    far_box_props_corners)

                far_selected = nms_func(
                    far_boxes_for_nms,
                    far_top_socres,
                    pre_max_size=nms_pre_max_size // 2,
                    post_max_size=nms_post_max_size // 2,
                    iou_threshold=nms_iou_threshold)
            else:
                far_selected = None

            if range_thresh > 0:
                near_boxes_idx = (box_props[:, 0] < range_thresh)
                near_box_props = box_props[near_boxes_idx]
                near_anchors = anchors[near_boxes_idx]
                near_top_socres = top_scores[near_boxes_idx]
                if training:
                    near_labels = labels[near_boxes_idx]
                    near_reg_target = reg_target[near_boxes_idx]
                    near_dir_target = dir_targets[near_boxes_idx]
                if near_box_props.shape[0] != 0:
                    near_boxes_for_nms = near_box_props[:, [0, 1, 3, 4, 6]]
                    near_box_props_corners = box_torch_ops.center_to_corner_box2d(
                        near_boxes_for_nms[:, :2], near_boxes_for_nms[:, 2:4],
                        near_boxes_for_nms[:, 4])
                    near_boxes_for_nms = box_torch_ops.corner_to_standup_nd(
                        near_box_props_corners)
                    near_selected = nms_func(
                        near_boxes_for_nms,
                        near_top_socres,
                        pre_max_size=nms_pre_max_size,
                        post_max_size=nms_post_max_size,
                        iou_threshold=nms_iou_threshold)
                else:
                    near_selected = None
            else:
                near_selected = None
        else:
            far_selected = None
            near_selected = None

        if far_selected is not None:
            far_proposals_3d = far_box_props[far_selected]
            num_far_selected = far_proposals_3d.shape[0]

            far_proposals_3d_fix = torch.zeros((nms_post_max_size // 2, 7)).cuda()
            far_anchors_fix = torch.zeros((nms_post_max_size // 2, 7)).cuda()
            far_proposals_3d_fix[:num_far_selected, :] = far_proposals_3d
            far_anchors_fix[:num_far_selected, :] = far_anchors[far_selected]
            far_anchors_fix = far_anchors_fix.unsqueeze(0)

            if training:
                far_labels_fix = torch.zeros((nms_post_max_size // 2)).cuda()
                far_reg_target_fix = torch.zeros((nms_post_max_size // 2, 7)).cuda()
                far_dir_target_fix = torch.zeros((nms_post_max_size // 2, 2)).cuda()

                far_labels_fix[:num_far_selected] = far_labels[far_selected]
                far_reg_target_fix[:num_far_selected, :] = far_reg_target[far_selected]
                far_dir_target_fix[:num_far_selected, :] = far_dir_target[far_selected]
                far_labels_fix = far_labels_fix.unsqueeze(0)
                far_reg_target_fix = far_reg_target_fix.unsqueeze(0)
                far_dir_target_fix = far_dir_target_fix.unsqueeze(0)

            far_proposals_bev_fix = far_proposals_3d_fix[:, [0, 1, 3, 4, 6]].unsqueeze(0)
            far_proposals_cam_fix = box_torch_ops.box_lidar_to_camera(far_proposals_3d_fix, rect, Trv2c)
            far_locs_cam = far_proposals_cam_fix[:, :3]
            far_dims_cam = far_proposals_cam_fix[:, 3:6]
            far_angles_cam = far_proposals_cam_fix[:, 6]
            camera_box_origin = [0.5, 1.0, 0.5]
            far_proposals_cam_corners = box_torch_ops.center_to_corner_box3d(
                far_locs_cam, far_dims_cam, far_angles_cam, camera_box_origin, axis=1)
            far_proposals_img_corners = box_torch_ops.project_to_image(
                far_proposals_cam_corners, P2)
            minxy = torch.min(far_proposals_img_corners, dim=1)[0]
            maxxy = torch.max(far_proposals_img_corners, dim=1)[0]
            far_proposals_img_fix = torch.cat([minxy, maxxy], dim=1).unsqueeze(0)
        else:
            far_proposals_bev_fix = torch.zeros((nms_post_max_size // 2, 5)).cuda().unsqueeze(0)
            far_proposals_img_fix = torch.zeros((nms_post_max_size // 2, 4)).cuda().unsqueeze(0)
            far_labels_fix = torch.zeros((nms_post_max_size // 2)).cuda().unsqueeze(0)
            far_reg_target_fix = torch.zeros((nms_post_max_size // 2, 7)).cuda().unsqueeze(0)
            far_dir_target_fix = torch.zeros((nms_post_max_size // 2, 2)).cuda().unsqueeze(0)
            far_anchors_fix = torch.zeros((nms_post_max_size // 2, 7)).cuda().unsqueeze(0)

        if near_selected is not None:
            near_proposals_3d = near_box_props[near_selected]
            num_near_selected = near_proposals_3d.shape[0]
            near_proposals_3d_fix = torch.zeros((nms_post_max_size, 7)).cuda()
            near_anchors_fix = torch.zeros((nms_post_max_size, 7)).cuda()

            near_proposals_3d_fix[:num_near_selected, :] = near_proposals_3d
            near_anchors_fix[:num_near_selected, :] = near_anchors[near_selected]
            near_anchors_fix = near_anchors_fix.unsqueeze(0)

            if training:
                near_labels_fix = torch.zeros((nms_post_max_size,)).cuda()
                near_reg_target_fix = torch.zeros((nms_post_max_size, 7)).cuda()
                near_dir_target_fix = torch.zeros((nms_post_max_size, 2)).cuda()

                near_labels_fix[:num_near_selected] = near_labels[near_selected]
                near_reg_target_fix[:num_near_selected, :] = near_reg_target[near_selected]
                near_dir_target_fix[:num_near_selected, :] = near_dir_target[near_selected]
                near_labels_fix = near_labels_fix.unsqueeze(0)
                near_reg_target_fix = near_reg_target_fix.unsqueeze(0)
                near_dir_target_fix = near_dir_target_fix.unsqueeze(0)

            near_proposals_bev_fix = near_proposals_3d_fix[:, [0, 1, 3, 4, 6]].unsqueeze(0)
            near_proposals_cam_fix = box_torch_ops.box_lidar_to_camera(near_proposals_3d_fix, rect, Trv2c)
            near_locs_cam = near_proposals_cam_fix[:, :3]
            near_dims_cam = near_proposals_cam_fix[:, 3:6]
            near_angles_cam = near_proposals_cam_fix[:, 6]
            camera_box_origin = [0.5, 1.0, 0.5]
            near_proposals_cam_corners = box_torch_ops.center_to_corner_box3d(
                near_locs_cam, near_dims_cam, near_angles_cam, camera_box_origin, axis=1)
            near_proposals_img_corners = box_torch_ops.project_to_image(
                near_proposals_cam_corners, P2)
            near_minxy = torch.min(near_proposals_img_corners, dim=1)[0]
            near_maxxy = torch.max(near_proposals_img_corners, dim=1)[0]
            near_proposals_img_fix = torch.cat([near_minxy, near_maxxy], dim=1).unsqueeze(0)
        else:
            near_proposals_bev_fix = torch.zeros((nms_post_max_size, 5)).cuda().unsqueeze(0)
            near_proposals_img_fix = torch.zeros((nms_post_max_size, 4)).cuda().unsqueeze(0)
            near_labels_fix = torch.zeros((nms_post_max_size)).cuda().unsqueeze(0)
            near_reg_target_fix = torch.zeros((nms_post_max_size, 7)).cuda().unsqueeze(0)
            near_dir_target_fix = torch.zeros((nms_post_max_size, 2)).cuda().unsqueeze(0)
            near_anchors_fix = torch.zeros((nms_post_max_size, 7)).cuda().unsqueeze(0)
        if training:
            rcnn_labels_fix = torch.cat([near_labels_fix, far_labels_fix], dim=1)
            rcnn_reg_target_fix = torch.cat([near_reg_target_fix, far_reg_target_fix], dim=1)
            rcnn_dir_target_fix = torch.cat([near_dir_target_fix, far_dir_target_fix], dim=1)
        else:
            rcnn_labels_fix = None
            rcnn_reg_target_fix = None
            rcnn_dir_target_fix = None
        if near_anchors_fix is not None:
            rcnn_anchors_fix = torch.cat([near_anchors_fix, far_anchors_fix], dim=1)
        batch_far_proposals_bev.append(far_proposals_bev_fix)
        batch_far_proposals_img.append(far_proposals_img_fix)
        batch_near_proposals_bev.append(near_proposals_bev_fix)
        batch_near_proposals_img.append(near_proposals_img_fix)
        batch_rcnn_labels.append(rcnn_labels_fix)
        batch_rcnn_reg_target.append(rcnn_reg_target_fix)
        batch_rcnn_dir_target.append(rcnn_dir_target_fix)
        batch_rcnn_anchors.append(rcnn_anchors_fix)
    batch_far_proposals_bev = torch.cat(batch_far_proposals_bev, dim=0)
    batch_far_proposals_img = torch.cat(batch_far_proposals_img, dim=0)
    if batch_near_proposals_bev[0] is not None:
        batch_near_proposals_bev = torch.cat(batch_near_proposals_bev, dim=0)
        batch_near_proposals_img = torch.cat(batch_near_proposals_img, dim=0)

    if training:
        batch_rcnn_labels = torch.cat(batch_rcnn_labels, dim=0)
        batch_rcnn_reg_target = torch.cat(batch_rcnn_reg_target, dim=0)
        batch_rcnn_dir_target = torch.cat(batch_rcnn_dir_target, dim=0)
    batch_rcnn_anchors = torch.cat(batch_rcnn_anchors, dim=0)
    rcnn_examples = {
        "far_props_bev": batch_far_proposals_bev,
        "far_props_img": batch_far_proposals_img,
        "near_props_bev": batch_near_proposals_bev,
        "near_props_img": batch_near_proposals_img,
        "rcnn_labels": batch_rcnn_labels,
        "rcnn_reg_targets": batch_rcnn_reg_target,
        "rcnn_dir_targets": batch_rcnn_dir_target,
        "rcnn_anchors": batch_rcnn_anchors
}
    return rcnn_examples
Exemple #5
0
def predict_kitti_to_anno(net,
                          detection_2d_path,
                          fusion_layer,
                          example,
                          class_names,
                          center_limit_range=None,
                          lidar_input=False,
                          global_set=None):
    focal_loss_val = SigmoidFocalClassificationLoss()
    batch_image_shape = example['image_shape']
    batch_imgidx = example['image_idx']
    all_3d_output_camera_dict, all_3d_output, top_predictions, fusion_input, torch_index = net(
        example, detection_2d_path)
    t_start = time.time()
    fusion_cls_preds, flag = fusion_layer(fusion_input.cuda(),
                                          torch_index.cuda())
    t_end = time.time()
    t_fusion = t_end - t_start
    fusion_cls_preds_reshape = fusion_cls_preds.reshape(1, 200, 176, 2)
    all_3d_output.update({
        'cls_preds': fusion_cls_preds_reshape
    })  ###########################################!!!!!!!!!!!!!
    predictions_dicts = predict_v2(net, example, all_3d_output)
    test_mode = False
    if test_mode == False:
        d3_gt_boxes = example["d3_gt_boxes"][0, :, :]
        if d3_gt_boxes.shape[0] == 0:
            target_for_fusion = np.zeros((1, 70400, 1))
            positives = torch.zeros(1, 70400).type(torch.float32).cuda()
            negatives = torch.zeros(1, 70400).type(torch.float32).cuda()
            negatives[:, :] = 1
        else:
            d3_gt_boxes_camera = box_torch_ops.box_lidar_to_camera(
                d3_gt_boxes, example['rect'][0, :], example['Trv2c'][0, :])
            d3_gt_boxes_camera_bev = d3_gt_boxes_camera[:, [0, 2, 3, 5, 6]]
            ###### predicted bev boxes
            pred_3d_box = all_3d_output_camera_dict[0]["box3d_camera"]
            pred_bev_box = pred_3d_box[:, [0, 2, 3, 5, 6]]
            #iou_bev = bev_box_overlap(d3_gt_boxes_camera_bev.detach().cpu().numpy(), pred_bev_box.detach().cpu().numpy(), criterion=-1)
            iou_bev = d3_box_overlap(
                d3_gt_boxes_camera.detach().cpu().numpy(),
                pred_3d_box.squeeze().detach().cpu().numpy(),
                criterion=-1)
            iou_bev_max = np.amax(iou_bev, axis=0)
            target_for_fusion = ((iou_bev_max >= 0.7) * 1).reshape(1, -1, 1)
            positive_index = ((iou_bev_max >= 0.7) * 1).reshape(1, -1)
            positives = torch.from_numpy(positive_index).type(
                torch.float32).cuda()
            negative_index = ((iou_bev_max <= 0.5) * 1).reshape(1, -1)
            negatives = torch.from_numpy(negative_index).type(
                torch.float32).cuda()

        cls_preds = fusion_cls_preds
        one_hot_targets = torch.from_numpy(target_for_fusion).type(
            torch.float32).cuda()

        negative_cls_weights = negatives.type(torch.float32) * 1.0
        cls_weights = negative_cls_weights + 1.0 * positives.type(
            torch.float32)
        pos_normalizer = positives.sum(1, keepdim=True).type(torch.float32)
        cls_weights /= torch.clamp(pos_normalizer, min=1.0)
        cls_losses = focal_loss_val._compute_loss(cls_preds, one_hot_targets,
                                                  cls_weights.cuda())  # [N, M]

        cls_losses_reduced = cls_losses.sum() / example['labels'].shape[0]
        cls_losses_reduced = cls_losses_reduced.detach().cpu().numpy()
    else:
        cls_losses_reduced = 1000
    annos = []
    for i, preds_dict in enumerate(predictions_dicts):
        image_shape = batch_image_shape[i]
        img_idx = preds_dict["image_idx"]
        if preds_dict["bbox"] is not None or preds_dict["bbox"].size.numel(
        ) != 0:
            box_2d_preds = preds_dict["bbox"].detach().cpu().numpy()
            box_preds = preds_dict["box3d_camera"].detach().cpu().numpy()
            scores = preds_dict["scores"].detach().cpu().numpy()
            box_preds_lidar = preds_dict["box3d_lidar"].detach().cpu().numpy()
            # write pred to file
            label_preds = preds_dict["label_preds"].detach().cpu().numpy()
            # label_preds = np.zeros([box_2d_preds.shape[0]], dtype=np.int32)
            anno = kitti.get_start_result_anno()
            num_example = 0
            for box, box_lidar, bbox, score, label in zip(
                    box_preds, box_preds_lidar, box_2d_preds, scores,
                    label_preds):
                if not lidar_input:
                    if bbox[0] > image_shape[1] or bbox[1] > image_shape[0]:
                        continue
                    if bbox[2] < 0 or bbox[3] < 0:
                        continue
                # print(img_shape)
                if center_limit_range is not None:
                    limit_range = np.array(center_limit_range)
                    if (np.any(box_lidar[:3] < limit_range[:3])
                            or np.any(box_lidar[:3] > limit_range[3:])):
                        continue
                bbox[2:] = np.minimum(bbox[2:], image_shape[::-1])
                bbox[:2] = np.maximum(bbox[:2], [0, 0])
                anno["name"].append(class_names[int(label)])
                anno["truncated"].append(0.0)
                anno["occluded"].append(0)
                anno["alpha"].append(-np.arctan2(-box_lidar[1], box_lidar[0]) +
                                     box[6])
                anno["bbox"].append(bbox)
                anno["dimensions"].append(box[3:6])
                anno["location"].append(box[:3])
                anno["rotation_y"].append(box[6])
                if global_set is not None:
                    for i in range(100000):
                        if score in global_set:
                            score -= 1 / 100000
                        else:
                            global_set.add(score)
                            break
                anno["score"].append(score)

                num_example += 1
            if num_example != 0:
                anno = {n: np.stack(v) for n, v in anno.items()}
                annos.append(anno)
            else:
                annos.append(kitti.empty_result_anno())
        else:
            annos.append(kitti.empty_result_anno())
        num_example = annos[-1]["name"].shape[0]
        annos[-1]["image_idx"] = np.array([img_idx] * num_example,
                                          dtype=np.int64)
        #cls_losses_reduced=100
    return annos, cls_losses_reduced
Exemple #6
0
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          pickle_result=True,
          patchs=None):
    torch.manual_seed(3)
    np.random.seed(3)
    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    patchs = patchs or []
    model_dir = pathlib.Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)
    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config
    detection_2d_path = config.train_config.detection_2d_path
    print("2d detection path:", detection_2d_path)
    center_limit_range = model_cfg.post_center_limit_range
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    class_names = target_assigner.classes
    net = build_inference_net('./configs/car.fhd.config', '../model_dir')
    fusion_layer = fusion.fusion()
    fusion_layer.cuda()
    optimizer_cfg = train_cfg.optimizer
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    loss_scale = train_cfg.loss_scale_factor
    mixed_optimizer = optimizer_builder.build(
        optimizer_cfg,
        fusion_layer,
        mixed=train_cfg.enable_mixed_precision,
        loss_scale=loss_scale)
    optimizer = mixed_optimizer
    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer,
                                              train_cfg.steps)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################

    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner)
    eval_dataset = input_reader_builder.build(
        eval_input_cfg,
        model_cfg,
        training=True,  #if rhnning for test, here it needs to be False
        voxel_generator=voxel_generator,
        target_assigner=target_assigner)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=input_cfg.batch_size,
                                             shuffle=True,
                                             num_workers=input_cfg.num_workers,
                                             pin_memory=False,
                                             collate_fn=merge_second_batch,
                                             worker_init_fn=_worker_init_fn)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    focal_loss = SigmoidFocalClassificationLoss()
    cls_loss_sum = 0
    training_detail = []
    log_path = model_dir / 'log.txt'
    training_detail_path = model_dir / 'log.json'
    if training_detail_path.exists():
        with open(training_detail_path, 'r') as f:
            training_detail = json.load(f)
    logf = open(log_path, 'a')
    logf.write(proto_str)
    logf.write("\n")
    summary_dir = model_dir / 'summary'
    summary_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(summary_dir))
    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t
    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    #print("steps, steps_per_eval, total_loop:", train_cfg.steps, train_cfg.steps_per_eval, total_loop)
    # total_loop = remain_steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch
    net.set_global_step(torch.tensor([0]))
    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:
        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.steps_per_eval > train_cfg.steps:
                steps = train_cfg.steps % train_cfg.steps_per_eval
            else:
                steps = train_cfg.steps_per_eval
            for step in range(steps):
                lr_scheduler.step(net.get_global_step())
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example, float_dtype)
                batch_size = example["anchors"].shape[0]
                all_3d_output_camera_dict, all_3d_output, top_predictions, fusion_input, tensor_index = net(
                    example_torch, detection_2d_path)
                d3_gt_boxes = example_torch["d3_gt_boxes"][0, :, :]
                if d3_gt_boxes.shape[0] == 0:
                    target_for_fusion = np.zeros((1, 70400, 1))
                    positives = torch.zeros(1,
                                            70400).type(torch.float32).cuda()
                    negatives = torch.zeros(1,
                                            70400).type(torch.float32).cuda()
                    negatives[:, :] = 1
                else:
                    d3_gt_boxes_camera = box_torch_ops.box_lidar_to_camera(
                        d3_gt_boxes, example_torch['rect'][0, :],
                        example_torch['Trv2c'][0, :])
                    d3_gt_boxes_camera_bev = d3_gt_boxes_camera[:, [
                        0, 2, 3, 5, 6
                    ]]
                    ###### predicted bev boxes
                    pred_3d_box = all_3d_output_camera_dict[0]["box3d_camera"]
                    pred_bev_box = pred_3d_box[:, [0, 2, 3, 5, 6]]
                    #iou_bev = bev_box_overlap(d3_gt_boxes_camera_bev.detach().cpu().numpy(), pred_bev_box.detach().cpu().numpy(), criterion=-1)
                    iou_bev = d3_box_overlap(
                        d3_gt_boxes_camera.detach().cpu().numpy(),
                        pred_3d_box.squeeze().detach().cpu().numpy(),
                        criterion=-1)
                    iou_bev_max = np.amax(iou_bev, axis=0)
                    #print(np.max(iou_bev_max))
                    target_for_fusion = ((iou_bev_max >= 0.7) * 1).reshape(
                        1, -1, 1)

                    positive_index = ((iou_bev_max >= 0.7) * 1).reshape(1, -1)
                    positives = torch.from_numpy(positive_index).type(
                        torch.float32).cuda()
                    negative_index = ((iou_bev_max <= 0.5) * 1).reshape(1, -1)
                    negatives = torch.from_numpy(negative_index).type(
                        torch.float32).cuda()

                cls_preds, flag = fusion_layer(fusion_input.cuda(),
                                               tensor_index.cuda())
                one_hot_targets = torch.from_numpy(target_for_fusion).type(
                    torch.float32).cuda()

                negative_cls_weights = negatives.type(torch.float32) * 1.0
                cls_weights = negative_cls_weights + 1.0 * positives.type(
                    torch.float32)
                pos_normalizer = positives.sum(1, keepdim=True).type(
                    torch.float32)
                cls_weights /= torch.clamp(pos_normalizer, min=1.0)
                if flag == 1:
                    cls_losses = focal_loss._compute_loss(
                        cls_preds, one_hot_targets,
                        cls_weights.cuda())  # [N, M]
                    cls_losses_reduced = cls_losses.sum(
                    ) / example_torch['labels'].shape[0]
                    cls_loss_sum = cls_loss_sum + cls_losses_reduced
                    if train_cfg.enable_mixed_precision:
                        loss *= loss_scale
                    cls_losses_reduced.backward()
                    mixed_optimizer.step()
                    mixed_optimizer.zero_grad()
                net.update_global_step()
                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                global_step = net.get_global_step()
                if global_step % display_step == 0:
                    print("now it is",
                          global_step,
                          "steps",
                          " and the cls_loss is :",
                          cls_loss_sum / display_step,
                          "learning_rate: ",
                          float(optimizer.lr),
                          file=logf)
                    print("now it is", global_step, "steps",
                          " and the cls_loss is :",
                          cls_loss_sum / display_step, "learning_rate: ",
                          float(optimizer.lr))
                    cls_loss_sum = 0

                ckpt_elasped_time = time.time() - ckpt_start_time

                if ckpt_elasped_time > train_cfg.save_checkpoints_secs:
                    torchplus.train.save_models(model_dir,
                                                [fusion_layer, optimizer],
                                                net.get_global_step())

                    ckpt_start_time = time.time()

            total_step_elapsed += steps

            torchplus.train.save_models(model_dir, [fusion_layer, optimizer],
                                        net.get_global_step())

            fusion_layer.eval()
            net.eval()
            result_path_step = result_path / f"step_{net.get_global_step()}"
            result_path_step.mkdir(parents=True, exist_ok=True)
            print("#################################")
            print("#################################", file=logf)
            print("# EVAL")
            print("# EVAL", file=logf)
            print("#################################")
            print("#################################", file=logf)
            print("Generate output labels...")
            print("Generate output labels...", file=logf)
            t = time.time()
            dt_annos = []
            prog_bar = ProgressBar()
            net.clear_timer()
            prog_bar.start(
                (len(eval_dataset) + eval_input_cfg.batch_size - 1) //
                eval_input_cfg.batch_size)
            val_loss_final = 0
            for example in iter(eval_dataloader):
                example = example_convert_to_torch(example, float_dtype)
                if pickle_result:
                    dt_annos_i, val_losses = predict_kitti_to_anno(
                        net, detection_2d_path, fusion_layer, example,
                        class_names, center_limit_range, model_cfg.lidar_input)
                    dt_annos += dt_annos_i
                    val_loss_final = val_loss_final + val_losses
                else:
                    _predict_kitti_to_file(net, detection_2d_path, example,
                                           result_path_step, class_names,
                                           center_limit_range,
                                           model_cfg.lidar_input)

                prog_bar.print_bar()

            sec_per_ex = len(eval_dataset) / (time.time() - t)
            print("validation_loss:", val_loss_final / len(eval_dataloader))
            print("validation_loss:",
                  val_loss_final / len(eval_dataloader),
                  file=logf)
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                  file=logf)
            gt_annos = [
                info["annos"] for info in eval_dataset.dataset.kitti_infos
            ]
            if not pickle_result:
                dt_annos = kitti.get_label_annos(result_path_step)
            # result = get_official_eval_result_v2(gt_annos, dt_annos, class_names)
            result = get_official_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            writer.add_text('eval_result', json.dumps(result, indent=2),
                            global_step)
            result = get_coco_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            if pickle_result:
                with open(result_path_step / "result.pkl", 'wb') as f:
                    pickle.dump(dt_annos, f)
            writer.add_text('eval_result', result, global_step)
            #net.train()
            fusion_layer.train()
    except Exception as e:

        torchplus.train.save_models(model_dir, [fusion_layer, optimizer],
                                    net.get_global_step())

        logf.close()
        raise e
    # save model before exit

    torchplus.train.save_models(model_dir, [fusion_layer, optimizer],
                                net.get_global_step())

    logf.close()