Ejemplo n.º 1
0
 def evaluation(self, detections, label_dir, output_dir):
     tags = [itm["tag"] for itm in self._kitti_infos]
     calibs = [itm["calib"] for itm in self._kitti_infos]
     det_path = os.path.join(output_dir, "data")
     assert len(tags) == len(detections) == len(calibs)
     self.save_detections(detections, tags, calibs, det_path)
     assert len(detections) > 50
     dt_annos = kitti.get_label_annos(det_path)
     gt_path = os.path.join(label_dir)
     val_image_ids = os.listdir(det_path)
     val_image_ids = [int(itm.split(".")[0]) for itm in val_image_ids]
     val_image_ids.sort()
     gt_annos = kitti.get_label_annos(gt_path, val_image_ids)
     cls_to_idx = {
         "Car": 0,
         "Pedestrian": 1,
         "Cyclist": 2,
         "Van": 3,
         "Truck": 4,
         "Tram": 5,
         "Person_sitting": 6
     }
     current_classes = [cls_to_idx[itm] for itm in self._class_names]
     val_ap_dict = get_official_eval_result(gt_annos, dt_annos,
                                            current_classes)
     return val_ap_dict
Ejemplo n.º 2
0
 def evaluation(self, detections, output_dir):
     """
     detection
     When you want to eval your own dataset, you MUST set correct
     the z axis and box z center.
     If you want to eval by my KITTI eval function, you must 
     provide the correct format annotations.
     ground_truth_annotations format:
     {
         bbox: [N, 4], if you fill fake data, MUST HAVE >25 HEIGHT!!!!!!
         alpha: [N], you can use -10 to ignore it.
         occluded: [N], you can use zero.
         truncated: [N], you can use zero.
         name: [N]
         location: [N, 3] center of 3d box.
         dimensions: [N, 3] dim of 3d box.
         rotation_y: [N] angle.
     }
     all fields must be filled, but some fields can fill
     zero.
     """
     if "annos" not in self._vkitti_infos[0]:
         return None
     gt_annos = [info["annos"] for info in self._vkitti_infos]
     dt_annos = self.convert_detection_to_kitti_annos(detections)
     # firstly convert standard detection to kitti-format dt annos
     z_axis = 1  # KITTI camera format use y as regular "z" axis.
     z_center = 1.0  # KITTI camera box's center is [0.5, 1, 0.5]
     # for regular raw lidar data, z_axis = 2, z_center = 0.5.
     result_official_dict = get_official_eval_result(gt_annos,
                                                     dt_annos,
                                                     self._class_names,
                                                     z_axis=z_axis,
                                                     z_center=z_center)
     # SID: Removing COCO evaluation.
     # result_coco = get_coco_eval_result(
     #     gt_annos,
     #     dt_annos,
     #     self._class_names,
     #     z_axis=z_axis,
     #     z_center=z_center)
     return {
         "results": {
             "official": result_official_dict["result"],
             # "coco": result_coco["result"],
         },
         "detail": {
             "eval.kitti": {
                 "official": result_official_dict["detail"],
                 # "coco": result_coco["detail"]
             }
         },
     }
Ejemplo n.º 3
0
 def evaluation(self, dt_annos):
     """dt_annos have same format as ground_truth_annotations.
     When you want to eval your own dataset, you MUST set correct
     the z axis and box z center.
     """
     gt_annos = self.ground_truth_annotations
     if gt_annos is None:
         return None, None
     z_axis = 1  # KITTI camera format use y as regular "z" axis.
     z_center = 1.0  # KITTI camera box's center is [0.5, 1, 0.5]
     # for regular raw lidar data, z_axis = 2, z_center = 0.5.
     result_official = get_official_eval_result(gt_annos,
                                                dt_annos,
                                                self._class_names,
                                                z_axis=z_axis,
                                                z_center=z_center)
     result_coco = get_coco_eval_result(gt_annos,
                                        dt_annos,
                                        self._class_names,
                                        z_axis=z_axis,
                                        z_center=z_center)
     return result_official, result_coco
Ejemplo n.º 4
0
 def evaluation(self, detections, output_dir):
     """
     detection
     When you want to eval your own dataset, you MUST set correct
     the z axis and box z center.
     """
     gt_annos = self.ground_truth_annotations
     if gt_annos is None:
         return None
     dt_annos = self.convert_detection_to_kitti_annos(detections)
     # firstly convert standard detection to kitti-format dt annos
     z_axis = 1  # KITTI camera format use y as regular "z" axis.
     z_center = 1.0  # KITTI camera box's center is [0.5, 1, 0.5]
     # for regular raw lidar data, z_axis = 2, z_center = 0.5.
     result_official_dict = get_official_eval_result(gt_annos,
                                                     dt_annos,
                                                     self._class_names,
                                                     z_axis=z_axis,
                                                     z_center=z_center)
     result_coco = get_coco_eval_result(gt_annos,
                                        dt_annos,
                                        self._class_names,
                                        z_axis=z_axis,
                                        z_center=z_center)
     return {
         "results": {
             "official": result_official_dict["result"],
             "coco": result_coco["result"],
         },
         "detail": {
             "eval.kitti": {
                 "official": result_official_dict["detail"],
                 "coco": result_coco["detail"]
             }
         },
     }
Ejemplo n.º 5
0
    def evaluation_kitti(self, detections, output_dir):
        """eval by kitti evaluation tool.
        I use num_lidar_pts to set easy, mod, hard.
        easy: num>15, mod: num>7, hard: num>0.
        """
        print("++++++++NuScenes KITTI unofficial Evaluation:")
        print(
            "++++++++easy: num_lidar_pts>15, mod: num_lidar_pts>7, hard: num_lidar_pts>0"
        )
        print("++++++++The bbox AP is invalid. Don't forget to ignore it.")
        class_names = self._class_names
        gt_annos = self.ground_truth_annotations
        if gt_annos is None:
            return None
        gt_annos = deepcopy(gt_annos)
        detections = deepcopy(detections)
        dt_annos = []
        for det in detections:
            final_box_preds = det["box3d_lidar"].detach().cpu().numpy()
            label_preds = det["label_preds"].detach().cpu().numpy()
            scores = det["scores"].detach().cpu().numpy()
            anno = kitti.get_start_result_anno()
            num_example = 0
            box3d_lidar = final_box_preds
            for j in range(box3d_lidar.shape[0]):
                anno["bbox"].append(np.array([0, 0, 50, 50]))
                anno["alpha"].append(-10)
                anno["dimensions"].append(box3d_lidar[j, 3:6])
                anno["location"].append(box3d_lidar[j, :3])
                anno["rotation_y"].append(box3d_lidar[j, 6])
                anno["name"].append(class_names[int(label_preds[j])])
                anno["truncated"].append(0.0)
                anno["occluded"].append(0)
                anno["score"].append(scores[j])
                num_example += 1
            if num_example != 0:
                anno = {n: np.stack(v) for n, v in anno.items()}
                dt_annos.append(anno)
            else:
                dt_annos.append(kitti.empty_result_anno())
            num_example = dt_annos[-1]["name"].shape[0]
            dt_annos[-1]["metadata"] = det["metadata"]

        for anno in gt_annos:
            names = anno["name"].tolist()
            mapped_names = []
            for n in names:
                if n in self.NameMapping:
                    mapped_names.append(self.NameMapping[n])
                else:
                    mapped_names.append(n)
            anno["name"] = np.array(mapped_names)
        for anno in dt_annos:
            names = anno["name"].tolist()
            mapped_names = []
            for n in names:
                if n in self.NameMapping:
                    mapped_names.append(self.NameMapping[n])
                else:
                    mapped_names.append(n)
            anno["name"] = np.array(mapped_names)
        mapped_class_names = []
        for n in self._class_names:
            if n in self.NameMapping:
                mapped_class_names.append(self.NameMapping[n])
            else:
                mapped_class_names.append(n)

        z_axis = 2
        z_center = 0.5
        # for regular raw lidar data, z_axis = 2, z_center = 0.5.
        result_official_dict = get_official_eval_result(gt_annos,
                                                        dt_annos,
                                                        mapped_class_names,
                                                        z_axis=z_axis,
                                                        z_center=z_center)
        result_coco = get_coco_eval_result(gt_annos,
                                           dt_annos,
                                           mapped_class_names,
                                           z_axis=z_axis,
                                           z_center=z_center)
        return {
            "results": {
                "official": result_official_dict["result"],
                "coco": result_coco["result"],
            },
            "detail": {
                "official": result_official_dict["detail"],
                "coco": result_coco["detail"],
            },
        }
Ejemplo n.º 6
0
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          pickle_result=True):
    """train a VoxelNet model specified by a config file.
    """
    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)

    model_dir = pathlib.Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    eval_checkpoint_dir = model_dir / 'eval_checkpoints'
    eval_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config_file_bkp = "pipeline.config"
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)
    shutil.copyfile(config_path, str(model_dir / config_file_bkp))
    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    class_names = list(input_cfg.class_names)
    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    ######################
    # BUILD TARGET ASSIGNER
    ######################
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    ######################
    # BUILD NET
    ######################
    center_limit_range = model_cfg.post_center_limit_range
    net = second_builder.build(model_cfg, voxel_generator, target_assigner)
    net.cuda()
    # net_train = torch.nn.DataParallel(net).cuda()
    print("num_trainable parameters:", len(list(net.parameters())))
    # for n, p in net.named_parameters():
    #     print(n, p.shape)
    ######################
    # BUILD OPTIMIZER
    ######################
    # we need global_step to create lr_scheduler, so restore net first.
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    gstep = net.get_global_step() - 1
    optimizer_cfg = train_cfg.optimizer
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    optimizer = optimizer_builder.build(optimizer_cfg, net.parameters())
    if train_cfg.enable_mixed_precision:
        loss_scale = train_cfg.loss_scale_factor
        mixed_optimizer = torchplus.train.MixedPrecisionWrapper(
            optimizer, loss_scale)
    else:
        mixed_optimizer = optimizer
    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer, gstep)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################

    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner)
    eval_dataset = input_reader_builder.build(eval_input_cfg,
                                              model_cfg,
                                              training=False,
                                              voxel_generator=voxel_generator,
                                              target_assigner=target_assigner)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=input_cfg.batch_size,
                                             shuffle=True,
                                             num_workers=input_cfg.num_workers,
                                             pin_memory=False,
                                             collate_fn=merge_second_batch,
                                             worker_init_fn=_worker_init_fn)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)
    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    log_path = model_dir / 'log.txt'
    logf = open(log_path, 'a')
    logf.write(proto_str)
    logf.write("\n")
    summary_dir = model_dir / 'summary'
    summary_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(summary_dir))

    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t

    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    # total_loop = remain_steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:
        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.steps_per_eval > train_cfg.steps:
                steps = train_cfg.steps % train_cfg.steps_per_eval
            else:
                steps = train_cfg.steps_per_eval
            for step in range(steps):
                lr_scheduler.step()
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example, float_dtype)

                batch_size = example["anchors"].shape[0]

                ret_dict = net(example_torch)

                # box_preds = ret_dict["box_preds"]
                cls_preds = ret_dict["cls_preds"]
                loss = ret_dict["loss"].mean()
                cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                cls_pos_loss = ret_dict["cls_pos_loss"]
                cls_neg_loss = ret_dict["cls_neg_loss"]
                loc_loss = ret_dict["loc_loss"]
                cls_loss = ret_dict["cls_loss"]
                dir_loss_reduced = ret_dict["dir_loss_reduced"]
                cared = ret_dict["cared"]
                labels = example_torch["labels"]
                if train_cfg.enable_mixed_precision:
                    loss *= loss_scale
                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
                mixed_optimizer.step()
                mixed_optimizer.zero_grad()
                net.update_global_step()
                net_metrics = net.update_metrics(cls_loss_reduced,
                                                 loc_loss_reduced, cls_preds,
                                                 labels, cared)

                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                if 'anchors_mask' not in example_torch:
                    num_anchors = example_torch['anchors'].shape[1]
                else:
                    num_anchors = int(example_torch['anchors_mask'][0].sum())
                global_step = net.get_global_step()
                if global_step % display_step == 0:
                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]
                    metrics["step"] = global_step
                    metrics["steptime"] = step_time
                    metrics.update(net_metrics)
                    metrics["loss"] = {}
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())
                    # if unlabeled_training:
                    #     metrics["loss"]["diff_rt"] = float(
                    #         diff_loc_loss_reduced.detach().cpu().numpy())
                    if model_cfg.use_direction_classifier:
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())
                    metrics["num_vox"] = int(example_torch["voxels"].shape[0])
                    metrics["num_pos"] = int(num_pos)
                    metrics["num_neg"] = int(num_neg)
                    metrics["num_anchors"] = int(num_anchors)
                    metrics["lr"] = float(
                        mixed_optimizer.param_groups[0]['lr'])
                    metrics["image_idx"] = example['image_idx'][0]
                    flatted_metrics = flat_nested_json_dict(metrics)
                    flatted_summarys = flat_nested_json_dict(metrics, "/")
                    for k, v in flatted_summarys.items():
                        if isinstance(v, (list, tuple)):
                            v = {str(i): e for i, e in enumerate(v)}
                            writer.add_scalars(k, v, global_step)
                        else:
                            writer.add_scalar(k, v, global_step)
                    metrics_str_list = []
                    for k, v in flatted_metrics.items():
                        if isinstance(v, float):
                            metrics_str_list.append(f"{k}={v:.3}")
                        elif isinstance(v, (list, tuple)):
                            if v and isinstance(v[0], float):
                                v_str = ', '.join([f"{e:.3}" for e in v])
                                metrics_str_list.append(f"{k}=[{v_str}]")
                            else:
                                metrics_str_list.append(f"{k}={v}")
                        else:
                            metrics_str_list.append(f"{k}={v}")
                    log_str = ', '.join(metrics_str_list)
                    print(log_str, file=logf)
                    print(log_str)
                ckpt_elasped_time = time.time() - ckpt_start_time
                if ckpt_elasped_time > train_cfg.save_checkpoints_secs:
                    torchplus.train.save_models(model_dir, [net, optimizer],
                                                net.get_global_step())
                    ckpt_start_time = time.time()
            total_step_elapsed += steps
            torchplus.train.save_models(model_dir, [net, optimizer],
                                        net.get_global_step())

            # Ensure that all evaluation points are saved forever
            torchplus.train.save_models(eval_checkpoint_dir, [net, optimizer],
                                        net.get_global_step(),
                                        max_to_keep=100)

            net.eval()
            result_path_step = result_path / f"step_{net.get_global_step()}"
            result_path_step.mkdir(parents=True, exist_ok=True)
            print("#################################")
            print("#################################", file=logf)
            print("# EVAL")
            print("# EVAL", file=logf)
            print("#################################")
            print("#################################", file=logf)
            print("Generate output labels...")
            print("Generate output labels...", file=logf)
            t = time.time()
            dt_annos = []
            prog_bar = ProgressBar()
            prog_bar.start(len(eval_dataset) // eval_input_cfg.batch_size + 1)
            for example in iter(eval_dataloader):
                example = example_convert_to_torch(example, float_dtype)
                if pickle_result:
                    dt_annos += predict_kitti_to_anno(net, example,
                                                      class_names,
                                                      center_limit_range,
                                                      model_cfg.lidar_input)
                else:
                    _predict_kitti_to_file(net, example, result_path_step,
                                           class_names, center_limit_range,
                                           model_cfg.lidar_input)

                prog_bar.print_bar()

            sec_per_ex = len(eval_dataset) / (time.time() - t)
            print(f"avg forward time per example: {net.avg_forward_time:.3f}")
            print(
                f"avg postprocess time per example: {net.avg_postprocess_time:.3f}"
            )

            net.clear_time_metrics()
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                  file=logf)
            gt_annos = [
                info["annos"] for info in eval_dataset.dataset.kitti_infos
            ]
            if not pickle_result:
                dt_annos = kitti.get_label_annos(result_path_step)
            result, mAPbbox, mAPbev, mAP3d, mAPaos = get_official_eval_result(
                gt_annos, dt_annos, class_names, return_data=True)
            print(result, file=logf)
            print(result)
            writer.add_text('eval_result', result, global_step)

            for i, class_name in enumerate(class_names):
                writer.add_scalar('bev_ap:{}'.format(class_name),
                                  mAPbev[i, 1, 0], global_step)
                writer.add_scalar('3d_ap:{}'.format(class_name),
                                  mAP3d[i, 1, 0], global_step)
                writer.add_scalar('aos_ap:{}'.format(class_name),
                                  mAPaos[i, 1, 0], global_step)
            writer.add_scalar('bev_map', np.mean(mAPbev[:, 1, 0]), global_step)
            writer.add_scalar('3d_map', np.mean(mAP3d[:, 1, 0]), global_step)
            writer.add_scalar('aos_map', np.mean(mAPaos[:, 1, 0]), global_step)

            result = get_coco_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            if pickle_result:
                with open(result_path_step / "result.pkl", 'wb') as f:
                    pickle.dump(dt_annos, f)
            writer.add_text('eval_result', result, global_step)
            net.train()
    except Exception as e:
        torchplus.train.save_models(model_dir, [net, optimizer],
                                    net.get_global_step())
        logf.close()
        raise e
    # save model before exit
    torchplus.train.save_models(model_dir, [net, optimizer],
                                net.get_global_step())
    logf.close()
Ejemplo n.º 7
0
def evaluate(config_path,
             model_dir,
             result_path=None,
             predict_test=False,
             ckpt_path=None,
             ref_detfile=None,
             pickle_result=True):
    model_dir = pathlib.Path(model_dir)
    if predict_test:
        result_name = 'predict_test'
    else:
        result_name = 'eval_results'
    if result_path is None:
        result_path = model_dir / result_name
    else:
        result_path = pathlib.Path(result_path)
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)

    input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config
    class_names = list(input_cfg.class_names)
    center_limit_range = model_cfg.post_center_limit_range
    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)

    net = second_builder.build(model_cfg, voxel_generator, target_assigner)
    net.cuda()
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)

    if ckpt_path is None:
        torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    else:
        torchplus.train.restore(ckpt_path, net)

    eval_dataset = input_reader_builder.build(input_cfg,
                                              model_cfg,
                                              training=False,
                                              voxel_generator=voxel_generator,
                                              target_assigner=target_assigner)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=input_cfg.batch_size,
        shuffle=False,
        num_workers=input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32

    net.eval()
    result_path_step = result_path / f"step_{net.get_global_step()}"
    result_path_step.mkdir(parents=True, exist_ok=True)
    t = time.time()
    dt_annos = []
    global_set = None
    print("Generate output labels...")
    bar = ProgressBar()
    bar.start(len(eval_dataset) // input_cfg.batch_size + 1)

    for example in iter(eval_dataloader):
        example = example_convert_to_torch(example, float_dtype)
        if pickle_result:
            dt_annos += predict_kitti_to_anno(net, example, class_names,
                                              center_limit_range,
                                              model_cfg.lidar_input,
                                              global_set)
        else:
            _predict_kitti_to_file(net, example, result_path_step, class_names,
                                   center_limit_range, model_cfg.lidar_input)
        bar.print_bar()

    sec_per_example = len(eval_dataset) / (time.time() - t)
    print(f'generate label finished({sec_per_example:.2f}/s). start eval:')

    print(f"avg forward time per example: {net.avg_forward_time:.3f}")
    print(f"avg postprocess time per example: {net.avg_postprocess_time:.3f}")
    if not predict_test:
        gt_annos = [info["annos"] for info in eval_dataset.dataset.kitti_infos]
        if not pickle_result:
            dt_annos = kitti.get_label_annos(result_path_step)
        result = get_official_eval_result(gt_annos, dt_annos, class_names)
        print(result)
        result = get_coco_eval_result(gt_annos, dt_annos, class_names)
        print(result)
        if pickle_result:
            with open(result_path_step / "result.pkl", 'wb') as f:
                pickle.dump(dt_annos, f)
def evaluate(config_path,
             model_dir,
             use_second_stage=False,
             use_endtoend=False,
             result_path=None,
             predict_test=False,
             ckpt_path=None,
             ref_detfile=None,
             pickle_result=True,
             measure_time=False,
             batch_size=None):
    model_dir = pathlib.Path(model_dir)
    if predict_test:
        result_name = 'predict_test_0095'
    else:
        result_name = 'eval_results'
    if result_path is None:
        result_path = model_dir / result_name
    else:
        result_path = pathlib.Path(result_path)
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)

    input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config
    
    center_limit_range = model_cfg.post_center_limit_range
    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    class_names = target_assigner.classes
    if use_second_stage:    
        net = second_2stage_builder.build(model_cfg, voxel_generator, target_assigner, measure_time=measure_time)
    elif use_endtoend:
        net = second_endtoend_builder.build(model_cfg, voxel_generator, target_assigner, measure_time=measure_time)
    else:
        net = second_builder.build(model_cfg, voxel_generator, target_assigner, measure_time=measure_time)
    net.cuda()
    #########################################
    # net = torch.nn.DataParallel(net)
    #########################################
    if ckpt_path is None:
        torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    else:
        torchplus.train.restore(ckpt_path, net)
    if train_cfg.enable_mixed_precision:
        net.half()
        print("half inference!")
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    batch_size = batch_size or input_cfg.batch_size
    eval_dataset = input_reader_builder_tr.build(
        input_cfg,
        model_cfg,
        training=False,
        voxel_generator=voxel_generator,
        target_assigner=target_assigner)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,# input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32

    net.eval()
    result_path_step = result_path / f"step_{net.get_global_step()}"
    result_path_step.mkdir(parents=True, exist_ok=True)
    t = time.time()
    dt_annos = []
    global_set = None
    print("Generate output labels...")
    bar = ProgressBar()
    bar.start((len(eval_dataset) + batch_size - 1) // batch_size)
    prep_example_times = []
    prep_times = []
    t2 = time.time()
    for example in iter(eval_dataloader):
        if measure_time:
            prep_times.append(time.time() - t2)
            t1 = time.time()
            torch.cuda.synchronize()
        example = example_convert_to_torch(example, float_dtype)
        if measure_time:
            torch.cuda.synchronize()
            prep_example_times.append(time.time() - t1)

        if pickle_result:
            dt_annos += predict_kitti_to_anno(
                net, example, class_names, center_limit_range,
                model_cfg.lidar_input, global_set)
        else:
            _predict_kitti_to_file(net, example, result_path_step, class_names,
                                   center_limit_range, model_cfg.lidar_input)
        # print(json.dumps(net.middle_feature_extractor.middle_conv.sparity_dict))
        bar.print_bar()
        if measure_time:
            t2 = time.time()

    sec_per_example = len(eval_dataset) / (time.time() - t)
    print(f'generate label finished({sec_per_example:.2f}/s). start eval:')
    if measure_time:
        print(f"avg example to torch time: {np.mean(prep_example_times) * 1000:.3f} ms")
        print(f"avg prep time: {np.mean(prep_times) * 1000:.3f} ms")
    for name, val in net.get_avg_time_dict().items():
        print(f"avg {name} time = {val * 1000:.3f} ms")
    if not predict_test:
        gt_annos = [info["annos"] for info in eval_dataset.dataset.kitti_infos]
        img_idx = [info["image_idx"] for info in eval_dataset.dataset.kitti_infos]
        if not pickle_result:
            dt_annos = kitti.get_label_annos(result_path_step)
        result = get_official_eval_result(gt_annos, dt_annos, class_names)
        # print(json.dumps(result, indent=2))
        print(result)
        result = get_coco_eval_result(gt_annos, dt_annos, class_names)
        print(result)
        if pickle_result:
            with open(result_path_step / "result.pkl", 'wb') as f:
                pickle.dump(dt_annos, f)
        # annos to txt file
        if True:
            os.makedirs(str(result_path_step) + '/txt', exist_ok=True)
            for i in range(len(dt_annos)):
                dt_annos[i]['dimensions'] = dt_annos[i]['dimensions'][:, [1, 2, 0]]
                result_lines = kitti.annos_to_kitti_label(dt_annos[i])
                image_idx = img_idx[i]
                with open(str(result_path_step) + '/txt/%06d.txt' % image_idx, 'w') as f:
                    for result_line in result_lines:
                        f.write(result_line + '\n')
                abcd = 1
    else:
        os.makedirs(str(result_path_step) + '/txt', exist_ok=True)
        img_idx = [info["image_idx"] for info in eval_dataset.dataset.kitti_infos]
        for i in range(len(dt_annos)):
            dt_annos[i]['dimensions'] = dt_annos[i]['dimensions'][:, [1, 2, 0]]
            result_lines = kitti.annos_to_kitti_label(dt_annos[i])
            image_idx = img_idx[i]
            with open(str(result_path_step) + '/txt/%06d.txt' % image_idx, 'w') as f:
                for result_line in result_lines:
                    f.write(result_line + '\n')
def train(config_path,
          model_dir,
          use_fusion=True,
          use_ft=False,
          use_second_stage=True,
          use_endtoend=True,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          local_rank=0,
          pickle_result=True,
          patchs=None):
    """train a VoxelNet mod[el specified by a config file.
    """
    ############ tracking
    config_tr_path = '/mnt/new_iou/second.pytorch/second/mmMOT/experiments/second/spatio_test/config.yaml'
    load_tr_path = '/mnt/new_iou/second.pytorch/second/mmMOT/experiments/second/spatio_test/results'
    with open(config_tr_path) as f:
        config_tr = yaml.load(f, Loader=yaml.FullLoader)

    result_path_tr = load_tr_path
    config_tr = EasyDict(config_tr['common'])
    config_tr.save_path = os.path.dirname(config_tr_path)

    # create model
    # model_tr = build_model(config_tr)
    # model_tr.cuda()

    # optimizer_tr = build_optim(model_tr, config_tr)

    criterion_tr = build_criterion(config_tr.loss)

    last_iter = -1
    best_mota = 0
    # if load_tr_path:
    #     if False:
    #         best_mota, last_iter = load_state(
    #             load_tr_path, model_tr, optimizer=optimizer_tr)
    #     else:
    #         load_state(load_tr_path, model_tr)

    cudnn.benchmark = True

    # Data loading code
    train_transform, valid_transform = build_augmentation(config_tr.augmentation)

    # # train
    # train_dataset = build_dataset(
    #     config_tr,
    #     set_source='train',
    #     evaluate=False,
    #     train_transform=train_transform)
    # trainval_dataset = build_dataset(
    #     config_tr,
    #     set_source='train',
    #     evaluate=True,
    #     valid_transform=valid_transform)
    # val_dataset = build_dataset(
    #     config_tr,
    #     set_source='val',
    #     evaluate=True,
    #     valid_transform=valid_transform)

    # train_sampler = DistributedGivenIterationSampler(
    #     train_dataset,
    #     config_tr.lr_scheduler.max_iter,
    #     config_tr.batch_size,
    #     world_size=1,
    #     rank=0,
    #     last_iter=last_iter)

    # import pdb; pdb.set_trace()
    # train_loader = DataLoader(
    #     train_dataset,
    #     batch_size=config_tr.batch_size,
    #     shuffle=False,
    #     num_workers=config_tr.workers,
    #     pin_memory=True)

    tb_logger = SummaryWriter(config_tr.save_path + '/events')
    logger = create_logger('global_logger', config_tr.save_path + '/log.txt')
    # logger.info('args: {}'.format(pprint.pformat(args)))
    logger.info('config: {}'.format(pprint.pformat(config_tr)))

    # tracking_module = TrackingModule(model_tr, criterion_tr,
                                    #  config_tr.det_type)
    # tracking_module.model.train()
    #### tracking setup done

    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    patchs = patchs or []
    model_dir = pathlib.Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config_file_bkp = "pipeline.config"
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)
    for patch in patchs:
        patch = "config." + patch 
        exec(patch)
    shutil.copyfile(config_path, str(model_dir / config_file_bkp))
    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    ######################
    # BUILD TARGET ASSIGNER
    ######################
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    class_names = target_assigner.classes
    ######################
    # BUILD NET
    ######################
    center_limit_range = model_cfg.post_center_limit_range
    # if use_second_stage:
    #     net = second_2stage_builder.build(model_cfg, voxel_generator, target_assigner)
    if use_endtoend:
        net = second_endtoend_builder_spatio.build(model_cfg, voxel_generator, target_assigner, criterion_tr, config_tr.det_type)
    else:
        net = second_builder.build(model_cfg, voxel_generator, target_assigner)
    net.cuda()
    print("num_trainable parameters:", len(list(net.parameters())))

    for n, p in net.named_parameters():
        print(n, p.shape)
    # pth_name = './pre_weight/first_stage_gating_det/voxelnet-17013.tckpt'
    pth_name = './pre_weight/second_stage_gating_det/voxelnet-35000.tckpt'

    res_pre_weights = torch.load(pth_name)
    new_res_state_dict = OrderedDict()
    model_dict = net.state_dict()
    for k,v in res_pre_weights.items():
        if 'global_step' not in k:
            # if 'dir' not in k:
            new_res_state_dict[k] = v
    model_dict.update(new_res_state_dict)
    net.load_state_dict(model_dict)

    # for k, weight in dict(net.named_parameters()).items(): # lidar_conv, p_lidar_conv, fusion_module, w_det, w_link, appearance, point_net
    #     if 'middle_feature_extractor' in '%s'%(k) or 'rpn' in '%s'%(k) or 'second_rpn' in '%s'%(k):
    #         weight.requires_grad = False

    # BUILD OPTIMIZER
    #####################
    # we need global_step to create lr_scheduler, so restore net first.
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    gstep = net.get_global_step() - 1
    optimizer_cfg = train_cfg.optimizer
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    loss_scale = train_cfg.loss_scale_factor
    mixed_optimizer = optimizer_builder.build(optimizer_cfg, net, mixed=train_cfg.enable_mixed_precision, loss_scale=loss_scale)
    optimizer = mixed_optimizer

    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer, train_cfg.steps)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################
    # import pdb; pdb.set_trace()
    dataset = input_reader_builder_tr_vid_spatio.build(
        input_cfg,
        model_cfg,
        training=True,
        voxel_generator=voxel_generator,
        target_assigner=target_assigner,
        config_tr=config_tr,
        set_source='train',
        evaluate=False,
        train_transform=train_transform)
    eval_dataset = input_reader_builder_tr_vid_spatio.build(
        eval_input_cfg,
        model_cfg,
        training=False,
        voxel_generator=voxel_generator,
        target_assigner=target_assigner,
        config_tr=config_tr,
        set_source='val',
        evaluate=True,
        valid_transform=valid_transform)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=input_cfg.batch_size,
        shuffle=True,
        num_workers=input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch_tr_vid_spatio,
        worker_init_fn=_worker_init_fn)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch_tr_vid_spatio)
    
    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    training_detail = []
    log_path = model_dir / 'log.txt'
    training_detail_path = model_dir / 'log.json'
    if training_detail_path.exists():
        with open(training_detail_path, 'r') as f:
            training_detail = json.load(f)
    logf = open(log_path, 'a')
    logf.write(proto_str)
    logf.write("\n")
    summary_dir = model_dir / 'summary'
    summary_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(summary_dir))

    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t

    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    
    # optimizer_tr.zero_grad()
    logger = logging.getLogger('global_logger')
    best_mota = 0
    losses = AverageMeter(config_tr.print_freq)

    total_steps = train_cfg.steps
    total_loop = total_steps // len(dataloader)
    
    kkkk = 0
    for step in range(total_loop):
        for i, (example) in enumerate(dataloader):

            curr_step = 0 + i
            kkkk += 1
            lr_scheduler.step(net.get_global_step())

            example_torch = example_convert_to_torch(example, float_dtype)

            batch_size = example["anchors"].shape[0]

            ret_dict = net(example_torch, train_param=True)

            cls_preds = ret_dict["cls_preds"]
            loss = ret_dict["loss"].mean()
            cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
            loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
            cls_pos_loss = ret_dict["cls_pos_loss"]
            cls_neg_loss = ret_dict["cls_neg_loss"]
            loc_loss = ret_dict["loc_loss"]
            cls_loss = ret_dict["cls_loss"]
            dir_loss_reduced = ret_dict["dir_loss_reduced"]
            cared = ret_dict["cared"]
            # loss_tr = ret_dict["loss_tr"]

            if use_second_stage or use_endtoend:
                labels = ret_dict["labels"]
            else:
                labels = example_torch["labels"]
            if train_cfg.enable_mixed_precision:
                loss *= loss_scale

            try:
                loss.backward()
            except:
                abc = 1
            #     import pdb; pdb.set_trace()
            #     abc = 1
            # torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
            # optimizer_tr.step()
            # optimizer_tr.zero_grad()
            mixed_optimizer.step()
            mixed_optimizer.zero_grad()
            net.update_global_step()
            net_metrics = net.update_metrics(cls_loss_reduced,
                                                loc_loss_reduced, cls_preds,
                                                labels, cared)

            step_time = (time.time() - t)
            t = time.time()
            metrics = {}
            num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
            num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
            if 'anchors_mask' not in example_torch:
                num_anchors = example_torch['anchors'].shape[1]
            else:
                num_anchors = int(example_torch['anchors_mask'][0].sum())
            global_step = net.get_global_step()
            # print(step)
            if global_step % display_step == 0:
                loc_loss_elem = [
                    float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                            batch_size) for i in range(loc_loss.shape[-1])
                ]
                metrics["type"] = "step_info"
                metrics["step"] = global_step
                metrics["steptime"] = step_time
                metrics.update(net_metrics)
                metrics["loss"] = {}
                metrics["loss"]["loc_elem"] = loc_loss_elem
                metrics["loss"]["cls_pos_rt"] = float(
                    cls_pos_loss.detach().cpu().numpy())
                metrics["loss"]["cls_neg_rt"] = float(
                    cls_neg_loss.detach().cpu().numpy())
                if model_cfg.use_direction_classifier:
                    metrics["loss"]["dir_rt"] = float(
                        dir_loss_reduced.detach().cpu().numpy())
                metrics["num_vox"] = int(example_torch["voxels"].shape[0])
                metrics["num_pos"] = int(num_pos)
                metrics["num_neg"] = int(num_neg)
                metrics["num_anchors"] = int(num_anchors)
                metrics["lr"] = float(
                    optimizer.lr)

                metrics["image_idx"] = example['image_idx'][0][7:]
                training_detail.append(metrics)
                flatted_metrics = flat_nested_json_dict(metrics)
                flatted_summarys = flat_nested_json_dict(metrics, "/")
                for k, v in flatted_summarys.items():
                    if isinstance(v, (list, tuple)):
                        v = {str(i): e for i, e in enumerate(v)}
                        if type(v) != str and ('loc_elem' not in k):
                            writer.add_scalars(k, v, global_step)
                    else:
                        if (type(v) != str) and ('loc_elem' not in k):
                            writer.add_scalar(k, v, global_step)

                metrics_str_list = []
                for k, v in flatted_metrics.items():
                    if isinstance(v, float):
                        metrics_str_list.append(f"{k}={v:.3}")
                    elif isinstance(v, (list, tuple)):
                        if v and isinstance(v[0], float):
                            v_str = ', '.join([f"{e:.3}" for e in v])
                            metrics_str_list.append(f"{k}=[{v_str}]")
                        else:
                            metrics_str_list.append(f"{k}={v}")
                    else:
                        metrics_str_list.append(f"{k}={v}")
                log_str = ', '.join(metrics_str_list)
                print(log_str, file=logf)
                print(log_str)

            ckpt_elasped_time = time.time() - ckpt_start_time
            if ckpt_elasped_time > train_cfg.save_checkpoints_secs:
                torchplus.train.save_models(model_dir, [net, optimizer], net.get_global_step())

                ckpt_start_time = time.time()

            if kkkk > 0 and (kkkk) % config_tr.val_freq == 0:
            # if True:
                torchplus.train.save_models(model_dir, [net, optimizer], net.get_global_step())
                net.eval()
                result_path_step = result_path / f"step_{net.get_global_step()}"
                result_path_step.mkdir(parents=True, exist_ok=True)
                print("#################################")
                print("#################################", file=logf)
                print("# EVAL")
                print("# EVAL", file=logf)
                print("#################################")
                print("#################################", file=logf)
                print("Generate output labels...")
                print("Generate output labels...", file=logf)
                t = time.time()
                dt_annos = []
                prog_bar = ProgressBar()
                net.clear_timer()
                prog_bar.start((len(eval_dataset) + eval_input_cfg.batch_size - 1) // eval_input_cfg.batch_size)
                for example in iter(eval_dataloader):
                    example = example_convert_to_torch(example, float_dtype)
                    if pickle_result:
                        results = predict_kitti_to_anno(
                            net, example, class_names, center_limit_range,
                            model_cfg.lidar_input)
                        dt_annos += results

                    else:
                        _predict_kitti_to_file(net, example, result_path_step,
                                            class_names, center_limit_range,
                                            model_cfg.lidar_input)

                    prog_bar.print_bar()

                sec_per_ex = len(eval_dataset) / (time.time() - t)
                print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
                print(f'generate label finished({sec_per_ex:.2f}/s). start eval:',file=logf)
                gt_annos = [
                    info["annos"] for info in eval_dataset.dataset.kitti_infos
                ]
                if not pickle_result:
                    dt_annos = kitti.get_label_annos(result_path_step)
                # result = get_official_eval_result_v2(gt_annos, dt_annos, class_names)
                # print(json.dumps(result, indent=2), file=logf)
                result = get_official_eval_result(gt_annos, dt_annos, class_names)
                print(result, file=logf)
                print(result)
                result_1 = result.split("\n")[:5]
                result_2 = result.split("\n")[10:15]
                result_3 = result.split("\n")[20:25]
                emh = ['0_easy', '1_mod', '2_hard']
                result_save = result_1
                for i in range(len(result_save)-1):
                    save_targ = result_save[i+1]
                    name_val = save_targ.split(':')[0].split(' ')[0]
                    value_val = save_targ.split(':')[1:]
                    for ev in range(3):
                        each_val = value_val[0].split(',')[ev]
                        merge_txt = 'AP_kitti/car_70/' + name_val+'/'+emh[ev]
                        try:
                            writer.add_scalar(merge_txt, float(each_val), global_step)
                        except:
                            abc=1
                            import pdb; pdb.set_trace()
                            abc=1
                if pickle_result:
                    with open(result_path_step / "result.pkl", 'wb') as f:
                        pickle.dump(dt_annos, f)
                writer.add_text('eval_result', result, global_step)

                logger.info('Evaluation on validation set:')
                # MOTA, MOTP, recall, prec, F1, fp, fn, id_switches = validate(
                #     val_dataset,
                #     net,
                #     str(0 + 1),
                #     config_tr,
                #     result_path_tr,
                #     part='val')
                # print(MOTA, MOTP, recall, prec, F1, fp, fn, id_switches)

                # curr_step = step
                # if tb_logger is not None:
                #     tb_logger.add_scalar('prec', prec, curr_step)
                #     tb_logger.add_scalar('recall', recall, curr_step)
                #     tb_logger.add_scalar('mota', MOTA, curr_step)
                #     tb_logger.add_scalar('motp', MOTP, curr_step)
                #     tb_logger.add_scalar('fp', fp, curr_step)
                #     tb_logger.add_scalar('fn', fn, curr_step)
                #     tb_logger.add_scalar('f1', F1, curr_step)
                #     tb_logger.add_scalar('id_switches', id_switches, curr_step)
                    # if lr_scheduler is not None:
                        # tb_logger.add_scalar('lr', current_lr, curr_step)

                # is_best = MOTA > best_mota
                # best_mota = max(MOTA, best_mota)
                # print(best_mota)

                # import pdb; pdb.set_trace()
                # save_checkpoint(
                #     {   'step': net.get_global_step(),
                #         'score_arch': config_tr.model.score_arch,
                #         'appear_arch': config_tr.model.appear_arch,
                #         'best_mota': best_mota,
                #         'state_dict': tracking_module.model.state_dict(),
                #         'optimizer': tracking_module.optimizer.state_dict(),
                #     }, is_best, config_tr.save_path + '/ckpt')

                # net.train()

    # save model before exit
    torchplus.train.save_models(model_dir, [net, optimizer],
                                net.get_global_step())
    logf.close()
Ejemplo n.º 10
0
def train(config_path,
          model_dir,
          use_fusion=False,
          use_ft=False,
          use_second_stage=False,
          use_endtoend=False,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          local_rank=0,
          pickle_result=True,
          patchs=None):
    """train a VoxelNet model specified by a config file.
    """
    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    patchs = patchs or []
    model_dir = pathlib.Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config_file_bkp = "pipeline.config"
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)
    for patch in patchs:
        patch = "config." + patch
        exec(patch)
    shutil.copyfile(config_path, str(model_dir / config_file_bkp))
    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    ######################
    # BUILD TARGET ASSIGNER
    ######################
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    class_names = target_assigner.classes
    ######################
    # BUILD NET
    ######################
    center_limit_range = model_cfg.post_center_limit_range
    if use_second_stage:
        net = second_2stage_builder.build(model_cfg, voxel_generator,
                                          target_assigner)
    if use_endtoend:
        net = second_endtoend_builder.build(model_cfg, voxel_generator,
                                            target_assigner)
    else:
        net = second_builder.build(model_cfg, voxel_generator, target_assigner)
    net.cuda()
    # import pdb; pdb.set_trace()
    print("num_trainable parameters:", len(list(net.parameters())))
    # for n, p in net.named_parameters():
    #     print(n, p.shape)
    # pth_name = 'pre_weight/first_stage/fusion_split/voxelnet-35210.tckpt'
    # # pth_name = 'pre_weight/first_stage/fusion_split/voxelnet-20130.tckpt'

    # res_pre_weights = torch.load(pth_name)
    # new_res_state_dict = OrderedDict()
    # model_dict = net.state_dict()
    # for k,v in res_pre_weights.items():
    #     if 'global_step' not in k:
    #         if 'dir' not in k:
    #             new_res_state_dict[k] = v
    # model_dict.update(new_res_state_dict)
    # net.load_state_dict(model_dict)

    ######################
    if use_second_stage or use_endtoend:
        if use_fusion:
            # pth_name = 'pre_weight/8020/voxelnet-20130.tckpt'
            pth_name = 'pre_weight/first_stage/fusion_split/voxelnet-35210.tckpt'
            for i in range(30):
                print(
                    '################## load Fusion First stage weight complete #######################'
                )
        else:
            pth_name = 'pre_weight/first_stage/lidaronly/voxelnet-30950.tckpt'
            for i in range(30):
                print(
                    '################## load LiDAR Only First stage weight complete #######################'
                )

        res_pre_weights = torch.load(pth_name)
        new_res_state_dict = OrderedDict()
        model_dict = net.state_dict()
        for k, v in res_pre_weights.items():
            if 'global_step' not in k:
                if 'dir' not in k:
                    new_res_state_dict[k] = v
        model_dict.update(new_res_state_dict)
        net.load_state_dict(model_dict)

    ############ load FPN18 pre-weight #############
    if (use_fusion and not use_second_stage and not use_endtoend):
        # if True:
        #  or (use_endtoend and use_fusion):
        fpn_depth = 18
        pth_name = 'pre_weight/FPN' + str(fpn_depth) + '_retinanet_968.pth'
        res_pre_weights = torch.load(pth_name)
        new_res_state_dict = OrderedDict()
        model_dict = net.state_dict()
        for k, v in res_pre_weights['state_dict'].items():
            if ('regressionModel' not in k) and ('classificationModel'
                                                 not in k):
                name = k.replace('module', 'rpn')
                new_res_state_dict[name] = v
        model_dict.update(new_res_state_dict)
        net.load_state_dict(model_dict)
        for i in range(30):
            print('!!!!!!!!!!!!!!!!!! load FPN' + str(fpn_depth) +
                  ' weight complete !!!!!!!!!!!!!!!!!!')
    ################################################
    # BUILD OPTIMIZER
    #####################
    # we need global_step to create lr_scheduler, so restore net first.
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    gstep = net.get_global_step() - 1
    optimizer_cfg = train_cfg.optimizer
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    loss_scale = train_cfg.loss_scale_factor
    mixed_optimizer = optimizer_builder.build(
        optimizer_cfg,
        net,
        mixed=train_cfg.enable_mixed_precision,
        loss_scale=loss_scale)
    optimizer = mixed_optimizer
    """
    if train_cfg.enable_mixed_precision:
        mixed_optimizer = torchplus.train.MixedPrecisionWrapper(
            optimizer, loss_scale)
    else:
        mixed_optimizer = optimizer
    """
    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer,
                                              train_cfg.steps)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################

    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner)
    eval_dataset = input_reader_builder.build(eval_input_cfg,
                                              model_cfg,
                                              training=False,
                                              voxel_generator=voxel_generator,
                                              target_assigner=target_assigner)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=input_cfg.batch_size,
                                             shuffle=True,
                                             num_workers=input_cfg.num_workers,
                                             pin_memory=False,
                                             collate_fn=merge_second_batch,
                                             worker_init_fn=_worker_init_fn)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    training_detail = []
    log_path = model_dir / 'log.txt'
    training_detail_path = model_dir / 'log.json'
    if training_detail_path.exists():
        with open(training_detail_path, 'r') as f:
            training_detail = json.load(f)
    logf = open(log_path, 'a')
    logf.write(proto_str)
    logf.write("\n")
    summary_dir = model_dir / 'summary'
    summary_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(summary_dir))

    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t

    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    # total_loop = remain_steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:
        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.steps_per_eval > train_cfg.steps:
                steps = train_cfg.steps % train_cfg.steps_per_eval
            else:
                steps = train_cfg.steps_per_eval
            for step in range(steps):
                lr_scheduler.step(net.get_global_step())
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example, float_dtype)

                batch_size = example["anchors"].shape[0]

                ret_dict = net(example_torch)

                # box_preds = ret_dict["box_preds"]
                cls_preds = ret_dict["cls_preds"]
                loss = ret_dict["loss"].mean()
                cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                cls_pos_loss = ret_dict["cls_pos_loss"]
                cls_neg_loss = ret_dict["cls_neg_loss"]
                loc_loss = ret_dict["loc_loss"]
                cls_loss = ret_dict["cls_loss"]
                dir_loss_reduced = ret_dict["dir_loss_reduced"]
                cared = ret_dict["cared"]
                # idx_offset = ret_dict["idx_offset"]

                # labels = example_torch["labels"]
                if use_second_stage or use_endtoend:
                    labels = ret_dict["labels"]
                else:
                    labels = example_torch["labels"]
                if train_cfg.enable_mixed_precision:
                    loss *= loss_scale
                loss.backward()
                # import pdb; pdb.set_trace()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
                mixed_optimizer.step()
                mixed_optimizer.zero_grad()
                net.update_global_step()
                net_metrics = net.update_metrics(cls_loss_reduced,
                                                 loc_loss_reduced, cls_preds,
                                                 labels, cared)

                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                if 'anchors_mask' not in example_torch:
                    num_anchors = example_torch['anchors'].shape[1]
                else:
                    num_anchors = int(example_torch['anchors_mask'][0].sum())
                global_step = net.get_global_step()
                # print(step)
                if global_step % display_step == 0:
                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]
                    metrics["type"] = "step_info"
                    metrics["step"] = global_step
                    metrics["steptime"] = step_time
                    metrics.update(net_metrics)
                    metrics["loss"] = {}
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())
                    if model_cfg.use_direction_classifier:
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())
                    metrics["num_vox"] = int(example_torch["voxels"].shape[0])
                    metrics["num_pos"] = int(num_pos)
                    metrics["num_neg"] = int(num_neg)
                    metrics["num_anchors"] = int(num_anchors)
                    # metrics["idx_offset_mean"] = float(idx_offset.mean().detach().cpu().numpy())
                    # metrics["idx_offset_sum"] = float(idx_offset.sum().detach().cpu().numpy())
                    # metrics["lr"] = float(
                    #     mixed_optimizer.param_groups[0]['lr'])
                    metrics["lr"] = float(optimizer.lr)

                    metrics["image_idx"] = example['image_idx'][0]
                    training_detail.append(metrics)
                    flatted_metrics = flat_nested_json_dict(metrics)
                    flatted_summarys = flat_nested_json_dict(metrics, "/")
                    for k, v in flatted_summarys.items():
                        if isinstance(v, (list, tuple)):
                            v = {str(i): e for i, e in enumerate(v)}
                            if type(v) != str and ('loc_elem' not in k):
                                writer.add_scalars(k, v, global_step)
                        else:
                            if (type(v) != str) and ('loc_elem' not in k):
                                writer.add_scalar(k, v, global_step)

                    # if use_second_stage or use_endtoend:
                    #     bev_logs =  ret_dict['bev_crops_output'][:64,0,...].view(64,1,14,14)
                    #     bev_vis = torchvision.utils.make_grid(bev_logs,normalize=True,scale_each=True)
                    #     writer.add_image('bev_crop',img_tensor=bev_vis, global_step=global_step)
                    #     if ret_dict['concat_crops_output'] is not None:
                    #         concat_logs =  ret_dict['concat_crops_output'][:64,0,...].view(64,1,14,14)
                    #         concat_vis = torchvision.utils.make_grid(concat_logs,normalize=True,scale_each=True)
                    #         writer.add_image('concat_crop',img_tensor=concat_vis, global_step=global_step)

                    metrics_str_list = []
                    for k, v in flatted_metrics.items():
                        if isinstance(v, float):
                            metrics_str_list.append(f"{k}={v:.3}")
                        elif isinstance(v, (list, tuple)):
                            if v and isinstance(v[0], float):
                                v_str = ', '.join([f"{e:.3}" for e in v])
                                metrics_str_list.append(f"{k}=[{v_str}]")
                            else:
                                metrics_str_list.append(f"{k}={v}")
                        else:
                            metrics_str_list.append(f"{k}={v}")
                    log_str = ', '.join(metrics_str_list)
                    print(log_str, file=logf)
                    print(log_str)
                ckpt_elasped_time = time.time() - ckpt_start_time
                if ckpt_elasped_time > train_cfg.save_checkpoints_secs:
                    torchplus.train.save_models(model_dir, [net, optimizer],
                                                net.get_global_step())

                    ckpt_start_time = time.time()
            total_step_elapsed += steps

            torchplus.train.save_models(model_dir, [net, optimizer],
                                        net.get_global_step())
            net.eval()
            result_path_step = result_path / f"step_{net.get_global_step()}"
            result_path_step.mkdir(parents=True, exist_ok=True)
            print("#################################")
            print("#################################", file=logf)
            print("# EVAL")
            print("# EVAL", file=logf)
            print("#################################")
            print("#################################", file=logf)
            print("Generate output labels...")
            print("Generate output labels...", file=logf)
            t = time.time()
            dt_annos = []
            prog_bar = ProgressBar()
            net.clear_timer()
            prog_bar.start(
                (len(eval_dataset) + eval_input_cfg.batch_size - 1) //
                eval_input_cfg.batch_size)
            for example in iter(eval_dataloader):
                example = example_convert_to_torch(example, float_dtype)
                if pickle_result:
                    dt_annos += predict_kitti_to_anno(net, example,
                                                      class_names,
                                                      center_limit_range,
                                                      model_cfg.lidar_input)
                else:
                    _predict_kitti_to_file(net, example, result_path_step,
                                           class_names, center_limit_range,
                                           model_cfg.lidar_input)

                prog_bar.print_bar()

            sec_per_ex = len(eval_dataset) / (time.time() - t)

            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                  file=logf)
            gt_annos = [
                info["annos"] for info in eval_dataset.dataset.kitti_infos
            ]
            if not pickle_result:
                dt_annos = kitti.get_label_annos(result_path_step)
            # result = get_official_eval_result_v2(gt_annos, dt_annos, class_names)
            # print(json.dumps(result, indent=2), file=logf)
            result = get_official_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            result_1 = result.split("\n")[:5]
            result_2 = result.split("\n")[10:15]
            result_3 = result.split("\n")[20:25]
            emh = ['0_easy', '1_mod', '2_hard']
            result_save = result_1
            for i in range(len(result_save) - 1):
                save_targ = result_save[i + 1]
                name_val = save_targ.split(':')[0].split(' ')[0]
                value_val = save_targ.split(':')[1:]
                for ev in range(3):
                    each_val = value_val[0].split(',')[ev]
                    merge_txt = 'AP_kitti/car_70/' + name_val + '/' + emh[ev]
                    writer.add_scalar(merge_txt, float(each_val), global_step)
            if pickle_result:
                with open(result_path_step / "result.pkl", 'wb') as f:
                    pickle.dump(dt_annos, f)
            writer.add_text('eval_result', result, global_step)
            net.train()
    except Exception as e:
        torchplus.train.save_models(model_dir, [net, optimizer],
                                    net.get_global_step())
        logf.close()
        raise e
    # save model before exit
    torchplus.train.save_models(model_dir, [net, optimizer],
                                net.get_global_step())
    logf.close()
Ejemplo n.º 11
0
    def evaluation_from_kitti_dets(self, dt_annos, output_dir):
        if "annos" not in self._kitti_infos[0]:
            return None
        gt_annos = [info["annos"] for info in self._kitti_infos]
        # firstly convert standard detection to kitti-format dt annos
        z_axis = 1  # KITTI camera format use y as regular "z" axis.
        z_center = 1.0  # KITTI camera box's center is [0.5, 1, 0.5]
        # for regular raw lidar data, z_axis = 2, z_center = 0.5.
        result_official_dict = get_official_eval_result(
            gt_annos,
            dt_annos,
            self._class_names,
            z_axis=z_axis,
            z_center=z_center)
        result_coco = get_coco_eval_result(
            gt_annos,
            dt_annos,
            self._class_names,
            z_axis=z_axis,
            z_center=z_center)
        
        # feature extraction
        for info, det in tqdm(zip(self._kitti_infos, dt_annos), desc="feature", total=len(dt_annos)):
            pc_info = info["point_cloud"]
            image_info = info["image"]
            calib = info["calib"]

            num_features = pc_info["num_features"]
            v_path = self._root_path / pc_info["velodyne_path"]
            v_path = str(v_path.parent.parent / (v_path.parent.stem + "_reduced") / v_path.name)
            points_v = np.fromfile(
                v_path, dtype=np.float32, count=-1).reshape([-1, num_features])
            rect = calib['R0_rect']
            Trv2c = calib['Tr_velo_to_cam']
            P2 = calib['P2']
            if False: # No longer you need remove outside image-rect (*_reduced pointcloud is already filtered.)
                points_v = box_np_ops.remove_outside_points(
                    points_v, rect, Trv2c, P2, image_info["image_shape"])

            annos = det
            num_obj = len([n for n in annos['name'] if n != 'DontCare'])
            # annos = kitti.filter_kitti_anno(annos, ['DontCare'])
            dims = annos['dimensions'][:num_obj]
            loc = annos['location'][:num_obj]
            rots = annos['rotation_y'][:num_obj]
            gt_boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]],
                                            axis=1)
            gt_boxes_lidar = box_np_ops.box_camera_to_lidar(
                gt_boxes_camera, rect, Trv2c)
            indices = box_np_ops.points_in_rbbox(points_v[:, :3], gt_boxes_lidar)
            num_points_in_gt = indices.sum(0)
            num_ignored = len(annos['dimensions']) - num_obj
            num_points_in_gt = np.concatenate(
                [num_points_in_gt, -np.ones([num_ignored])])
            annos["num_points_in_det"] = num_points_in_gt.astype(np.int32)

        return {
            "results": {
                "official": result_official_dict["result"],
                "coco": result_coco["result"],
            },
            "detail": {
                "eval.kitti": {
                    "official": result_official_dict["detail"],
                    "coco": result_coco["detail"]
                }
            },
            "result_kitti": result_official_dict["detections"],
        }
Ejemplo n.º 12
0
def evaluate(config_path, model_id=None, from_file_mode = False, epoch_idx=None):
    """
    Args:
        config_path     (str)   : Path to the config yaml
        model_id        (int)   : Which model id should be evaluated? "None" means the one in the config is used
        from_file_mode  (bool)  : "True" means that eval data from a file is used
        limit           (int)   : OPTIONS: None, int // after how many datapoints we want to exit 
    """

    print("**********************************************")
    print("* Start Evaluation")
    print("**********************************************")

    
    # This variable "model_id_memory" just helps to remember the original state of "model_id" 
    model_id_memory = 0
    if model_id == None: model_id_memory = None

    # ------------------------------------------------------------------------------------------------------ 
    #  load the config from file and set Variables
    # ------------------------------------------------------------------------------------------------------ 

    with open(config_path) as f1:   
        config = yaml.load(f1, Loader=yaml.FullLoader)


    # If no model_id is given we take the one from the config
    if model_id == None: model_id = config["eval_model_id"]
    eval_checkpoint = config["eval_checkpoint"]
    print("**********************************************")
    print("* Load Model ID {}".format(str(model_id)))
    print("**********************************************")


    # set training to false -> eval mode
    training = False

    # ------------------------------------------------------------------------------------------------------ 
    #  Load directory parameter and create directories
    # ------------------------------------------------------------------------------------------------------  
    
    project_dir_base = config["project_dir_base"] # path to base project dir (where the whole code is stored)

    
    # path to out base dir (where the training logs are stored)
    out_dir_base = create_out_dir_base(project_dir_base, training, model_id)
    

    # create the subdir where the training is stored
    out_dir_eval_results, out_dir_checkpoints = create_model_dirs_eval(out_dir_base)
   

    # ------------------------------------------------------------------------------------------------------
    #  Load dataloader parameter and create dataloader
    # ------------------------------------------------------------------------------------------------------ 

    limit = None # Options: {None,Int} # limit the amount of test data to be evaluated to save time
    batch_size = config["eval_input_reader"]["batch_size"] 
    num_point_features = config["train_input_reader"]["num_point_features"]
    center_limit_range = config["model"]["second"]["post_center_limit_range"]
    desired_objects = config["eval_input_reader"]["desired_objects"]
    no_annos_mode = config["eval_input_reader"]["no_annos_mode"]
    production_mode = bool(config["production_mode"])
    prediction_min_score = config["prediction_min_score"]


    # create the dataLoader object which is reponsible for loading datapoints
    # contains not much logic and basically just holds some variables
    dataset_ori = dataLoader(training, None, config)    


    # initializes the dataset object (batch creating etc.)
    dataset = dataset_ori.getIterator()


    # makes the dataset object iterable
    data_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)


    # ------------------------------------------------------------------------------------------------------
    #  Create network
    # ------------------------------------------------------------------------------------------------------ 

    # create tensorboard writer for logging
    writer = tf.summary.create_file_writer(out_dir_base)


    # create network
    net = VoxelNet(config, writer, training=training)


    # DEBUG
    if from_file_mode:
        with open("eval_dataloader_limit200", "rb") as file:
            data_iterator = pickle.load(file)

    # ------------------------------------------------------------------------------------------------------
    #  We wrap a training step in a tf.function to to gain speedups (5x)
    # ------------------------------------------------------------------------------------------------------
    
    max_number_of_points_per_voxel = config["model"]["second"]["voxel_generator"]["max_number_of_points_per_voxel"] 

    @tf.function(input_signature = [tf.TensorSpec(shape=[None,max_number_of_points_per_voxel,num_point_features], dtype=tf.float32),tf.TensorSpec(shape=[None,], dtype=tf.int32),tf.TensorSpec(shape=[None,4], dtype=tf.int32),tf.TensorSpec(shape=[None,None,7], dtype=tf.float32)])
    def trainStep(voxels,num_points,coors,batch_anchors):
        preds_dict = net(voxels,num_points,coors,batch_anchors)
        return preds_dict

    # ------------------------------------------------------------------------------------------------------
    #  EPOCH ITERATOR Settings
    # ------------------------------------------------------------------------------------------------------

    # Helper Variable to load weights                                      
    load_weights_finished = False # This variable is part of a tensorflow workaround caused by subclassing keras.model




    # # Create a model using low-level tf.* APIs
    # class Squared(tf.Module):
    #     @tf.function
    #     def __call__(self, x):
    #         return tf.square(x)
    # model = Squared()
    # # (ro run your model) result = Squared(5.0) # This prints "25.0"
    # # (to generate a SavedModel) tf.saved_model.save(model, "saved_model_tf_dir")
    # concrete_func = model.__call__.get_concrete_function()

    # # Convert the model
    # converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    # tflite_model = converter.convert()




    # Helper Variable to save the output of the network
    dt_annos = []

    # eval params
    measure_time = config["measure_time"]

    # Helper Variables for Time Bechmarks
    current_milli_time = lambda: int(round(time.time() * 1000))
    t_full_sample_list = []
    t_preprocess_list = []
    t_network_list = []
    t_predict_list = []
    t_anno_list = []
    t_rviz_list = []


    # ------------------------------------------------------------------------------------------------------
    # ROS
    # ------------------------------------------------------------------------------------------------------
    if production_mode:
        bb_pred_guess_1_pub = rospy.Publisher("bb_pred_guess_1", BoundingBoxArray)
        header = std_msgs.msg.Header()
        header.stamp = rospy.Time.now()
        header.frame_id = 'camera_color_frame'
        calib = {"R0_rect" : np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).reshape(3,3),
                    "Tr_velo_to_cam" : np.array([0.0, -1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0, 0.0, 0.0, 0.0]).reshape(3,4)}

    # ------------------------------------------------------------------------------------------------------
    #  EPOCH ITERATOR
    # ------------------------------------------------------------------------------------------------------  

    for i, example in enumerate(data_iterator): 
        # example: [0:voxels, 1:num_points, 2:coordinates, 3:rect, 4:Trv2c, 5:P2, 6:anchors, 7:anchors_mask, 8:image_idx, 9:image_shape]
        
        if measure_time: 
            if i > 0:
                t_preprocess = current_milli_time() - t_preprocess
                t_preprocess_list.append(t_preprocess)
            else:
                t_preprocess_list.append(0.0)

        # save starting time to measure network speed
        if measure_time: t_full_sample = current_milli_time()

        # Progess Bar
        if not production_mode:
            if limit is None: 
                progressBar(i,dataset_ori.ndata//batch_size)
            else:
                progressBar(i,limit)
                if i == limit: break # consider the case for early exit
       
        # ------------------------------------------------------------------------------------------------------
        #  Load Model Weights and Optimizer Weights (Checkpoint)
        # ------------------------------------------------------------------------------------------------------

        if load_weights_finished == False: # support variable, to check if the weights are loaded

            # initialize the model # TODO delete?
            #net(example[0],example[1],example[2],example[6],example[8],example[9])
            net(example[0],example[1],example[2],example[6])
            load_weights_finished = True

            # load the weights depending on if we are in training mode since
            # if yes the "model_weights_temp" file needs to be evaluated 
            model_dir = ""
            if model_id_memory == None: # for explaination, see variable "model_id_memory"
                model_dir = out_dir_checkpoints + eval_checkpoint
                net.load_weights(model_dir)
            else:
                model_dir = out_dir_checkpoints + "/model_weights_temp.h5"
                net.load_weights(model_dir)

            print("**********************************************")
            print("* Model Loaded from Path: {}".format(model_dir))
            print("**********************************************")


        # ------------------------------------------------------------------------------------------------------
        #  Run Network 
        # ------------------------------------------------------------------------------------------------------

        if measure_time: t_network = current_milli_time()

        preds_dict = trainStep(example[0],example[1],example[2],example[6])

        if measure_time:
            t_network = current_milli_time() - t_network
            t_network_list.append(t_network)

        # ------------------------------------------------------------------------------------------------------
        # Convert Network Output to predictions by applying the direction classifier to predictions of rotation 
        # use nms to get the final bboxes
        # ------------------------------------------------------------------------------------------------------

        if measure_time: t_predict = current_milli_time()

        # t_full_sample: 23.91, t_preprocess: 0.85, t_network: 12.54, t_predict: 10.46, t_anno: 0.84, t_rviz: 0.0
        # t_full_sample: 23.54, t_preprocess: 1.13, t_network: 11.08, t_predict: 11.45, t_anno: 0.82, t_rviz: 0.01
        predictions_dicts = net.predict(example,preds_dict) # list(predictions) of ['name', 'truncated', 'occluded', 'alpha', 'bbox', 'dimensions', 'location', 'rotation_y', 'score', 'image_idx'])

        if measure_time:
            t_predict = current_milli_time() -t_predict
            t_predict_list.append(t_predict)
        

        # ------------------------------------------------------------------------------------------------------
        # Convert Predictions to Kitti Annotation style 
        # ------------------------------------------------------------------------------------------------------
        
        if measure_time: t_anno = current_milli_time()

        dt_anno = predict_kitti_to_anno(example, desired_objects, predictions_dicts, None, False)

        if measure_time:
            t_anno = current_milli_time() -t_anno
            t_anno_list.append(t_anno)
        
        # ------------------------------------------------------------------------------------------------------
        # Send annotation to RVIZ
        # ------------------------------------------------------------------------------------------------------

        if measure_time: t_rviz = current_milli_time()

        if production_mode:

            dt_anno = remove_low_score(dt_anno[0], float(prediction_min_score))
            # if len(dt_anno["score"]) > 0:
            #    print(dt_anno["score"])
            dims = dt_anno['dimensions']
            loc = dt_anno['location']
            rots = dt_anno['rotation_y']
            boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1)
            boxes_lidar = box_camera_to_lidar(boxes_camera, calib['R0_rect'],calib['Tr_velo_to_cam'])
            centers,dims,angles = boxes_lidar[:, :3], boxes_lidar[:, 3:6], boxes_lidar[:, 6] # [a,b,c] -> [c,a,b] (camera to lidar coords)
            
            # LIFT bounding boxes 
            # - the postition of bboxes in the pipeline is at the z buttom of the bb and ros needs it at the z center
            # - TODO: lift by height/2 and not just 1.0

            #centers = centers + [0.0,0.0,0.9]
            send_3d_bbox(centers, dims, angles, bb_pred_guess_1_pub, header) 
        else:
            dt_annos += dt_anno
        
        if measure_time:
            t_rviz = current_milli_time() -t_rviz
            t_rviz_list.append(t_rviz)
            
            t_full_sample = current_milli_time() -t_full_sample
            t_full_sample_list.append(t_full_sample)
            
            t_preprocess = current_milli_time()

        # ------------------------------------------------------------------------------------------------------
        # Print times
        # ------------------------------------------------------------------------------------------------------
       
        if i > 0 and measure_time: # we scipt the first network iteration (initialization)
            t_full_sample_avg = round(sum(t_full_sample_list[1:])/len(t_full_sample_list[1:]),2)
            t_preprocess_avg = round(sum(t_preprocess_list[1:])/len(t_preprocess_list[1:]),2)
            t_network_avg = round(sum(t_network_list[1:])/len(t_network_list[1:]),2)
            t_predict_avg = round(sum(t_predict_list[1:])/len(t_predict_list[1:]),2)
            t_anno_avg = round(sum(t_anno_list[1:])/len(t_anno_list[1:]),2)
            t_rviz_avg = round(sum(t_rviz_list[1:])/len(t_rviz_list[1:]),2)
        
            print(f't_full_sample: {t_full_sample_avg}, t_preprocess: {t_preprocess_avg}, t_network: {t_network_avg}, t_predict: {t_predict_avg}, t_anno: {t_anno_avg}, t_rviz: {t_rviz_avg}')

    # ------------------------------------------------------------------------------------------------------
    # save the results in a file
    # ------------------------------------------------------------------------------------------------------

    if epoch_idx is not None: # if epoch index is given include in the name (typically evaluation during training)
        with open(out_dir_eval_results + "/result_epoch_{}.pkl".format(str(epoch_idx)), 'wb') as f:
            pickle.dump(dt_annos, f, 2)
            
    else: # if epoch index is not given use generic name (typically while evaluation of certain models, epochs, and testing sets)
        with open(out_dir_eval_results + "/result.pkl", 'wb') as f:
            pickle.dump(dt_annos, f, 2)

    # ------------------------------------------------------------------------------------------------------
    # Exit the program if we run in no_annos_mode (since we dont have annotations we cannot do the following evaluation.
    # ------------------------------------------------------------------------------------------------------

    if no_annos_mode:
        return (np.array([0]),"no evaluation") 

    # get all gt in dataset
    gt_annos = [info["annos"] for info in dataset_ori.img_list_and_infos]

    
    # DEBUG
    # This can limit the amount of test data to be evaluated to save time
    if limit is not None:
        gt_annos = gt_annos[0:limit]
    else:
        gt_annos = gt_annos[0:len(dt_annos)]

    # ------------------------------------------------------------------------------------------------------
    # evaluate the predictions in KITTI? style
    # - AOS is average orientation similarity, bbox is 2D
    # - Results for the columns (difficulties) will be equal since we do not have OCCLUSION and TRUNCATION 
    #   annos in out ground truth which have influence on the difficulties
    # - Input: dt_annos, gt_annos are in camera coors (in Lidar expression: (-y,-z,x))
    # ------------------------------------------------------------------------------------------------------

    compute_bbox = False # Since we dont have 2D box ground truth we don want to compute 2D bboxes results

    result1, mAPbbox, mAPbev, mAP3d, mAPaos = get_official_eval_result(gt_annos, dt_annos, desired_objects, compute_bbox = compute_bbox) 

    # ------------------------------------------------------------------------------------------------------
    # Print the evaluation result which depicts the overlaps between the gt and predictions 
    # - if multiple evals of the same class (e.g. Pedestrian,Pedestrian) occur, these depict separate evaluations 
    #   with different >overlap settings<
    # - in each of those evaluations: columns are difficulties (dependent on OCCLUSION and TRUNCATION) and 
    #   rows are [bbox,bev,3D]. The >overlap settings< are specified right to the class name e.g.:
    #   Pedestrian [email protected], 0.25, 0.25 -> [bbox, bev, 3D] and refer to the rows (NOT colums!). In this example
    #   bbox is 0.50, bev is 0.25 and 3D is 0.25 overlap. Also note, that bbox is missing in our evaluation.
    # ------------------------------------------------------------------------------------------------------
    print(result1)

    # ------------------------------------------------------------------------------------------------------
    # evaluate the predictions in COCO? style
    # ------------------------------------------------------------------------------------------------------

    # result2 = get_coco_eval_result(gt_annos, dt_annos, desired_objects)
    # print(result2)

    # ------------------------------------------------------------------------------------------------------
    # return results (only used during training)
    # ------------------------------------------------------------------------------------------------------
    
    return (mAP3d,result1) 
Ejemplo n.º 13
0
def evaluate(config_path,
             model_dir,
             result_path=None,
             predict_test=False,
             ckpt_path=None,
             ref_detfile=None,
             pickle_result=True,
             measure_time=False,
             batch_size=None):
    model_dir = pathlib.Path(model_dir)
    print("Predict_test: ", predict_test)
    if predict_test:
        result_name = 'predict_test'
    else:
        result_name = 'eval_results'
    if result_path is None:
        result_path = model_dir / result_name
    else:
        result_path = pathlib.Path(result_path)
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)

    input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config
    detection_2d_path = config.train_config.detection_2d_path
    center_limit_range = model_cfg.post_center_limit_range
    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    class_names = target_assigner.classes
    # this one is used for training car detector
    net = build_inference_net('./configs/car.fhd.config', '../model_dir')
    fusion_layer = fusion.fusion()
    fusion_layer.cuda()
    net.cuda()
    ############ restore parameters for fusion layer
    if ckpt_path is None:
        print("load existing model for fusion layer")
        torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                       [fusion_layer])
    else:
        torchplus.train.restore(ckpt_path, fusion_layer)
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    batch_size = batch_size or input_cfg.batch_size
    eval_dataset = input_reader_builder.build(input_cfg,
                                              model_cfg,
                                              training=not predict_test,
                                              voxel_generator=voxel_generator,
                                              target_assigner=target_assigner)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,  # input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32

    net.eval()
    fusion_layer.eval()
    result_path_step = result_path / f"step_{net.get_global_step()}"
    result_path_step.mkdir(parents=True, exist_ok=True)
    t = time.time()
    dt_annos = []
    global_set = None
    print("Generate output labels...")
    bar = ProgressBar()
    bar.start((len(eval_dataset) + batch_size - 1) // batch_size)
    prep_example_times = []
    prep_times = []
    t2 = time.time()
    val_loss_final = 0
    for example in iter(eval_dataloader):
        if measure_time:
            prep_times.append(time.time() - t2)
            t1 = time.time()
            torch.cuda.synchronize()
        example = example_convert_to_torch(example, float_dtype)
        if measure_time:
            torch.cuda.synchronize()
            prep_example_times.append(time.time() - t1)

        if pickle_result:
            dt_annos_i, val_losses = predict_kitti_to_anno(
                net, detection_2d_path, fusion_layer, example, class_names,
                center_limit_range, model_cfg.lidar_input, global_set)
            dt_annos += dt_annos_i
            val_loss_final = val_loss_final + val_losses
        else:
            _predict_kitti_to_file(net, detection_2d_path, fusion_layer,
                                   example, result_path_step, class_names,
                                   center_limit_range, model_cfg.lidar_input)
        bar.print_bar()
        if measure_time:
            t2 = time.time()

    sec_per_example = len(eval_dataset) / (time.time() - t)
    print(f'generate label finished({sec_per_example:.2f}/s). start eval:')
    print("validation_loss:", val_loss_final / len(eval_dataloader))
    if measure_time:
        print(
            f"avg example to torch time: {np.mean(prep_example_times) * 1000:.3f} ms"
        )
        print(f"avg prep time: {np.mean(prep_times) * 1000:.3f} ms")
    for name, val in net.get_avg_time_dict().items():
        print(f"avg {name} time = {val * 1000:.3f} ms")
    if not predict_test:
        gt_annos = [info["annos"] for info in eval_dataset.dataset.kitti_infos]
        if not pickle_result:
            dt_annos = kitti.get_label_annos(result_path_step)
        result = get_official_eval_result(gt_annos, dt_annos, class_names)
        # print(json.dumps(result, indent=2))
        print(result)
        result = get_coco_eval_result(gt_annos, dt_annos, class_names)
        print(result)
        if pickle_result:
            with open(result_path_step / "result.pkl", 'wb') as f:
                pickle.dump(dt_annos, f)
    else:
        if pickle_result:
            with open(result_path_step / "result.pkl", 'wb') as f:
                pickle.dump(dt_annos, f)
Ejemplo n.º 14
0
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          pickle_result=True,
          patchs=None):
    torch.manual_seed(3)
    np.random.seed(3)
    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    patchs = patchs or []
    model_dir = pathlib.Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)
    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config
    detection_2d_path = config.train_config.detection_2d_path
    print("2d detection path:", detection_2d_path)
    center_limit_range = model_cfg.post_center_limit_range
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    class_names = target_assigner.classes
    net = build_inference_net('./configs/car.fhd.config', '../model_dir')
    fusion_layer = fusion.fusion()
    fusion_layer.cuda()
    optimizer_cfg = train_cfg.optimizer
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    loss_scale = train_cfg.loss_scale_factor
    mixed_optimizer = optimizer_builder.build(
        optimizer_cfg,
        fusion_layer,
        mixed=train_cfg.enable_mixed_precision,
        loss_scale=loss_scale)
    optimizer = mixed_optimizer
    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer,
                                              train_cfg.steps)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################

    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner)
    eval_dataset = input_reader_builder.build(
        eval_input_cfg,
        model_cfg,
        training=True,  #if rhnning for test, here it needs to be False
        voxel_generator=voxel_generator,
        target_assigner=target_assigner)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=input_cfg.batch_size,
                                             shuffle=True,
                                             num_workers=input_cfg.num_workers,
                                             pin_memory=False,
                                             collate_fn=merge_second_batch,
                                             worker_init_fn=_worker_init_fn)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    focal_loss = SigmoidFocalClassificationLoss()
    cls_loss_sum = 0
    training_detail = []
    log_path = model_dir / 'log.txt'
    training_detail_path = model_dir / 'log.json'
    if training_detail_path.exists():
        with open(training_detail_path, 'r') as f:
            training_detail = json.load(f)
    logf = open(log_path, 'a')
    logf.write(proto_str)
    logf.write("\n")
    summary_dir = model_dir / 'summary'
    summary_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(summary_dir))
    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t
    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    #print("steps, steps_per_eval, total_loop:", train_cfg.steps, train_cfg.steps_per_eval, total_loop)
    # total_loop = remain_steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch
    net.set_global_step(torch.tensor([0]))
    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:
        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.steps_per_eval > train_cfg.steps:
                steps = train_cfg.steps % train_cfg.steps_per_eval
            else:
                steps = train_cfg.steps_per_eval
            for step in range(steps):
                lr_scheduler.step(net.get_global_step())
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example, float_dtype)
                batch_size = example["anchors"].shape[0]
                all_3d_output_camera_dict, all_3d_output, top_predictions, fusion_input, tensor_index = net(
                    example_torch, detection_2d_path)
                d3_gt_boxes = example_torch["d3_gt_boxes"][0, :, :]
                if d3_gt_boxes.shape[0] == 0:
                    target_for_fusion = np.zeros((1, 70400, 1))
                    positives = torch.zeros(1,
                                            70400).type(torch.float32).cuda()
                    negatives = torch.zeros(1,
                                            70400).type(torch.float32).cuda()
                    negatives[:, :] = 1
                else:
                    d3_gt_boxes_camera = box_torch_ops.box_lidar_to_camera(
                        d3_gt_boxes, example_torch['rect'][0, :],
                        example_torch['Trv2c'][0, :])
                    d3_gt_boxes_camera_bev = d3_gt_boxes_camera[:, [
                        0, 2, 3, 5, 6
                    ]]
                    ###### predicted bev boxes
                    pred_3d_box = all_3d_output_camera_dict[0]["box3d_camera"]
                    pred_bev_box = pred_3d_box[:, [0, 2, 3, 5, 6]]
                    #iou_bev = bev_box_overlap(d3_gt_boxes_camera_bev.detach().cpu().numpy(), pred_bev_box.detach().cpu().numpy(), criterion=-1)
                    iou_bev = d3_box_overlap(
                        d3_gt_boxes_camera.detach().cpu().numpy(),
                        pred_3d_box.squeeze().detach().cpu().numpy(),
                        criterion=-1)
                    iou_bev_max = np.amax(iou_bev, axis=0)
                    #print(np.max(iou_bev_max))
                    target_for_fusion = ((iou_bev_max >= 0.7) * 1).reshape(
                        1, -1, 1)

                    positive_index = ((iou_bev_max >= 0.7) * 1).reshape(1, -1)
                    positives = torch.from_numpy(positive_index).type(
                        torch.float32).cuda()
                    negative_index = ((iou_bev_max <= 0.5) * 1).reshape(1, -1)
                    negatives = torch.from_numpy(negative_index).type(
                        torch.float32).cuda()

                cls_preds, flag = fusion_layer(fusion_input.cuda(),
                                               tensor_index.cuda())
                one_hot_targets = torch.from_numpy(target_for_fusion).type(
                    torch.float32).cuda()

                negative_cls_weights = negatives.type(torch.float32) * 1.0
                cls_weights = negative_cls_weights + 1.0 * positives.type(
                    torch.float32)
                pos_normalizer = positives.sum(1, keepdim=True).type(
                    torch.float32)
                cls_weights /= torch.clamp(pos_normalizer, min=1.0)
                if flag == 1:
                    cls_losses = focal_loss._compute_loss(
                        cls_preds, one_hot_targets,
                        cls_weights.cuda())  # [N, M]
                    cls_losses_reduced = cls_losses.sum(
                    ) / example_torch['labels'].shape[0]
                    cls_loss_sum = cls_loss_sum + cls_losses_reduced
                    if train_cfg.enable_mixed_precision:
                        loss *= loss_scale
                    cls_losses_reduced.backward()
                    mixed_optimizer.step()
                    mixed_optimizer.zero_grad()
                net.update_global_step()
                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                global_step = net.get_global_step()
                if global_step % display_step == 0:
                    print("now it is",
                          global_step,
                          "steps",
                          " and the cls_loss is :",
                          cls_loss_sum / display_step,
                          "learning_rate: ",
                          float(optimizer.lr),
                          file=logf)
                    print("now it is", global_step, "steps",
                          " and the cls_loss is :",
                          cls_loss_sum / display_step, "learning_rate: ",
                          float(optimizer.lr))
                    cls_loss_sum = 0

                ckpt_elasped_time = time.time() - ckpt_start_time

                if ckpt_elasped_time > train_cfg.save_checkpoints_secs:
                    torchplus.train.save_models(model_dir,
                                                [fusion_layer, optimizer],
                                                net.get_global_step())

                    ckpt_start_time = time.time()

            total_step_elapsed += steps

            torchplus.train.save_models(model_dir, [fusion_layer, optimizer],
                                        net.get_global_step())

            fusion_layer.eval()
            net.eval()
            result_path_step = result_path / f"step_{net.get_global_step()}"
            result_path_step.mkdir(parents=True, exist_ok=True)
            print("#################################")
            print("#################################", file=logf)
            print("# EVAL")
            print("# EVAL", file=logf)
            print("#################################")
            print("#################################", file=logf)
            print("Generate output labels...")
            print("Generate output labels...", file=logf)
            t = time.time()
            dt_annos = []
            prog_bar = ProgressBar()
            net.clear_timer()
            prog_bar.start(
                (len(eval_dataset) + eval_input_cfg.batch_size - 1) //
                eval_input_cfg.batch_size)
            val_loss_final = 0
            for example in iter(eval_dataloader):
                example = example_convert_to_torch(example, float_dtype)
                if pickle_result:
                    dt_annos_i, val_losses = predict_kitti_to_anno(
                        net, detection_2d_path, fusion_layer, example,
                        class_names, center_limit_range, model_cfg.lidar_input)
                    dt_annos += dt_annos_i
                    val_loss_final = val_loss_final + val_losses
                else:
                    _predict_kitti_to_file(net, detection_2d_path, example,
                                           result_path_step, class_names,
                                           center_limit_range,
                                           model_cfg.lidar_input)

                prog_bar.print_bar()

            sec_per_ex = len(eval_dataset) / (time.time() - t)
            print("validation_loss:", val_loss_final / len(eval_dataloader))
            print("validation_loss:",
                  val_loss_final / len(eval_dataloader),
                  file=logf)
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                  file=logf)
            gt_annos = [
                info["annos"] for info in eval_dataset.dataset.kitti_infos
            ]
            if not pickle_result:
                dt_annos = kitti.get_label_annos(result_path_step)
            # result = get_official_eval_result_v2(gt_annos, dt_annos, class_names)
            result = get_official_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            writer.add_text('eval_result', json.dumps(result, indent=2),
                            global_step)
            result = get_coco_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            if pickle_result:
                with open(result_path_step / "result.pkl", 'wb') as f:
                    pickle.dump(dt_annos, f)
            writer.add_text('eval_result', result, global_step)
            #net.train()
            fusion_layer.train()
    except Exception as e:

        torchplus.train.save_models(model_dir, [fusion_layer, optimizer],
                                    net.get_global_step())

        logf.close()
        raise e
    # save model before exit

    torchplus.train.save_models(model_dir, [fusion_layer, optimizer],
                                net.get_global_step())

    logf.close()