Exemple #1
0
    def forward(self, features, num_voxels, coors):

        # Find distance of x, y, and z from cluster center
        points_mean = features[:, :, :3].sum(
            dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1)
        f_cluster = features[:, :, :3] - points_mean

        # Find distance of x, y, and z from pillar center
        f_center = torch.zeros_like(features[:, :, :2])
        f_center[:, :, 0] = features[:, :, 0] - (
            coors[:, 3].float().unsqueeze(1) * self.vx + self.x_offset)
        f_center[:, :, 1] = features[:, :, 1] - (
            coors[:, 2].float().unsqueeze(1) * self.vy + self.y_offset)

        # Combine together feature decorations
        features_ls = [features, f_cluster, f_center]
        if self._with_distance:
            points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
            features_ls.append(points_dist)
        features = torch.cat(features_ls, dim=-1)

        # The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
        # empty pillars remain set to zeros.
        voxel_count = features.shape[1]
        mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
        mask = torch.unsqueeze(mask, -1).type_as(features)
        features *= mask

        features = self.VoxelFeature_TA(points_mean, features)

        # Forward pass through PFNLayers
        for pfn in self.pfn_layers:
            features = pfn(features)

        return features.squeeze()
Exemple #2
0
 def forward(self, features, num_voxels, coors):
     # features: [concated_num_points, num_voxel_size, 3(4)]
     # num_voxels: [concated_num_points]
     points_mean = features[:, :, :3].sum(
         dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1)
     features_relative = features[:, :, :3] - points_mean
     if self._with_distance:
         points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
         features = torch.cat([features, features_relative, points_dist],
                              dim=-1)
     else:
         features = torch.cat([features, features_relative], dim=-1)
     voxel_count = features.shape[1]
     mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
     mask = torch.unsqueeze(mask, -1).type_as(features)
     for vfe in self.vfe_layers:
         features = vfe(features)
         features *= mask
     features = self.linear(features)
     features = self.norm(features.permute(0, 2, 1).contiguous()).permute(
         0, 2, 1).contiguous()
     features = F.relu(features)
     features *= mask
     # x: [concated_num_points, num_voxel_size, 128]
     voxelwise = torch.max(features, dim=1)[0]
     return voxelwise
Exemple #3
0
    def forward(self, features, num_voxels, coors):  # PFN网络前向传递
        """
        :param features: 体素特征,(num_voxels,max_points,4)
        :param num_voxels: 这里可能是打错了,应该是体素内的点数(num_voxels,)
        :param coors: 体素的坐标索引,第一位是用来区别2帧数据的(num_voxels,4)
        :return:每一个特征最大的特征值(num_voxels,64)
        """

        # Find distance of x, y, and z from cluster center
        points_mean = features[:, :, :3].sum(
            dim=1, keepdim=True) / num_voxels.type_as(features).view(
                -1, 1, 1)  # 求每个体素内所有点的中心坐标,(num_voxels,1,3)
        f_cluster = features[:, :, :
                             3] - points_mean  # 体素内的点减去中心点归一化,(num_voxels,100,3)

        # Find distance of x, y, and z from pillar center
        # 根据体素索引,以及体素偏移确定每个体素的水平中心(xy方向)
        f_center = features[:, :, :2]  # (num_voxels,max_points,2)
        f_center[:, :, 0] = f_center[:, :, 0] - (
            coors[:, 3].float().unsqueeze(1) * self.vx + self.x_offset)
        f_center[:, :, 1] = f_center[:, :, 1] - (
            coors[:, 2].float().unsqueeze(1) * self.vy + self.y_offset)

        # Combine together feature decorations
        features_ls = [features, f_cluster, f_center]
        if self._with_distance:  # False
            points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
            features_ls.append(points_dist)
        features = torch.cat(
            features_ls, dim=-1)  # 将最后一维的特征相加,(num_voxels,max_points,4+3+2)

        # The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
        # empty pillars remain set to zeros.
        voxel_count = features.shape[1]  # max_points
        mask = get_paddings_indicator(
            num_voxels, voxel_count,
            axis=0)  # 每个体素内用来填充的标签,每个体素内有几个点每行就有几个1(num_voxels,max_points)
        mask = torch.unsqueeze(mask, -1).type_as(
            features)  # (num_voxels,max_points,1)
        features *= mask  # 对应元素相乘,空元素确保置0(num_voxels,max_points,4+3+2)

        # Forward pass through PFNLayers
        for pfn in self.pfn_layers:  # 调用上面的pfn层进行体素特征提取
            features = pfn(features)

        return features.squeeze()  # 去掉值为1的维度
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          pickle_result=True):
    """train a VoxelNet model specified by a config file.
    """
    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)

    model_dir = pathlib.Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    eval_checkpoint_dir = model_dir / 'eval_checkpoints'
    eval_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config_file_bkp = "pipeline.config"
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)
    shutil.copyfile(config_path, str(model_dir / config_file_bkp))
    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    class_names = list(input_cfg.class_names)
    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    ######################
    # BUILD TARGET ASSIGNER
    ######################
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)
    target_assigner_cfg = model_cfg.target_assigner
    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    ######################
    # BUILD NET
    ######################
    center_limit_range = model_cfg.post_center_limit_range
    # net = second_builder.build(model_cfg, voxel_generator, target_assigner)
    net = second_builder.build(model_cfg, voxel_generator, target_assigner, input_cfg.batch_size)
    net.cuda()
    # net_train = torch.nn.DataParallel(net).cuda()
    print("num_trainable parameters:", len(list(net.parameters())))
    # for n, p in net.named_parameters():
    #     print(n, p.shape)
    ######################
    # BUILD OPTIMIZER
    ######################
    # we need global_step to create lr_scheduler, so restore net first.
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    gstep = net.get_global_step() - 1
    optimizer_cfg = train_cfg.optimizer
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    optimizer = optimizer_builder.build(optimizer_cfg, net.parameters())
    if train_cfg.enable_mixed_precision:
        loss_scale = train_cfg.loss_scale_factor
        mixed_optimizer = torchplus.train.MixedPrecisionWrapper(
            optimizer, loss_scale)
    else:
        mixed_optimizer = optimizer
    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer, gstep)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################

    dataset = input_reader_builder.build(
        input_cfg,
        model_cfg,
        training=True,
        voxel_generator=voxel_generator,
        target_assigner=target_assigner)
    eval_dataset = input_reader_builder.build(
        eval_input_cfg,
        model_cfg,
        training=False,
        voxel_generator=voxel_generator,
        target_assigner=target_assigner)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=input_cfg.batch_size,
        shuffle=True,
        num_workers=input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch,
        worker_init_fn=_worker_init_fn)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)
    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    log_path = model_dir / 'log.txt'
    logf = open(log_path, 'a')
    logf.write(proto_str)
    logf.write("\n")
    summary_dir = model_dir / 'summary'
    summary_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(summary_dir))

    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t

    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    # total_loop = remain_steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:
        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.steps_per_eval > train_cfg.steps:
                steps = train_cfg.steps % train_cfg.steps_per_eval
            else:
                steps = train_cfg.steps_per_eval
            for step in range(steps):
                lr_scheduler.step()
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example, float_dtype)

                batch_size = example["anchors"].shape[0]

                example_tuple = list(example_torch.values())
                example_tuple[11] = torch.from_numpy(example_tuple[11])
                example_tuple[12] = torch.from_numpy(example_tuple[12])
                assert 13==len(example_tuple), "something wring with training input size!"
                # training example:[0:'voxels', 1:'num_points', 2:'coordinates', 3:'rect',
                # 4:'Trv2c', 5:'P2',
                # 6:'anchors', 7:'anchors_mask', 8:'labels', 9:'reg_targets', 10:'reg_weights',
                # 11:'image_idx', 12:'image_shape']
                # ret_dict = net(example_torch)

                # training input from example
                # print("example[0] size", example_tuple[0].size())
                pillar_x = example_tuple[0][:,:,0].unsqueeze(0).unsqueeze(0)
                pillar_y = example_tuple[0][:,:,1].unsqueeze(0).unsqueeze(0)
                pillar_z = example_tuple[0][:,:,2].unsqueeze(0).unsqueeze(0)
                pillar_i = example_tuple[0][:,:,3].unsqueeze(0).unsqueeze(0)
                num_points_per_pillar = example_tuple[1].float().unsqueeze(0)

                # Find distance of x, y, and z from pillar center
                # assuming xyres_16.proto
                coors_x = example_tuple[2][:, 3].float()
                coors_y = example_tuple[2][:, 2].float()
                # self.x_offset = self.vx / 2 + pc_range[0]
                # self.y_offset = self.vy / 2 + pc_range[1]
                # this assumes xyres 20
                # x_sub = coors_x.unsqueeze(1) * 0.16 + 0.1
                # y_sub = coors_y.unsqueeze(1) * 0.16 + -39.9
                # here assumes xyres 16
                x_sub = coors_x.unsqueeze(1) * 0.16 + 0.08
                y_sub = coors_y.unsqueeze(1) * 0.16 + -39.6
                ones = torch.ones([1, 100],dtype=torch.float32, device=pillar_x.device )
                x_sub_shaped = torch.mm(x_sub, ones).unsqueeze(0).unsqueeze(0)
                y_sub_shaped = torch.mm(y_sub, ones).unsqueeze(0).unsqueeze(0)

                num_points_for_a_pillar = pillar_x.size()[3]
                mask = get_paddings_indicator(num_points_per_pillar, num_points_for_a_pillar, axis=0)
                mask = mask.permute(0, 2, 1)
                mask = mask.unsqueeze(1)
                mask = mask.type_as(pillar_x)

                coors   = example_tuple[2]
                anchors = example_tuple[6]
                labels  = example_tuple[8]
                reg_targets = example_tuple[9]

                input = [pillar_x, pillar_y, pillar_z, pillar_i,
                         num_points_per_pillar, x_sub_shaped, y_sub_shaped, mask, coors,
                         anchors, labels, reg_targets]

                ret_dict = net(input)
                assert 10==len(ret_dict), "something wring with training output size!"
                # return 0
                # ret_dict {
                #     0:"loss": loss,
                #     1:"cls_loss": cls_loss,
                #     2:"loc_loss": loc_loss,
                #     3:"cls_pos_loss": cls_pos_loss,
                #     4:"cls_neg_loss": cls_neg_loss,
                #     5:"cls_preds": cls_preds,
                #     6:"dir_loss_reduced": dir_loss_reduced,
                #     7:"cls_loss_reduced": cls_loss_reduced,
                #     8:"loc_loss_reduced": loc_loss_reduced,
                #     9:"cared": cared,
                # }
                # cls_preds = ret_dict["cls_preds"]
                cls_preds = ret_dict[5]
                # loss = ret_dict["loss"].mean()
                loss = ret_dict[0].mean()
                # cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                cls_loss_reduced = ret_dict[7].mean()
                # loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                loc_loss_reduced = ret_dict[8].mean()
                # cls_pos_loss = ret_dict["cls_pos_loss"]
                cls_pos_loss = ret_dict[3]
                # cls_neg_loss = ret_dict["cls_neg_loss"]
                cls_neg_loss = ret_dict[4]
                # loc_loss = ret_dict["loc_loss"]
                loc_loss = ret_dict[2]
                # cls_loss = ret_dict["cls_loss"]
                cls_loss = ret_dict[1]
                # dir_loss_reduced = ret_dict["dir_loss_reduced"]
                dir_loss_reduced = ret_dict[6]
                # cared = ret_dict["cared"]
                cared = ret_dict[9]
                # labels = example_torch["labels"]
                labels = example_tuple[8]
                if train_cfg.enable_mixed_precision:
                    loss *= loss_scale
                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
                mixed_optimizer.step()
                mixed_optimizer.zero_grad()
                net.update_global_step()
                net_metrics = net.update_metrics(cls_loss_reduced,
                                                 loc_loss_reduced, cls_preds,
                                                 labels, cared)

                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                # if 'anchors_mask' not in example_torch:
                #     num_anchors = example_torch['anchors'].shape[1]
                # else:
                #     num_anchors = int(example_torch['anchors_mask'][0].sum())
                num_anchors = int(example_tuple[7][0].sum())
                global_step = net.get_global_step()
                if global_step % display_step == 0:
                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]
                    metrics["step"] = global_step
                    metrics["steptime"] = step_time
                    metrics.update(net_metrics)
                    metrics["loss"] = {}
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())
                    # if unlabeled_training:
                    #     metrics["loss"]["diff_rt"] = float(
                    #         diff_loc_loss_reduced.detach().cpu().numpy())
                    if model_cfg.use_direction_classifier:
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())
                    # metrics["num_vox"] = int(example_torch["voxels"].shape[0])
                    metrics["num_vox"] = int(example_tuple[0].shape[0])
                    metrics["num_pos"] = int(num_pos)
                    metrics["num_neg"] = int(num_neg)
                    metrics["num_anchors"] = int(num_anchors)
                    metrics["lr"] = float(
                        mixed_optimizer.param_groups[0]['lr'])
                    # metrics["image_idx"] = example['image_idx'][0]
                    metrics["image_idx"] = example_tuple[11][0]
                    flatted_metrics = flat_nested_json_dict(metrics)
                    flatted_summarys = flat_nested_json_dict(metrics, "/")
                    for k, v in flatted_summarys.items():
                        if isinstance(v, (list, tuple)):
                            v = {str(i): e for i, e in enumerate(v)}
                            writer.add_scalars(k, v, global_step)
                        else:
                            writer.add_scalar(k, v, global_step)
                    metrics_str_list = []
                    for k, v in flatted_metrics.items():
                        if isinstance(v, float):
                            metrics_str_list.append(f"{k}={v:.3}")
                        elif isinstance(v, (list, tuple)):
                            if v and isinstance(v[0], float):
                                v_str = ', '.join([f"{e:.3}" for e in v])
                                metrics_str_list.append(f"{k}=[{v_str}]")
                            else:
                                metrics_str_list.append(f"{k}={v}")
                        else:
                            metrics_str_list.append(f"{k}={v}")
                    log_str = ', '.join(metrics_str_list)
                    print(log_str, file=logf)
                    print(log_str)
                ckpt_elasped_time = time.time() - ckpt_start_time
                if ckpt_elasped_time > train_cfg.save_checkpoints_secs:
                    torchplus.train.save_models(model_dir, [net, optimizer],
                                                net.get_global_step())
                    ckpt_start_time = time.time()
            total_step_elapsed += steps
            torchplus.train.save_models(model_dir, [net, optimizer],
                                        net.get_global_step())

            # Ensure that all evaluation points are saved forever
            torchplus.train.save_models(eval_checkpoint_dir, [net, optimizer], net.get_global_step(), max_to_keep=100)

            # net.eval()
            # result_path_step = result_path / f"step_{net.get_global_step()}"
            # result_path_step.mkdir(parents=True, exist_ok=True)
            # print("#################################")
            # print("#################################", file=logf)
            # print("# EVAL")
            # print("# EVAL", file=logf)
            # print("#################################")
            # print("#################################", file=logf)
            # print("Generate output labels...")
            # print("Generate output labels...", file=logf)
            # t = time.time()
            # dt_annos = []
            # prog_bar = ProgressBar()
            # prog_bar.start(len(eval_dataset) // eval_input_cfg.batch_size + 1)
            # for example in iter(eval_dataloader):
            #     example = example_convert_to_torch(example, float_dtype)
            #     # evaluation example:[0:'voxels', 1:'num_points', 2:'coordinates', 3:'rect',
            #     # 4:'Trv2c', 5:'P2',
            #     # 6:'anchors', 7:'anchors_mask', 8:'image_idx', 9:'image_shape']
            #     example_tuple = list(example.values())
            #     example_tuple[8] = torch.from_numpy(example_tuple[8])
            #     example_tuple[9] = torch.from_numpy(example_tuple[9])
            #     if pickle_result:
            #         dt_annos += predict_kitti_to_anno(
            #             net, example_tuple, class_names, center_limit_range,
            #             model_cfg.lidar_input)
            #     else:
            #         _predict_kitti_to_file(net, example, result_path_step,
            #                                class_names, center_limit_range,
            #                                model_cfg.lidar_input)
            #
            #     prog_bar.print_bar()
            #
            # sec_per_ex = len(eval_dataset) / (time.time() - t)
            # print(f"avg forward time per example: {net.avg_forward_time:.3f}")
            # print(
            #     f"avg postprocess time per example: {net.avg_postprocess_time:.3f}"
            # )
            #
            # net.clear_time_metrics()
            # print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
            # print(
            #     f'generate label finished({sec_per_ex:.2f}/s). start eval:',
            #     file=logf)
            # gt_annos = [
            #     info["annos"] for info in eval_dataset.dataset.kitti_infos
            # ]
            # if not pickle_result:
            #     dt_annos = kitti.get_label_annos(result_path_step)
            # result, mAPbbox, mAPbev, mAP3d, mAPaos = get_official_eval_result(gt_annos, dt_annos, class_names,
            #                                                                   return_data=True)
            # print(result, file=logf)
            # print(result)
            # writer.add_text('eval_result', result, global_step)
            #
            # for i, class_name in enumerate(class_names):
            #     writer.add_scalar('bev_ap:{}'.format(class_name), mAPbev[i, 1, 0], global_step)
            #     writer.add_scalar('3d_ap:{}'.format(class_name), mAP3d[i, 1, 0], global_step)
            #     writer.add_scalar('aos_ap:{}'.format(class_name), mAPaos[i, 1, 0], global_step)
            # writer.add_scalar('bev_map', np.mean(mAPbev[:, 1, 0]), global_step)
            # writer.add_scalar('3d_map', np.mean(mAP3d[:, 1, 0]), global_step)
            # writer.add_scalar('aos_map', np.mean(mAPaos[:, 1, 0]), global_step)
            #
            # result = get_coco_eval_result(gt_annos, dt_annos, class_names)
            # print(result, file=logf)
            # print(result)
            # if pickle_result:
            #     with open(result_path_step / "result.pkl", 'wb') as f:
            #         pickle.dump(dt_annos, f)
            # writer.add_text('eval_result', result, global_step)
            # net.train()
    except Exception as e:
        torchplus.train.save_models(model_dir, [net, optimizer],
                                    net.get_global_step())
        logf.close()
        raise e
    # save model before exit
    torchplus.train.save_models(model_dir, [net, optimizer],
                                net.get_global_step())
    logf.close()
def prediction_once(net,
                          example,
                          class_names,
                          batch_image_shape,
                          center_limit_range=None,
                          lidar_input=False,
                          global_set=None):
    # predictions_dicts = net(example)
    # input_names = ['voxels', 'num_points', 'coordinates', 'rect', 'Trv2c', 'P2',
    # 'anchors', 'anchors_mask', 'labels', 'image_idx', 'image_shape']


    pillar_x = example[0][:,:,0].unsqueeze(0).unsqueeze(0)
    pillar_y = example[0][:,:,1].unsqueeze(0).unsqueeze(0)
    pillar_z = example[0][:,:,2].unsqueeze(0).unsqueeze(0)
    pillar_i = example[0][:,:,3].unsqueeze(0).unsqueeze(0)
    num_points_per_pillar = example[1].float().unsqueeze(0)

    # Find distance of x, y, and z from pillar center
    # assuming xyres_16.proto
    coors_x = example[2][:, 3].float()
    coors_y = example[2][:, 2].float()
    x_sub = coors_x.unsqueeze(1) * 0.16 + 0.1
    y_sub = coors_y.unsqueeze(1) * 0.16 + -39.9
    ones = torch.ones([1, 100],dtype=torch.float32, device=pillar_x.device )
    x_sub_shaped = torch.mm(x_sub, ones).unsqueeze(0).unsqueeze(0)
    y_sub_shaped = torch.mm(y_sub, ones).unsqueeze(0).unsqueeze(0)

    num_points_for_a_pillar = pillar_x.size()[3]
    mask = get_paddings_indicator(num_points_per_pillar, num_points_for_a_pillar, axis=0)
    mask = mask.permute(0, 2, 1)
    mask = mask.unsqueeze(1)
    mask = mask.type_as(pillar_x)

    coors = example[2]

    print(pillar_x.size())
    print(pillar_y.size())
    print(pillar_z.size())
    print(pillar_i.size())
    print(num_points_per_pillar.size())
    print(x_sub_shaped.size())
    print(y_sub_shaped.size())
    print(mask.size())

    # input = [pillar_x, pillar_y, pillar_z, pillar_i, num_points_per_pillar, x_sub_shaped, y_sub_shaped, mask, coors]
    # predictions_dicts = net(input)
    # return 0

    input_names = ["pillar_x", "pillar_y", "pillar_z", "pillar_i", "num_points_per_pillar", "x_sub_shaped", "y_sub_shaped", "mask"]
    # input_names = ["pillar_x", "pillar_y", "pillar_z"]

    # pillar_x = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )
    # pillar_y = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )
    # pillar_z = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )
    # pillar_i = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )
    # num_points_per_pillar = torch.ones([1, 8599],dtype=torch.float32, device=pillar_x.device )
    # x_sub_shaped = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )
    # y_sub_shaped = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )
    # mask = torch.ones([1, 8599, 100, 1],dtype=torch.float32, device=pillar_x.device )

    # wierd conv
    pillar_x = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )
    pillar_y = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )
    pillar_z = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )
    pillar_i = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )
    num_points_per_pillar = torch.ones([1, 12000],dtype=torch.float32, device=pillar_x.device )
    x_sub_shaped = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )
    y_sub_shaped = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )
    mask = torch.ones([1, 1, 12000, 100],dtype=torch.float32, device=pillar_x.device )

    # deconv
    # pillar_x = torch.ones([1, 100, 8599, 1],dtype=torch.float32, device=pillar_x.device )
    # pillar_y = torch.ones([1, 100, 8599, 1],dtype=torch.float32, device=pillar_x.device )
    # pillar_z = torch.ones([1, 100, 8599, 1],dtype=torch.float32, device=pillar_x.device )
    # pillar_i = torch.ones([1, 100, 8599, 1],dtype=torch.float32, device=pillar_x.device )
    # num_points_per_pillar = torch.ones([1, 8599],dtype=torch.float32, device=pillar_x.device )
    # x_sub_shaped = torch.ones([1, 100,8599, 1],dtype=torch.float32, device=pillar_x.device )
    # y_sub_shaped = torch.ones([1, 100,8599, 1],dtype=torch.float32, device=pillar_x.device )
    # mask = torch.ones([1, 100, 8599, 1],dtype=torch.float32, device=pillar_x.device )

    example1 = [pillar_x, pillar_y, pillar_z, pillar_i, num_points_per_pillar, x_sub_shaped, y_sub_shaped, mask]
    # example1 = [pillar_x, pillar_y, pillar_z]
    # example1 = [pillar_x, pillar_y, pillar_z, pillar_i, num_points, mask]
    torch.onnx.export(net, example1, "pp.onnx", verbose=False, input_names = input_names)

    sp_f = torch.ones([1, 64, 496, 432],dtype=torch.float32, device=pillar_x.device )
    torch.onnx.export(net.rpn, sp_f, "rpn.onnx", verbose=False)
    return 0
def predict_kitti_to_anno(net,
                          example,
                          class_names,
                          center_limit_range=None,
                          lidar_input=False,
                          global_set=None):
    # evaluation example:[0:'voxels', 1:'num_points', 2:'coordinates', 3:'rect',
    # 4:'Trv2c', 5:'P2',
    # 6:'anchors', 7:'anchors_mask', 8:'image_idx', 9:'image_shape']
    # batch_image_shape = example['image_shape']
    batch_image_shape = example[9]
    batch_imgidx = example[8]

    pillar_x = example[0][:,:,0].unsqueeze(0).unsqueeze(0)
    pillar_y = example[0][:,:,1].unsqueeze(0).unsqueeze(0)
    pillar_z = example[0][:,:,2].unsqueeze(0).unsqueeze(0)
    pillar_i = example[0][:,:,3].unsqueeze(0).unsqueeze(0)
    num_points_per_pillar = example[1].float().unsqueeze(0)

    # Find distance of x, y, and z from pillar center
    # assuming xyres_16.proto
    coors_x = example[2][:, 3].float()
    coors_y = example[2][:, 2].float()
    x_sub = coors_x.unsqueeze(1) * 0.16 + 0.1
    y_sub = coors_y.unsqueeze(1) * 0.16 + -39.9
    ones = torch.ones([1, 100],dtype=torch.float32, device=pillar_x.device )
    x_sub_shaped = torch.mm(x_sub, ones).unsqueeze(0).unsqueeze(0)
    y_sub_shaped = torch.mm(y_sub, ones).unsqueeze(0).unsqueeze(0)

    num_points_for_a_pillar = pillar_x.size()[3]
    mask = get_paddings_indicator(num_points_per_pillar, num_points_for_a_pillar, axis=0)
    mask = mask.permute(0, 2, 1)
    mask = mask.unsqueeze(1)
    mask = mask.type_as(pillar_x)

    coors   = example[2]
    anchors = example[6]
    anchors_mask = example[7]
    anchors_mask = torch.as_tensor(anchors_mask, dtype=torch.uint8, device=pillar_x.device)
    anchors_mask = anchors_mask.byte()
    rect = example[3]
    Trv2c = example[4]
    P2 = example[5]
    image_idx = example[8]

    input = [pillar_x, pillar_y, pillar_z, pillar_i,
             num_points_per_pillar, x_sub_shaped, y_sub_shaped, mask, coors,
             anchors, anchors_mask, rect, Trv2c, P2, image_idx]

    predictions_dicts = net(input)
    # predictions_dict = {
    #     0:"bbox": box_2d_preds,
    #     1:"box3d_camera": final_box_preds_camera,
    #     2:"box3d_lidar": final_box_preds,
    #     3:"scores": final_scores,
    #     4:"label_preds": label_preds,
    #     5:"image_idx": img_idx,
    # }
    annos = []
    for i, preds_dict in enumerate(predictions_dicts):
        image_shape = batch_image_shape[i]
        # img_idx = preds_dict["image_idx"]
        img_idx = preds_dict[5]
        # if preds_dict["bbox"] is not None:
        if preds_dict[0] is not None:
            # box_2d_preds = preds_dict["bbox"].detach().cpu().numpy()
            box_2d_preds = preds_dict[0].detach().cpu().numpy()
            # box_preds = preds_dict["box3d_camera"].detach().cpu().numpy()
            box_preds = preds_dict[1].detach().cpu().numpy()
            # scores = preds_dict["scores"].detach().cpu().numpy()
            scores = preds_dict[3].detach().cpu().numpy()
            # box_preds_lidar = preds_dict["box3d_lidar"].detach().cpu().numpy()
            box_preds_lidar = preds_dict[2].detach().cpu().numpy()
            # write pred to file
            # label_preds = preds_dict["label_preds"].detach().cpu().numpy()
            label_preds = preds_dict[4].detach().cpu().numpy()
            # label_preds = np.zeros([box_2d_preds.shape[0]], dtype=np.int32)
            anno = kitti.get_start_result_anno()
            num_example = 0
            for box, box_lidar, bbox, score, label in zip(
                    box_preds, box_preds_lidar, box_2d_preds, scores,
                    label_preds):
                if not lidar_input:
                    if bbox[0] > image_shape[1] or bbox[1] > image_shape[0]:
                        continue
                    if bbox[2] < 0 or bbox[3] < 0:
                        continue
                # print(img_shape)
                if center_limit_range is not None:
                    limit_range = np.array(center_limit_range)
                    if (np.any(box_lidar[:3] < limit_range[:3])
                            or np.any(box_lidar[:3] > limit_range[3:])):
                        continue
                image_shape = [image_shape[0], image_shape[1]]
                bbox[2:] = np.minimum(bbox[2:], image_shape[::-1])
                bbox[:2] = np.maximum(bbox[:2], [0, 0])
                anno["name"].append(class_names[int(label)])
                anno["truncated"].append(0.0)
                anno["occluded"].append(0)
                anno["alpha"].append(-np.arctan2(-box_lidar[1], box_lidar[0]) +
                                     box[6])
                anno["bbox"].append(bbox)
                anno["dimensions"].append(box[3:6])
                anno["location"].append(box[:3])
                anno["rotation_y"].append(box[6])
                if global_set is not None:
                    for i in range(100000):
                        if score in global_set:
                            score -= 1 / 100000
                        else:
                            global_set.add(score)
                            break
                anno["score"].append(score)

                num_example += 1
            if num_example != 0:
                anno = {n: np.stack(v) for n, v in anno.items()}
                annos.append(anno)
            else:
                annos.append(kitti.empty_result_anno())
        else:
            annos.append(kitti.empty_result_anno())
        num_example = annos[-1]["name"].shape[0]
        annos[-1]["image_idx"] = np.array(
            [img_idx] * num_example, dtype=np.int64)
    return annos
    def forward(self, features, num_voxels, coors):
        #only use avobe >-1.6

        #features_mask =torch.where((features>-1.6),torch.ones(features.shape[0],features.shape[1],features.shape[2],device ="cuda:0"),torch.zeros(features.shape[0],features.shape[1],features.shape[2],device ="cuda:0"))

        # Find distance of x, y, and z from cluster center
        points_mean = features[:, :, :3].sum(
            dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1)
        f_cluster = features[:, :, :3] - points_mean

        # Find distance of x, y, and z from pillar center
        f_center = torch.zeros_like(features[:, :, :2])
        f_center[:, :, 0] = features[:, :, 0] - (
            coors[:, 3].float().unsqueeze(1) * self.vx + self.x_offset)
        f_center[:, :, 1] = features[:, :, 1] - (
            coors[:, 2].float().unsqueeze(1) * self.vy + self.y_offset)

        # Combine together feature decorations
        features_ls = [features[:, :, :3], f_cluster, f_center]
        coss = torch.acos(features[:, :, 3])
        if self._with_distance:
            points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
            features_ls.append(points_dist)
        features = torch.cat(features_ls, dim=-1)

        theta_features_ls = [features[:, :, :3], f_cluster]
        coss = torch.acos(features[:, :, 3])
        if self._with_distance:
            points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
            theta_features_ls.append(points_dist)
        theta_features = torch.cat(theta_features_ls, dim=-1)

        #print(features)

        # The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
        # empty pillars remain set to zeros.
        voxel_count = features.shape[1]
        mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
        mask = torch.unsqueeze(mask, -1).type_as(features)
        features *= mask

        #print(features)

        # Forward pass through PFNLayers
        target_thetas = torch.reshape(coss, (-1, 50, 1))
        #features_mask_z = torch.reshape(features_mask[:,:,2],(-1,25,1))
        #print(features_mask_z)

        voxel_count = target_thetas.shape[1]
        mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
        mask = torch.unsqueeze(mask, -1).type_as(target_thetas)
        target_thetas *= mask

        #print(target_thetas.size())

        #print(theta_features.size())
        for theta in self.theta_layers:
            theta_features = theta(theta_features)

        voxel_count = target_thetas.shape[1]
        mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
        mask = torch.unsqueeze(mask, -1).type_as(theta_features)
        theta_features *= mask

        features_LS = [theta_features, features]
        features = torch.cat(features_LS, dim=-1)
        #print(features.size())
        for pfn in self.pfn_layers:
            #print(features.size())
            features = pfn(features)

        criterion = nn.MSELoss()

        theta_loss = criterion(theta_features, target_thetas)

        #print(num_voxels)
        #print(features_mask_z.sum())

        theta_loss = theta_loss / (num_voxels.sum() +
                                   0.001) * 50 * features.shape[0]

        print(theta_loss)

        return features.squeeze(), theta_loss * 1.0
def example_convert_to_torch_for_cuda_implementation(example, dtype=torch.float32,
                                                     device=None) -> dict:
    device = device or torch.device("cuda:0")
    example_torch = {}
    float_names = [
        "voxels", "anchors", "reg_targets", "reg_weights", "bev_map", "rect",
        "Trv2c", "P2"
    ]

    # almost every element is unsqueezed to get [1/2, 12000, 100, 1] the last one is for concatenate

    for k, v in example.items():
        if k == 'voxels':
            v = torch.as_tensor(v, dtype=dtype, device=device)
            example_torch['dev_pillar_x_'] = v[:, :, 0]  # shape [x, 100]
            example_torch['dev_pillar_y_'] = v[:, :, 1]
            example_torch['dev_pillar_z_'] = v[:, :, 2]
            example_torch['dev_pillar_i_'] = v[:, :, 3]

        elif k == 'coordinates':
            coors = torch.as_tensor(v, dtype=torch.int32, device=device)
            example_torch['dev_x_coors_for_sub_shaped_'] = coors[:, 3].float().view(-1, 1).repeat(1, 100)
            example_torch['dev_y_coors_for_sub_shaped_'] = coors[:, 2].float().view(-1, 1).repeat(1, 100)
            # print( example_torch['dev_x_coors_for_sub_shaped_'] [0, :, 0])  # correct

            example_torch[k] = torch.as_tensor(
                v, dtype=torch.int32, device=device)

        elif k in float_names:
            example_torch[k] = torch.as_tensor(v, dtype=dtype, device=device)

        elif k in ["labels", "num_points"]:  # num_points is dev_num_points_per_pillar
            example_torch[k] = torch.as_tensor(
                v, dtype=torch.int32, device=device)

        elif k in ["anchors_mask"]:
            example_torch[k] = torch.as_tensor(
                v, dtype=torch.uint8, device=device)
        else:
            example_torch[k] = v

    """
    voxel_count = features.shape[1]
    mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
    mask = torch.unsqueeze(mask, -1).type_as(features)
    features *= mask
    """

    # get mask
    voxel_count = example_torch['dev_pillar_x_'].shape[1]
    num_voxels = example_torch['num_points']
    mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
    mask = mask.type_as(example_torch['dev_pillar_x_'])
    example_torch['dev_pillar_feature_mask_'] = mask
    # split data by batch and expand tensor to 12000 pillars with empty ones
    batch_size = example["anchors"].shape[0]
    example_torch['dev_num_points_per_pillar_'] = example_torch['num_points']

    # add zeros and split tensor by batch
    def add_batch_and_zeros(dense_pillar_tensor, batch_size, max_pillar_num, zero=True):
        if len(dense_pillar_tensor.shape) > 1:
            last = [dense_pillar_tensor.shape[-1]]
        else:
            last = []

        shape = [batch_size, 1, max_pillar_num] + last

        sparse_pillar_tensor = torch.zeros(shape, device=device)
        # default value is -1 for num and 0 for the others
        if zero is False:
            sparse_pillar_tensor = sparse_pillar_tensor.fill_(-1)

        for batch_itt in range(batch_size):
            batch_mask = coors[:, 0] == batch_itt
            batch_dense_pillar_tensor = dense_pillar_tensor[batch_mask]
            # if batch_itt == 1:
            #     print(batch_dense_pillar_tensor[0, 0, 0])
            non_empty_pillar_num = batch_dense_pillar_tensor.shape[0]
            sparse_pillar_tensor[batch_itt][0][:non_empty_pillar_num] = batch_dense_pillar_tensor

        return sparse_pillar_tensor

    for k in ["dev_pillar_x_", "dev_pillar_y_", "dev_pillar_z_", "dev_pillar_i_", "dev_num_points_per_pillar_",
              "dev_x_coors_for_sub_shaped_", "dev_y_coors_for_sub_shaped_", "dev_pillar_feature_mask_"]:

        example_torch[k] = add_batch_and_zeros(example_torch[k], batch_size, 12000, zero= k!='dev_num_points_per_pillar_')
    return example_torch