예제 #1
0
def log_function(model_dir, config_path):
    model_logging = SimpleModelLog(model_dir)
    model_logging.open()

    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, config)

    model_logging.log_text(proto_str + "\n", 0, tag="config")
    return model_logging
예제 #2
0
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          pretrained_path=None,
          pretrained_include=None,
          pretrained_exclude=None,
          freeze_include=None,
          freeze_exclude=None,
          multi_gpu=False,
          measure_time=False,
          resume=False):
    """train a VoxelNet model specified by a config file.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # create dir for saving training states
    model_dir = str(Path(model_dir).resolve())
    if create_folder:
        if Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    model_dir = Path(model_dir)
    if not resume and model_dir.exists():
        raise ValueError("model dir exists and you don't specify resume.")
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    # loadd config file
    config_file_bkp = "pipeline.config"
    if isinstance(config_path, str):
        # directly provide a config object. this usually used
        # when you want to train with several different parameters in
        # one script.
        config = pipeline_pb2.TrainEvalPipelineConfig()
        with open(config_path, "r") as f:
            proto_str = f.read()
            text_format.Merge(proto_str, config)
    else:
        config = config_path
        proto_str = text_format.MessageToString(config, indent=2)
    with (model_dir / config_file_bkp).open("w") as f:
        f.write(proto_str)

    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    net = build_network(model_cfg, measure_time).to(device)
    # if train_cfg.enable_mixed_precision:
    #     net.half()
    #     net.metrics_to_float()
    #     net.convert_norm_to_float(net)
    target_assigner = net.target_assigner
    voxel_generator = net.voxel_generator
    print("num parameters:", len(list(net.parameters())))
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    if pretrained_path is not None:
        ## load  pretrained params
        model_dict = net.state_dict()
        pretrained_dict = torch.load(pretrained_path)
        pretrained_dict = filter_param_dict(pretrained_dict,
                                            pretrained_include,
                                            pretrained_exclude)
        new_pretrained_dict = {}
        for k, v in pretrained_dict.items():
            if k in model_dict and v.shape == model_dict[k].shape:
                new_pretrained_dict[k] = v
        print("Load pretrained parameters:")
        for k, v in new_pretrained_dict.items():
            print(k, v.shape)
        model_dict.update(new_pretrained_dict)
        net.load_state_dict(model_dict)
        freeze_params_v2(dict(net.named_parameters()), freeze_include,
                         freeze_exclude)
        net.clear_global_step()
        net.clear_metrics()
    if multi_gpu:
        net_parallel = torch.nn.DataParallel(net)
    else:
        net_parallel = net
    optimizer_cfg = train_cfg.optimizer
    loss_scale = train_cfg.loss_scale_factor
    fastai_optimizer = optimizer_builder.build(optimizer_cfg,
                                               net,
                                               mixed=False,
                                               loss_scale=loss_scale)
    if loss_scale < 0:
        loss_scale = "dynamic"
    if train_cfg.enable_mixed_precision:
        max_num_voxels = input_cfg.preprocess.max_number_of_voxels * input_cfg.batch_size
        assert max_num_voxels < 65535, "spconv fp16 training only support this"
        from apex import amp
        net, amp_optimizer = amp.initialize(net,
                                            fastai_optimizer,
                                            opt_level="O2",
                                            keep_batchnorm_fp32=True,
                                            loss_scale=loss_scale)
        net.metrics_to_float()
    else:
        amp_optimizer = fastai_optimizer
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [fastai_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, amp_optimizer,
                                              train_cfg.steps)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32

    if multi_gpu:
        num_gpu = torch.cuda.device_count()
        print(f"MULTI-GPU: use {num_gpu} gpu")
        collate_fn = merge_second_batch_multigpu
    else:
        collate_fn = merge_second_batch
        num_gpu = 1

    ######################
    # PREPARE INPUT
    ######################
    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner,
                                         multi_gpu=multi_gpu)
    eval_dataset = input_reader_builder.build(eval_input_cfg,
                                              model_cfg,
                                              training=False,
                                              voxel_generator=voxel_generator,
                                              target_assigner=target_assigner)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=input_cfg.batch_size * num_gpu,
        shuffle=True,
        num_workers=input_cfg.preprocess.num_workers * num_gpu,
        pin_memory=False,
        collate_fn=collate_fn,
        worker_init_fn=_worker_init_fn,
        drop_last=not multi_gpu)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,  # only support multi-gpu train
        shuffle=False,
        num_workers=eval_input_cfg.preprocess.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    ######################
    # TRAINING
    ######################
    model_logging = SimpleModelLog(model_dir)
    model_logging.open()
    model_logging.log_text(proto_str + "\n", 0, tag="config")
    start_step = net.get_global_step()
    total_step = train_cfg.steps
    t = time.time()
    steps_per_eval = train_cfg.steps_per_eval
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    amp_optimizer.zero_grad()
    step_times = []
    step = start_step
    try:
        while True:
            if clear_metrics_every_epoch:
                net.clear_metrics()
            for example in dataloader:
                lr_scheduler.step(net.get_global_step())
                time_metrics = example["metrics"]
                example.pop("metrics")
                example_torch = example_convert_to_torch(example, float_dtype)

                batch_size = example["anchors"].shape[0]

                ret_dict = net_parallel(example_torch)
                cls_preds = ret_dict["cls_preds"]
                loss = ret_dict["loss"].mean()
                cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                cls_pos_loss = ret_dict["cls_pos_loss"].mean()
                cls_neg_loss = ret_dict["cls_neg_loss"].mean()
                loc_loss = ret_dict["loc_loss"]
                cls_loss = ret_dict["cls_loss"]

                cared = ret_dict["cared"]
                labels = example_torch["labels"]
                if train_cfg.enable_mixed_precision:
                    with amp.scale_loss(loss, amp_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
                amp_optimizer.step()
                amp_optimizer.zero_grad()
                net.update_global_step()
                net_metrics = net.update_metrics(cls_loss_reduced,
                                                 loc_loss_reduced, cls_preds,
                                                 labels, cared)

                step_time = (time.time() - t)
                step_times.append(step_time)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                if 'anchors_mask' not in example_torch:
                    num_anchors = example_torch['anchors'].shape[1]
                else:
                    num_anchors = int(example_torch['anchors_mask'][0].sum())
                global_step = net.get_global_step()

                if global_step % display_step == 0:
                    if measure_time:
                        for name, val in net.get_avg_time_dict().items():
                            print(f"avg {name} time = {val * 1000:.3f} ms")

                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]
                    metrics["runtime"] = {
                        "step": global_step,
                        "steptime": np.mean(step_times),
                    }
                    metrics["runtime"].update(time_metrics[0])
                    step_times = []
                    metrics.update(net_metrics)
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())
                    if model_cfg.use_direction_classifier:
                        dir_loss_reduced = ret_dict["dir_loss_reduced"].mean()
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())

                    metrics["misc"] = {
                        # "num_vox": int(example_torch["voxels"].shape[0]),
                        "num_pos": int(num_pos),
                        "num_neg": int(num_neg),
                        "num_anchors": int(num_anchors),
                        "lr": float(amp_optimizer.lr),
                        "mem_usage": psutil.virtual_memory().percent,
                    }
                    model_logging.log_metrics(metrics, global_step)

                if global_step % steps_per_eval == 0:
                    torchplus.train.save_models(model_dir,
                                                [net, amp_optimizer],
                                                net.get_global_step())
                    net.eval()
                    result_path_step = result_path / f"step_{net.get_global_step()}"
                    result_path_step.mkdir(parents=True, exist_ok=True)
                    model_logging.log_text("#################################",
                                           global_step)
                    model_logging.log_text("# EVAL", global_step)
                    model_logging.log_text("#################################",
                                           global_step)
                    model_logging.log_text("Generate output labels...",
                                           global_step)
                    t = time.time()
                    detections = []
                    prog_bar = ProgressBar()
                    net.clear_timer()
                    prog_bar.start(
                        (len(eval_dataset) + eval_input_cfg.batch_size - 1) //
                        eval_input_cfg.batch_size)
                    for example in iter(eval_dataloader):
                        example = example_convert_to_torch(
                            example, float_dtype)
                        detections += net(example)
                        prog_bar.print_bar()

                    sec_per_ex = len(eval_dataset) / (time.time() - t)
                    model_logging.log_text(
                        f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                        global_step)
                    result_dict = eval_dataset.dataset.evaluation(
                        detections, str(result_path_step))
                    for k, v in result_dict["results"].items():
                        model_logging.log_text("Evaluation {}".format(k),
                                               global_step)
                        model_logging.log_text(v, global_step)
                    model_logging.log_metrics(result_dict["detail"],
                                              global_step)
                    with open(result_path_step / "result.pkl", 'wb') as f:
                        pickle.dump(detections, f)
                    net.train()
                step += 1
                if step >= total_step:
                    break
            if step >= total_step:
                break
    except Exception as e:
        print(json.dumps(example["metadata"], indent=2))
        model_logging.log_text(str(e), step)
        model_logging.log_text(json.dumps(example["metadata"], indent=2), step)
        torchplus.train.save_models(model_dir, [net, amp_optimizer], step)
        raise e
    finally:
        model_logging.close()
    torchplus.train.save_models(model_dir, [net, amp_optimizer],
                                net.get_global_step())
예제 #3
0
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          resume=False):
    """train a VoxelNet model specified by a config file.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if create_folder:
        if pathlib.Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    model_dir = pathlib.Path(model_dir)
    if not resume and model_dir.exists():
        raise ValueError("model dir exists and you don't specify resume.")
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config_file_bkp = "pipeline.config"
    if isinstance(config_path, str):
        # directly provide a config object. this usually used
        # when you want to train with several different parameters in
        # one script.
        config = pipeline_pb2.TrainEvalPipelineConfig()
        with open(config_path, "r") as f:
            proto_str = f.read()
            text_format.Merge(proto_str, config)
    else:
        config = config_path
        proto_str = text_format.MessageToString(config, indent=2)
    with (model_dir / config_file_bkp).open("w") as f:
        f.write(proto_str)

    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    net = build_network(model_cfg).to(device)
    if train_cfg.enable_mixed_precision:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    target_assigner = net.target_assigner
    voxel_generator = net.voxel_generator
    class_names = target_assigner.classes

    # net_train = torch.nn.DataParallel(net).cuda()
    print("num_trainable parameters:", len(list(net.parameters())))
    # for n, p in net.named_parameters():
    #     print(n, p.shape)
    ######################
    # BUILD OPTIMIZER
    ######################
    # we need global_step to create lr_scheduler, so restore net first.
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    gstep = net.get_global_step() - 1
    optimizer_cfg = train_cfg.optimizer
    loss_scale = train_cfg.loss_scale_factor
    mixed_optimizer = optimizer_builder.build(
        optimizer_cfg,
        net,
        mixed=train_cfg.enable_mixed_precision,
        loss_scale=loss_scale)
    optimizer = mixed_optimizer
    center_limit_range = model_cfg.post_center_limit_range
    """
    if train_cfg.enable_mixed_precision:
        mixed_optimizer = torchplus.train.MixedPrecisionWrapper(
            optimizer, loss_scale)
    else:
        mixed_optimizer = optimizer
    """
    # must restore optimizer AFTER using MixedPrecisionWrapper
    torchplus.train.try_restore_latest_checkpoints(model_dir,
                                                   [mixed_optimizer])
    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, optimizer,
                                              train_cfg.steps)
    if train_cfg.enable_mixed_precision:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################
    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner)
    eval_dataset = input_reader_builder.build(eval_input_cfg,
                                              model_cfg,
                                              training=False,
                                              voxel_generator=voxel_generator,
                                              target_assigner=target_assigner)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=input_cfg.batch_size,
        shuffle=True,
        num_workers=input_cfg.preprocess.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch,
        worker_init_fn=_worker_init_fn)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,
        shuffle=False,
        num_workers=eval_input_cfg.preprocess.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    data_iter = iter(dataloader)
    print(data_iter)
    ######################
    # TRAINING
    ######################
    model_logging = SimpleModelLog(model_dir)
    model_logging.open()
    model_logging.log_text(proto_str + "\n", 0, tag="config")

    total_step_elapsed = 0
    remain_steps = train_cfg.steps - net.get_global_step()
    t = time.time()
    ckpt_start_time = t
    steps_per_eval = train_cfg.steps_per_eval
    total_loop = train_cfg.steps // train_cfg.steps_per_eval + 1
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    if train_cfg.steps % train_cfg.steps_per_eval == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:
        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.steps_per_eval > train_cfg.steps:
                steps = train_cfg.steps % train_cfg.steps_per_eval
            else:
                steps = train_cfg.steps_per_eval
            for step in range(steps):
                lr_scheduler.step(net.get_global_step())
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example, float_dtype)

                #batch_size = example["anchors"].shape[0]
                ret_dict = net(example_torch)

                # FCOS

                losses = ret_dict['total_loss']
                loss_cls = ret_dict["loss_cls"]
                loss_reg = ret_dict["loss_reg"]
                cls_preds = ret_dict['cls_preds']
                labels = ret_dict["labels"]
                cared = ret_dict["labels"]

                optimizer.zero_grad()
                losses.backward()
                #torch.nn.utils.clip_grad_norm_(net.parameters(),  1)
                # optimizer_step is for updating the parameter, so clip before update
                optimizer.step()
                net.update_global_step()
                #need to unpack the [0] for fpn
                net_metrics = net.update_metrics(loss_cls, loss_reg,
                                                 cls_preds[0], labels, cared)
                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                global_step = net.get_global_step()

                #print log
                if global_step % display_step == 0:
                    metrics["runtime"] = {
                        "step": global_step,
                        "steptime": step_time,
                    }

                    metrics.update(net_metrics)
                    metrics["misc"] = {
                        "num_vox": int(example_torch["voxels"].shape[0]),
                        "lr": float(optimizer.lr),
                    }
                    model_logging.log_metrics(metrics, global_step)
                ckpt_elasped_time = time.time() - ckpt_start_time
                torchplus.train.save_models(model_dir, [net, optimizer],
                                            net.get_global_step())

            total_step_elapsed += steps
            torchplus.train.save_models(model_dir, [net, optimizer],
                                        net.get_global_step())
            net.eval()
            result_path_step = result_path / f"step_{net.get_global_step()}"
            result_path_step.mkdir(parents=True, exist_ok=True)
            model_logging.log_text("#################################",
                                   global_step)
            model_logging.log_text("# EVAL", global_step)
            model_logging.log_text("#################################",
                                   global_step)
            model_logging.log_text("Generate output labels...", global_step)
            t = time.time()
            detections = []
            prog_bar = ProgressBar()
            net.clear_timer()
            prog_bar.start(
                (len(eval_dataset) + eval_input_cfg.batch_size - 1) //
                eval_input_cfg.batch_size)
            for example in iter(eval_dataloader):
                example = example_convert_to_torch(example, float_dtype)
                with torch.no_grad():
                    detections += net(example)
                prog_bar.print_bar()

            sec_per_ex = len(eval_dataset) / (time.time() - t)
            model_logging.log_text(
                f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                global_step)
            result_dict = eval_dataset.dataset.evaluation(
                detections, str(result_path_step))
            for k, v in result_dict["results"].items():
                model_logging.log_text("Evaluation {}".format(k), global_step)
                model_logging.log_text(v, global_step)
            model_logging.log_metrics(result_dict["detail"], global_step)
            with open(result_path_step / "result.pkl", 'wb') as f:
                pickle.dump(detections, f)
            net.train()
            '''
                new version of evaluation while trainging 
                # do the evaluation while traingingi
                if global_step % steps_per_eval == 0:
                   
                    torchplus.train.save_models(model_dir, [net, optimizer],
                                                net.get_global_step())
                    net.eval()
                    result_path_step = result_path / f"step_{net.get_global_step()}"
                    result_path_step.mkdir(parents=True, exist_ok=True)
                    model_logging.log_text("#################################",
                                        global_step)
                    model_logging.log_text("# EVAL", global_step)
                    model_logging.log_text("#################################",
                                        global_step)
                    model_logging.log_text("Generate output labels...", global_step)
                    t = time.time()
                    detections = []
                    prog_bar = ProgressBar()
                    net.clear_timer()
                    prog_bar.start((len(eval_dataset) + eval_input_cfg.batch_size - 1)
                                // eval_input_cfg.batch_size)
                    for example in iter(eval_dataloader):
                        example = example_convert_to_torch(example, float_dtype)
                        with torch.no_grad():
                            detections += net(example)
                        prog_bar.print_bar()

                    sec_per_ex = len(eval_dataset) / (time.time() - t)
                    model_logging.log_text(
                        f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                        global_step)
                    result_dict = eval_dataset.dataset.evaluation(
                        detections, str(result_path_step))
                    for k, v in result_dict["results"].items():
                        model_logging.log_text("Evaluation {}".format(k), global_step)
                        model_logging.log_text(v, global_step)
                    model_logging.log_metrics(result_dict["detail"], global_step)
                    with open(result_path_step / "result.pkl", 'wb') as f:
                        pickle.dump(detections, f)
                    net.train()
            '''

    except Exception as e:
        print("trainging error")
        raise e
    finally:
        model_logging.close()
    # save model before exit
    torchplus.train.save_models(model_dir, [net, optimizer],
                                net.get_global_step())
예제 #4
0
파일: second_a.py 프로젝트: karlzipser/k3
def train(
        config_path: Union[str, Path, pipeline.TrainEvalPipelineConfig],
        model_dir: Union[str, Path],
        data_root_path: Union[str, Path],
        result_path: Optional[Union[str, Path]] = None,
        display_step: int = 50,
        pretrained_path=None,
        pretrained_include=None,
        pretrained_exclude=None,
        freeze_include=None,
        freeze_exclude=None,
        measure_time: bool = False,
        resume: bool = False,
):
    """train a VoxelNet model specified by a config file.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_dir = real_path(model_dir, check_exists=False)
    if not resume and model_dir.exists():
        raise ValueError("model dir exists and you don't specify resume.")
    model_dir.mkdir(parents=True, exist_ok=True)
    model_dir = Path(model_dir)

    if result_path is None:
        result_path = model_dir / "results"
    else:
        result_path = assert_real_path(result_path, mkdir=True)

    config_file_bkp = DEFAULT_CONFIG_FILE_NAME
    if isinstance(config_path, pipeline.TrainEvalPipelineConfig):
        # directly provide a config object. this usually used
        # when you want to train with several different parameters in
        # one script.
        config = config_path
        proto_str = text_format.MessageToString(config, use_short_repeated_primitives=True, indent=2)
    else:
        config_path = assert_real_path(config_path)
        data_root_path = assert_real_path(data_root_path)
        config = read_pipeline_config(config_path, data_root_path)
        # Copy the contents of config_path to config_file_bkp verbatim without passing it through the protobuf parser.
        with open(str(config_path), "r") as f:
            proto_str = f.read()
    with (model_dir / config_file_bkp).open("w") as f:
        f.write(proto_str)

    input_cfg = config.train_input_reader
    eval_input_cfg = config.eval_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config

    net = build_network(model_cfg, measure_time).to(device)
    if train_cfg.enable_mixed_precision:
        # net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
        
    target_assigner = net.target_assigner
    voxel_generator = net.voxel_generator
    # print("num parameters:", len(list(net.parameters())))
    print("num parameters (million): ", count_parameters(net) * 1e-6)
    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])
    if pretrained_path is not None:
        model_dict = net.state_dict()
        pretrained_dict = torch.load(pretrained_path)
        pretrained_dict = filter_param_dict(pretrained_dict, pretrained_include, pretrained_exclude)
        new_pretrained_dict = {}
        for k, v in pretrained_dict.items():
            if k in model_dict and v.shape == model_dict[k].shape:
                new_pretrained_dict[k] = v        
        print("Load pretrained parameters:")
        for k, v in new_pretrained_dict.items():
            print(k, v.shape)
        model_dict.update(new_pretrained_dict) 
        net.load_state_dict(model_dict)
        freeze_params_v2(dict(net.named_parameters()), freeze_include, freeze_exclude)
        net.clear_global_step()
        net.clear_metrics()

    optimizer_cfg = train_cfg.optimizer

    loss_scale = train_cfg.loss_scale_factor

    fastai_optimizer = optimizer_builder.build(
        optimizer_cfg,
        net,
        mixed=False,
        loss_scale=loss_scale)

    if loss_scale < 0:
        loss_scale = "dynamic"

    amp_optimizer = fastai_optimizer

    torchplus.train.try_restore_latest_checkpoints(model_dir,[amp_optimizer])
    
    float_dtype = torch.float32

    collate_fn = merge_second_batch
    num_gpu = 1

    ######################
    # PREPARE INPUT
    ######################
    def get_train_dataloader(input_cfg, model_cfg, voxel_generator, target_assigner,
                          multi_gpu, num_gpu, collate_fn, _worker_init_fn):
        dataset = input_reader_builder.build(
            input_cfg,
            model_cfg,
            training=True,
            voxel_generator=voxel_generator,
            target_assigner=target_assigner,
            multi_gpu=multi_gpu)

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=input_cfg.batch_size * num_gpu,
            shuffle=True,
            num_workers=input_cfg.preprocess.num_workers * num_gpu,
            pin_memory=True,
            collate_fn=collate_fn,
            worker_init_fn=_worker_init_fn,
            drop_last=not multi_gpu)

        return dataloader

    eval_dataset = input_reader_builder.build(
        eval_input_cfg,
        model_cfg,
        training=False,
        voxel_generator=voxel_generator,
        target_assigner=target_assigner)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size, # only support multi-gpu train
        shuffle=False,
        num_workers=eval_input_cfg.preprocess.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    ######################
    # TRAINING
    ######################
    model_logging = SimpleModelLog(model_dir)
    model_logging.open()
    model_logging.log_text(proto_str + "\n", 0, tag="config")
    epochs = train_cfg.steps
    epochs_per_eval = train_cfg.steps_per_eval
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    amp_optimizer.zero_grad()
    step_times = []
    eval_times = []

    t = time.time()
    reset_ds_epoch = False
    run_once = True
    if not (os.getenv("MLFLOW_EXPERIMENT_ID") or os.getenv("MLFLOW_EXPERIMENT_NAME")):
        mlflow.set_experiment("object_detection")
    try:
        while True:
            if run_once or reset_ds_epoch:
                dataloader = get_train_dataloader(input_cfg, model_cfg, voxel_generator, target_assigner,
                                                  multi_gpu, num_gpu, collate_fn, _worker_init_fn)
                total_step = int(np.ceil((len(dataloader.dataset) / dataloader.batch_size) * epochs))
                steps_per_eval = int(np.floor((len(dataloader.dataset) / dataloader.batch_size) * epochs_per_eval))
                train_cfg.steps = int(total_step)
                train_cfg.steps_per_eval = int(steps_per_eval)
                lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, amp_optimizer, total_step)

                print(f"\nnumber of samples: {len(dataloader.dataset)}\ntotal_steps: {total_step}\nsteps_per_eval: {steps_per_eval}")

                run_once = False

            if clear_metrics_every_epoch:
                net.clear_metrics()
            for example in dataloader:
                lr_scheduler.step(net.get_global_step())
                time_metrics = example["metrics"]
                example.pop("metrics")
                example_torch = example_convert_to_torch(example, float_dtype)

                batch_size = example["anchors"].shape[0]

                ret_dict = net(example_torch)
                cls_preds = ret_dict["cls_preds"]
                loss = ret_dict["loss"].mean()
                cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                cls_pos_loss = ret_dict["cls_pos_loss"].mean()
                cls_neg_loss = ret_dict["cls_neg_loss"].mean()
                loc_loss = ret_dict["loc_loss"]
                # cls_loss = ret_dict["cls_loss"]
                cared = ret_dict["cared"]
                labels = example_torch["labels"]
                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 30.0)
                # torch.nn.utils.clip_grad_norm_(amp.master_params(amp_optimizer), 10.0)

                amp_optimizer.step()
                amp_optimizer.zero_grad()
                net.update_global_step()
                global_step = net.get_global_step()
                net_metrics = net.update_metrics(cls_loss_reduced,
                                                 loc_loss_reduced, cls_preds,
                                                 labels, cared)

                step_time = (time.time() - t)
                step_times.append(step_time)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                if 'anchors_mask' not in example_torch:
                    num_anchors = example_torch['anchors'].shape[1]
                else:
                    num_anchors = int(example_torch['anchors_mask'][0].sum())

                if global_step % display_step == 0:
                    if measure_time:
                        for name, val in net.get_avg_time_dict().items():
                            print(f"avg {name} time = {val * 1000:.3f} ms")

                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]

                    total_seconds = ((total_step - global_step) * np.mean(step_times))
                    if len(eval_times) != 0:
                        eval_seconds = ((epochs / epochs_per_eval) - len(eval_times)) * np.mean(eval_times)
                        total_seconds += eval_seconds
                    
                    next_eval_seconds = (steps_per_eval - (global_step % steps_per_eval)) * np.mean(step_times)
                    metrics["runtime"] = {
                        "step": global_step,
                        "steptime": np.mean(step_times),
                        "ETA": seconds_to_eta(total_seconds),
                        "eval_ETA": seconds_to_eta(next_eval_seconds),
                    }
                    metrics["runtime"].update(time_metrics[0])
                    step_times = []
                    metrics.update(net_metrics)
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())
                    if model_cfg.use_direction_classifier:
                        dir_loss_reduced = ret_dict["dir_loss_reduced"].mean()
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())

                    metrics["misc"] = {
                        "num_vox": int(example_torch["voxels"].shape[0]),
                        "num_pos": int(num_pos),
                        "num_neg": int(num_neg),
                        "num_anchors": int(num_anchors),
                        "lr": float(amp_optimizer.lr),
                        "mem_usage": psutil.virtual_memory().percent,
                    }

                    model_logging.log_metrics(metrics, global_step)


                # if global_step % steps_per_eval != 0 and global_step % 1000 == 0:
                    # torchplus.train.save_models(model_dir, [net, amp_optimizer], net.get_global_step())

                if global_step % steps_per_eval == 0:
                    torchplus.train.save_models(model_dir, [net, amp_optimizer], global_step)
                    net.eval()
                    result_path_step = result_path / f"step_{global_step}"
                    result_path_step.mkdir(parents=True, exist_ok=True)
                    model_logging.log_text("#################################", global_step)
                    model_logging.log_text("# EVAL", global_step)
                    model_logging.log_text("#################################", global_step)
                    model_logging.log_text("Generate output labels...", global_step)
                    t = time.time()
                    detections = []
                    prog_bar = ProgressBar()
                    net.clear_timer()
                    prog_bar.start((len(eval_dataset) + eval_input_cfg.batch_size - 1)
                                // eval_input_cfg.batch_size)
                    for example in iter(eval_dataloader):
                        example = example_convert_to_torch(example, float_dtype)
                        detections += net(example)
                        prog_bar.print_bar()

                    sec_per_ex = len(eval_dataset) / (time.time() - t)
                    eval_times.append((time.time() - t))

                    model_logging.log_text(f'generate label finished({sec_per_ex:.2f}/s). start eval:', global_step)
                    result_dict = eval_dataset.dataset.evaluation(detections, result_path_step)
                    if result_dict is None:
                        raise RuntimeError("eval_dataset.dataset.evaluation() returned None")
                    for k, v in result_dict["results"].items():
                        model_logging.log_text("Evaluation {}".format(k), global_step)
                        model_logging.log_text(v, global_step)
                    model_logging.log_metrics(result_dict["detail"], global_step)
                    with open(result_path_step / "result.pkl", 'wb') as f:
                        pickle.dump(detections, f)
                    net.train()
                if global_step >= total_step:
                    break
            if net.get_global_step() >= total_step:
                break
    except Exception as e:
        if 'example' in locals():
            print(json.dumps(example["metadata"], indent=2))
        global_step = net.get_global_step()
        model_logging.log_text(str(e), global_step)
        if 'example' in locals():
            model_logging.log_text(json.dumps(example["metadata"], indent=2), global_step)
        torchplus.train.save_models(model_dir, [net, amp_optimizer], global_step)
        raise e
    finally:
        model_logging.close()
    torchplus.train.save_models(model_dir, [net, amp_optimizer], net.get_global_step())

    def _save_checkpoint_info(file_path, config_filename, checkpoint_filename):
        from yaml import dump
        with open(file_path, "w") as config_info_file:
            checkpoint_info = { "config": config_filename, "checkpoint": checkpoint_filename }
            dump(checkpoint_info, config_info_file, default_flow_style=False)

    ckpt_info_path = str(model_dir / "checkpoint_info.yaml")
    latest_ckpt_filename = "voxelnet-{}.tckpt".format(net.get_global_step())
    _save_checkpoint_info(ckpt_info_path, config_file_bkp, latest_ckpt_filename)
    mlflow.log_artifact(ckpt_info_path, "model")

    mlflow.log_artifact(str(model_dir / config_file_bkp), "model")
    mlflow.log_artifact(str(model_dir / latest_ckpt_filename), "model")
예제 #5
0
        pin_memory=False,
        collate_fn=collate_fn,
        worker_init_fn=_worker_init_fn,
        drop_last=not cfg.multi_gpu)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.batch_size,  # only support multi-gpu train
        shuffle=False,
        num_workers=eval_input_cfg.preprocess.num_workers,
        pin_memory=False,
        collate_fn=merge_second_batch)

    ######################
    # TRAINING
    ######################
    model_logging = SimpleModelLog(cfg.model_dir)
    model_logging.open()
    model_logging.log_text(proto_str + "\n", 0, tag="config")
    start_step = net.get_global_step()
    total_step = train_cfg.steps
    t = time.time()
    steps_per_eval = train_cfg.steps_per_eval
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    amp_optimizer.zero_grad()
    step_times = []
    step = start_step
    try:
        while True:
            if clear_metrics_every_epoch:
                net.clear_metrics()
예제 #6
0
    def __init__(self,  train_net,
                        test_net,
                        pretrain = None,
                        prefix = "pp",
                        model_dir=None,
                        config_path=None,
                        ### Solver Params ###
                        solver_type='ADAM',
                        weight_decay=0.001,
                        lr_policy='step',
                        warmup_step=0,
                        warmup_start_lr=0,
                        lr_ratio=1,
                        end_ratio=1,
                        base_lr=0.002,
                        max_lr=0.002,
                        momentum = 0.9,
                        max_momentum = 0,
                        cycle_steps=1856,
                        gamma=0.8, #0.1 for lr_policy
                        stepsize=100,
                        test_iter=3769,
                        test_interval=50, #set test_interval to 999999999 if not it will auto run validation
                        max_iter=1e5,
                        iter_size=1,
                        snapshot=9999,
                        display=1,
                        random_seed=0,
                        debug_info=False,
                        create_prototxt=True,
                        args=None):
        """Initialize the SolverWrapper."""
        self.test_net = test_net
        self.solver_param = caffe_pb2.SolverParameter()
        self.solver_param.train_net = train_net
        self.solver_param.test_initialization = False


        self.solver_param.display = display
        self.solver_param.warmup_step = warmup_step
        self.solver_param.warmup_start_lr = warmup_start_lr
        self.solver_param.lr_ratio = lr_ratio
        self.solver_param.end_ratio = end_ratio
        self.solver_param.base_lr = base_lr
        self.solver_param.max_lr = max_lr
        self.solver_param.cycle_steps = cycle_steps
        self.solver_param.max_momentum = max_momentum
        self.solver_param.lr_policy = lr_policy  # "fixed" #exp
        self.solver_param.gamma = gamma
        self.solver_param.stepsize = stepsize

        self.solver_param.display = display
        self.solver_param.max_iter = max_iter
        self.solver_param.iter_size = iter_size
        self.solver_param.snapshot = snapshot
        self.solver_param.snapshot_prefix = os.path.join(model_dir, prefix)
        self.solver_param.random_seed = random_seed

        self.solver_param.solver_mode = caffe_pb2.SolverParameter.GPU
        if solver_type is 'SGD':
            print("[Info] SGD Solver >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
            self.solver_param.solver_type = caffe_pb2.SolverParameter.SGD
        elif solver_type is 'ADAM':
            print("[Info] ADAM Solver >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
            self.solver_param.solver_type = caffe_pb2.SolverParameter.ADAM
        self.solver_param.momentum = momentum
        self.solver_param.momentum2 = 0.999

        self.solver_param.weight_decay = weight_decay
        self.solver_param.debug_info = debug_info

        if create_prototxt:
            solver_prototxt = get_prototxt(self.solver_param, os.path.join(model_dir, 'solver.prototxt'))
            print(solver_prototxt)

        self.solver = caffe.get_solver(solver_prototxt)
        self.test_interval = test_interval

        '''Model config parameter Initialization'''
        self.args = args
        self.model_dir, self.config_path = model_dir, config_path
        _, eval_input_cfg, model_cfg, train_cfg = load_config(self.model_dir, self.config_path)
        voxel_generator, self.target_assigner = build_network(model_cfg)
        self.dataloader, self.eval_dataset = load_dataloader(eval_input_cfg, model_cfg, voxel_generator,
                                        self.target_assigner, args = args)
        self.model_cfg = model_cfg
        # NOTE: Could have problem, if eval no good check here
        self._box_coder=self.target_assigner.box_coder
        classes_cfg = model_cfg.target_assigner.class_settings
        self._num_class = len(classes_cfg)
        self._encode_background_as_zeros = model_cfg.encode_background_as_zeros
        self._nms_class_agnostic=model_cfg.nms_class_agnostic
        self._use_multi_class_nms=[c.use_multi_class_nms for c in classes_cfg]
        self._nms_pre_max_sizes=[c.nms_pre_max_size for c in classes_cfg]
        self._multiclass_nms=all(self._use_multi_class_nms)
        self._use_sigmoid_score=model_cfg.use_sigmoid_score
        self._num_anchor_per_loc=self.target_assigner.num_anchors_per_location

        self._use_rotate_nms=[c.use_rotate_nms for c in classes_cfg]  #False for pillar, True for second
        self._nms_post_max_sizes=[c.nms_post_max_size for c in classes_cfg] #300 for pillar, 100 for second
        self._nms_score_thresholds=[c.nms_score_threshold for c in classes_cfg] # 0.4 in submit, but 0.3 can get better hard performance #pillar use 0.05, second 0.3
        self._nms_iou_thresholds=[c.nms_iou_threshold for c in classes_cfg] ## NOTE: double check #pillar use 0.5, second use 0.01
        self._post_center_range=list(model_cfg.post_center_limit_range) ## NOTE: double check
        self._use_direction_classifier=model_cfg.use_direction_classifier ## NOTE: double check
        path = pretrain["path"]
        weight = pretrain["weight"]
        skip_layer = pretrain["skip_layer"] #list skip layer name
        if path != None and weight != None:
            self.load_pretrained_caffe_weight(path, weight, skip_layer)

        #self.model_logging = log_function(self.model_dir, self.config_path)
        ################################Log#####################################
        self.model_logging = SimpleModelLog(self.model_dir)
        self.model_logging.open()

        config = pipeline_pb2.TrainEvalPipelineConfig()
        with open(self.config_path, "r") as f:
            proto_str = f.read()
            text_format.Merge(proto_str, config)

        self.model_logging.log_text(proto_str + "\n", 0, tag="config")
        self.model_logging.close()
        ########################################################################

        #Log loss
        ########################################################################
        self.log_loss_path = Path(self.model_dir) / f'log_loss.txt'
예제 #7
0
class SolverWrapper:
    def __init__(self,  train_net,
                        test_net,
                        pretrain = None,
                        prefix = "pp",
                        model_dir=None,
                        config_path=None,
                        ### Solver Params ###
                        solver_type='ADAM',
                        weight_decay=0.001,
                        lr_policy='step',
                        warmup_step=0,
                        warmup_start_lr=0,
                        lr_ratio=1,
                        end_ratio=1,
                        base_lr=0.002,
                        max_lr=0.002,
                        momentum = 0.9,
                        max_momentum = 0,
                        cycle_steps=1856,
                        gamma=0.8, #0.1 for lr_policy
                        stepsize=100,
                        test_iter=3769,
                        test_interval=50, #set test_interval to 999999999 if not it will auto run validation
                        max_iter=1e5,
                        iter_size=1,
                        snapshot=9999,
                        display=1,
                        random_seed=0,
                        debug_info=False,
                        create_prototxt=True,
                        args=None):
        """Initialize the SolverWrapper."""
        self.test_net = test_net
        self.solver_param = caffe_pb2.SolverParameter()
        self.solver_param.train_net = train_net
        self.solver_param.test_initialization = False


        self.solver_param.display = display
        self.solver_param.warmup_step = warmup_step
        self.solver_param.warmup_start_lr = warmup_start_lr
        self.solver_param.lr_ratio = lr_ratio
        self.solver_param.end_ratio = end_ratio
        self.solver_param.base_lr = base_lr
        self.solver_param.max_lr = max_lr
        self.solver_param.cycle_steps = cycle_steps
        self.solver_param.max_momentum = max_momentum
        self.solver_param.lr_policy = lr_policy  # "fixed" #exp
        self.solver_param.gamma = gamma
        self.solver_param.stepsize = stepsize

        self.solver_param.display = display
        self.solver_param.max_iter = max_iter
        self.solver_param.iter_size = iter_size
        self.solver_param.snapshot = snapshot
        self.solver_param.snapshot_prefix = os.path.join(model_dir, prefix)
        self.solver_param.random_seed = random_seed

        self.solver_param.solver_mode = caffe_pb2.SolverParameter.GPU
        if solver_type is 'SGD':
            print("[Info] SGD Solver >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
            self.solver_param.solver_type = caffe_pb2.SolverParameter.SGD
        elif solver_type is 'ADAM':
            print("[Info] ADAM Solver >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
            self.solver_param.solver_type = caffe_pb2.SolverParameter.ADAM
        self.solver_param.momentum = momentum
        self.solver_param.momentum2 = 0.999

        self.solver_param.weight_decay = weight_decay
        self.solver_param.debug_info = debug_info

        if create_prototxt:
            solver_prototxt = get_prototxt(self.solver_param, os.path.join(model_dir, 'solver.prototxt'))
            print(solver_prototxt)

        self.solver = caffe.get_solver(solver_prototxt)
        self.test_interval = test_interval

        '''Model config parameter Initialization'''
        self.args = args
        self.model_dir, self.config_path = model_dir, config_path
        _, eval_input_cfg, model_cfg, train_cfg = load_config(self.model_dir, self.config_path)
        voxel_generator, self.target_assigner = build_network(model_cfg)
        self.dataloader, self.eval_dataset = load_dataloader(eval_input_cfg, model_cfg, voxel_generator,
                                        self.target_assigner, args = args)
        self.model_cfg = model_cfg
        # NOTE: Could have problem, if eval no good check here
        self._box_coder=self.target_assigner.box_coder
        classes_cfg = model_cfg.target_assigner.class_settings
        self._num_class = len(classes_cfg)
        self._encode_background_as_zeros = model_cfg.encode_background_as_zeros
        self._nms_class_agnostic=model_cfg.nms_class_agnostic
        self._use_multi_class_nms=[c.use_multi_class_nms for c in classes_cfg]
        self._nms_pre_max_sizes=[c.nms_pre_max_size for c in classes_cfg]
        self._multiclass_nms=all(self._use_multi_class_nms)
        self._use_sigmoid_score=model_cfg.use_sigmoid_score
        self._num_anchor_per_loc=self.target_assigner.num_anchors_per_location

        self._use_rotate_nms=[c.use_rotate_nms for c in classes_cfg]  #False for pillar, True for second
        self._nms_post_max_sizes=[c.nms_post_max_size for c in classes_cfg] #300 for pillar, 100 for second
        self._nms_score_thresholds=[c.nms_score_threshold for c in classes_cfg] # 0.4 in submit, but 0.3 can get better hard performance #pillar use 0.05, second 0.3
        self._nms_iou_thresholds=[c.nms_iou_threshold for c in classes_cfg] ## NOTE: double check #pillar use 0.5, second use 0.01
        self._post_center_range=list(model_cfg.post_center_limit_range) ## NOTE: double check
        self._use_direction_classifier=model_cfg.use_direction_classifier ## NOTE: double check
        path = pretrain["path"]
        weight = pretrain["weight"]
        skip_layer = pretrain["skip_layer"] #list skip layer name
        if path != None and weight != None:
            self.load_pretrained_caffe_weight(path, weight, skip_layer)

        #self.model_logging = log_function(self.model_dir, self.config_path)
        ################################Log#####################################
        self.model_logging = SimpleModelLog(self.model_dir)
        self.model_logging.open()

        config = pipeline_pb2.TrainEvalPipelineConfig()
        with open(self.config_path, "r") as f:
            proto_str = f.read()
            text_format.Merge(proto_str, config)

        self.model_logging.log_text(proto_str + "\n", 0, tag="config")
        self.model_logging.close()
        ########################################################################

        #Log loss
        ########################################################################
        self.log_loss_path = Path(self.model_dir) / f'log_loss.txt'
        ########################################################################

    def load_pretrained_caffe_weight(self, path, weight_path, skip_layer):
        assert isinstance(skip_layer, list) #pass skip list name inlist
        print("### Start loading pretrained caffe weights")
        old_proto_path = os.path.join(path, "train.prototxt")
        old_weight_path = os.path.join(path, weight_path)
        print("### Load old caffe model")
        old_net = caffe.Net(old_proto_path, old_weight_path, caffe.TRAIN)
        print("### Start loading model layers")
        for layer in old_net.params.keys():
            if layer in skip_layer:
                print("### Skipped layer: " + layer)
                continue
            param_length = len(old_net.params[layer])
            print("# Loading layer: " + layer)
            for index in range(param_length):
                try:
                    self.solver.net.params[layer][index].data[...] = old_net.params[layer][index].data[...]
                except Exception as e:
                    print(e)
                    print("!! Cannot load layer: " + layer)
                    continue
        print("### Finish loading pretrained model")

    def eval_model(self):

        self.model_logging.open() #logging

        cur_iter = self.solver.iter
        # if self.args["segmentation"]:
        # self.segmentation_evaluation(cur_iter)
        # else:
        self.object_detection_evaluation(cur_iter)

        self.model_logging.close()

    def train_model(self):
        cur_iter = self.solver.iter
        while cur_iter < self.solver_param.max_iter:
            for i in range(self.test_interval):
                #####For Restrore check
                if cur_iter + i >= self.solver_param.max_iter:
                    break

                self.solver.step(1)

                if (self.solver.iter-1) % self.solver_param.display == 0:
                    with open(self.log_loss_path, "a") as f:
                        lr = self.solver.lr
                        cls_loss = self.solver.net.blobs['cls_loss'].data[...][0]
                        reg_loss = self.solver.net.blobs['reg_loss'].data[...][0]
                        f.write("steps={},".format(self.solver.iter-1))
                        f.write("lr={:.8f},".format(lr))
                        f.write("cls_loss={:.3f},".format(cls_loss))
                        f.write("reg_loss={:.3f}".format(reg_loss))
                        f.write("\n")

            sut.plot_graph(self.log_loss_path, self.model_dir)
            self.eval_model()
            sut.clear_caffemodel(self.model_dir, 8) #KEPP Last 8
            cur_iter += self.test_interval

    def lr_finder(self):
        lr_finder_path = Path(self.model_dir) / f'log_lrf.txt'
        for _ in range(self.solver_param.max_iter):
            self.solver.step(1)

            if (self.solver.iter-1) % self.solver_param.display == 0:
                with open(lr_finder_path, "a") as f:
                    lr = self.solver.lr
                    cls_loss = self.solver.net.blobs['cls_loss'].data[...][0]
                    reg_loss = self.solver.net.blobs['reg_loss'].data[...][0]
                    f.write("steps={},".format(self.solver.iter-1))
                    f.write("lr={:.8f},".format(lr))
                    f.write("cls_loss={:.3f},".format(cls_loss))
                    f.write("reg_loss={:.3f}".format(reg_loss))
                    f.write("\n")

        sut.plot_graph(lr_finder_path, self.model_dir, name='Finder')

    def demo(self):
        print("[Info] Initialize test net\n")
        test_net = caffe.Net(self.test_net, caffe.TEST)
        test_net.share_with(self.solver.net)
        print("[Info] Loaded train net weights \n")
        data_dir = "./debug_tool/experiment/data/2011_09_26_drive_0009_sync/velodyne_points/data"
        point_cloud_files = os.listdir(data_dir)
        point_cloud_files.sort()
        obj_detections = []
        # Voxel generator
        pc_range = self.model_cfg.voxel_generator.point_cloud_range
        class_settings = self.model_cfg.target_assigner.class_settings[0]
        size = class_settings.anchor_generator_range.sizes
        rotations = class_settings.anchor_generator_range.rotations
        anchor_ranges = np.array(class_settings.anchor_generator_range.anchor_ranges)
        voxel_size = np.array(self.model_cfg.voxel_generator.voxel_size)
        out_size_factor = self.model_cfg.middle_feature_extractor.downsample_factor
        point_cloud_range = np.array(pc_range)
        grid_size = (
            point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size
        grid_size = np.round(grid_size).astype(np.int64)
        feature_map_size = grid_size[:2] // out_size_factor
        feature_map_size = [*feature_map_size, 1][::-1]
        for file in tqdm(point_cloud_files):
            file_path = os.path.join(data_dir, file)
            # with open(file_path, "rb") as f:
            #     points = f.read()
            points = np.fromfile(file_path, dtype = np.float32).reshape(-1,4)
            # NOTE: Prior seg preprocessing ###
            points = box_np_ops.remove_out_pc_range_points(points, pc_range)
            # Data sampling
            seg_keep_points = 20000
            points = PointRandomChoiceV2(points, seg_keep_points) #Repeat sample according points distance
            points = np.expand_dims(points, 0)
            ###
            # Anchor Generator
            anchors = box_np_ops.create_anchors_3d_range(feature_map_size, anchor_ranges,
                                                        size, rotations)
            # input
            test_net.blobs['top_prev'].reshape(*points.shape)
            test_net.blobs['top_prev'].data[...] = points
            test_net.forward()

            # segmentation output
            try:
                seg_preds = test_net.blobs['seg_output'].data[...].squeeze()
                points = np.squeeze(points)
                pred_thresh = 0.5
                pd_points = points[seg_preds >= pred_thresh,:]
                with open(os.path.join('./debug_tool/experiment',"pd_points.pkl") , 'ab') as f:
                    pickle.dump(pd_points,f)
            except Exception as e:
                pass

            with open(os.path.join('./debug_tool/experiment',"points.pkl") , 'ab') as f:
                pickle.dump(points,f)

            # Bounding box output
            cls_preds = test_net.blobs['f_cls_preds'].data[...]
            box_preds = test_net.blobs['f_box_preds'].data[...]
            preds_dict = {"box_preds":box_preds, "cls_preds":cls_preds}
            example = {"anchors": np.expand_dims(anchors, 0)}
            example = example_convert_to_torch(example, torch.float32)
            preds_dict = example_convert_to_torch(preds_dict, torch.float32)
            obj_detections += self.predict(example, preds_dict)
            pd_boxes = obj_detections[-1]["box3d_lidar"].cpu().detach().numpy()
            with open(os.path.join('./debug_tool/experiment',"pd_boxes.pkl") , 'ab') as f:
                pickle.dump(pd_boxes,f)

    ############################################################################
    # For object evaluation
    ############################################################################
    def object_detection_evaluation(self, global_step):
        print("[Info] Initialize test net\n")
        test_net = caffe.Net(self.test_net, caffe.TEST)
        test_net.share_with(self.solver.net)
        print("[Info] Loaded train net weights \n")
        data_iter=iter(self.dataloader)
        obj_detections = []
        seg_detections = []
        t = time.time()
        model_dir = str(Path(self.model_dir).resolve())
        model_dir = Path(model_dir)
        result_path = model_dir / 'results'
        result_path_step = result_path / f"step_{global_step}"
        result_path_step.mkdir(parents=True, exist_ok=True)
        for i in tqdm(range(len(data_iter))):
            example = next(data_iter)
            # points = example['seg_points'] # Pointseg
            # voxels = example['voxels']
            # coors = example['coordinates']
            # coors = example['coordinates']
            # num_points = example['num_points']
            # test_net.blobs['top_prev'].reshape(*points.shape)
            # test_net.blobs['top_prev'].data[...] = points
            # test_net.forward()

            # test_net.blobs['top_lat_feats'].reshape(*(voxels.squeeze()).shape)
            # test_net.blobs['top_lat_feats'].data[...] = voxels.squeeze()
            # voxels = voxels.squeeze()
            # with open(os.path.join('./debug',"points.pkl") , 'ab') as f:
            #     pickle.dump(voxels,f)
            # voxels = voxels[cls_out,:]
            # # print("selected voxels", voxels.shape)
            # with open(os.path.join('./debug',"seg_points.pkl") , 'ab') as f:
            #     pickle.dump(voxels,f)
            # NOTE: For voxel seg net
            # seg_points = example['seg_points'] # Pointseg
            # coords = example['coords']
            # coords_center = example['coords_center']
            # p2voxel_idx = example['p2voxel_idx']
            # test_net.blobs['seg_points'].reshape(*seg_points.shape)
            # test_net.blobs['seg_points'].data[...] = seg_points
            # test_net.blobs['coords'].reshape(*coords.shape)
            # test_net.blobs['coords'].data[...] = coords
            # test_net.blobs['p2voxel_idx'].reshape(*p2voxel_idx.shape)
            # test_net.blobs['p2voxel_idx'].data[...] = p2voxel_idx
            ##
            # NOTE: For prior seg
            voxels = example['seg_points']
            test_net.blobs['top_prev'].reshape(*voxels.shape)
            test_net.blobs['top_prev'].data[...] = voxels
            test_net.forward()
            ##
            cls_preds = test_net.blobs['f_cls_preds'].data[...]
            box_preds = test_net.blobs['f_box_preds'].data[...]
            # seg_preds = test_net.blobs['seg_output'].data[...].squeeze()
            # feat_map = test_net.blobs['p2fm'].data[...].squeeze().reshape(5,-1).transpose()
            # feat_map = feat_map[(feat_map != 0).any(-1)]
            # Reverse coordinate for anchor generator
            # anchor generated from generator shape (n_anchors, 7)
            # needed to expand dim for prediction
            # example["anchors"] = np.expand_dims(anchors, 0)
            # preds_dict = {"box_preds":box_preds.reshape(1,-1,7), "cls_preds":cls_preds.reshape(1,-1,1)}
            # example["seg_points"] = voxels
            preds_dict = {"box_preds":box_preds, "cls_preds":cls_preds}
            example = example_convert_to_torch(example, torch.float32)
            preds_dict = example_convert_to_torch(preds_dict, torch.float32)
            obj_detections += self.predict(example, preds_dict)
            # seg_detections += self.seg_predict(np.arange(0.5, 0.75, 0.05), seg_preds, example, result_path_step, vis=False)
            ################ visualization #####################
            pd_boxes = obj_detections[-1]["box3d_lidar"].cpu().detach().numpy()
            with open(os.path.join(result_path_step,"pd_boxes.pkl") , 'ab') as f:
                pickle.dump(pd_boxes,f)

        self.model_logging.log_text(
            f'\nEval at step ---------> {global_step:.2f}:\n', global_step)

        # Object detection evaluation
        result_dict = self.eval_dataset.dataset.evaluation(obj_detections,
                                                str(result_path_step))
        for k, v in result_dict["results"].items():
            self.model_logging.log_text("Evaluation {}".format(k), global_step)
            self.model_logging.log_text(v, global_step)
        self.model_logging.log_metrics(result_dict["detail"], global_step)

        # Class segmentation prediction
        # result_dict = self.total_segmentation_result(seg_detections)
        # for k, v in result_dict["results"].items():
        #     self.model_logging.log_text("Evaluation {}".format(k), global_step)
        #     self.model_logging.log_text(v, global_step)
        # self.model_logging.log_metrics(result_dict["detail"], global_step)

    def predict(self, example, preds_dict):
        """start with v1.6.0, this function don't contain any kitti-specific code.
        Returns:
            predict: list of pred_dict.
            pred_dict: {
                box3d_lidar: [N, 7] 3d box.
                scores: [N]
                label_preds: [N]
                metadata: meta-data which contains dataset-specific information.
                    for kitti, it contains image idx (label idx),
                    for nuscenes, sample_token is saved in it.
            }
        """
        batch_size = example['anchors'].shape[0]
        # NOTE: for voxel seg net
        # batch_size = example['coords_center'].shape[0]

        # batch_size = example['seg_points'].shape[0]
        if "metadata" not in example or len(example["metadata"]) == 0:
            meta_list = [None] * batch_size
        else:
            meta_list = example["metadata"]

        batch_anchors = example["anchors"].view(batch_size, -1, example["anchors"].shape[-1])
        # NOTE: for voxel seg net
        # batch_anchors = example["coords_center"].view(batch_size, -1, example["coords_center"].shape[-1])

        # batch_anchors = example["seg_points"].view(batch_size, -1, example["seg_points"].shape[-1])
        if "anchors_mask" not in example:
            batch_anchors_mask = [None] * batch_size
        else:
            batch_anchors_mask = example["anchors_mask"].view(batch_size, -1)

        t = time.time()
        batch_box_preds = preds_dict["box_preds"]
        batch_cls_preds = preds_dict["cls_preds"]
        batch_box_preds = batch_box_preds.view(batch_size, -1,
                                               self._box_coder.code_size)
        num_class_with_bg = self._num_class
        if not self._encode_background_as_zeros:
            num_class_with_bg = self._num_class + 1

        batch_cls_preds = batch_cls_preds.view(batch_size, -1,
                                               num_class_with_bg)
        # NOTE: Original decoding
        batch_box_preds = self._box_coder.decode_torch(batch_box_preds,
                                                       batch_anchors)
        # NOTE: For voxel seg net and point wise prediction
        # batch_box_preds = box_np_ops.fcos_box_decoder_v2_torch(batch_anchors,
        #                                               batch_box_preds)
        if self._use_direction_classifier:
            batch_dir_preds = preds_dict["dir_cls_preds"]
            batch_dir_preds = batch_dir_preds.view(batch_size, -1,
                                                   self._num_direction_bins)
        else:
            batch_dir_preds = [None] * batch_size
        predictions_dicts = []
        post_center_range = None
        if len(self._post_center_range) > 0:
            post_center_range = torch.tensor(
                self._post_center_range,
                dtype=batch_box_preds.dtype,
                device=batch_box_preds.device).float()
        for box_preds, cls_preds, dir_preds, a_mask, meta in zip(
                batch_box_preds, batch_cls_preds, batch_dir_preds,
                batch_anchors_mask, meta_list):
            if a_mask is not None:
                box_preds = box_preds[a_mask]
                cls_preds = cls_preds[a_mask]
            box_preds = box_preds.float()
            cls_preds = cls_preds.float()
            if self._use_direction_classifier:
                if a_mask is not None:
                    dir_preds = dir_preds[a_mask]
                dir_labels = torch.max(dir_preds, dim=-1)[1]
            if self._encode_background_as_zeros:
                # this don't support softmax
                assert self._use_sigmoid_score is True
                total_scores = torch.sigmoid(cls_preds)
            else:
                # encode background as first element in one-hot vector
                if self._use_sigmoid_score:
                    total_scores = torch.sigmoid(cls_preds)[..., 1:]
                else:
                    total_scores = F.softmax(cls_preds, dim=-1)[..., 1:]
            # Apply NMS in birdeye view
            if self._use_rotate_nms:
                nms_func = box_torch_ops.rotate_nms
            else:
                nms_func = box_torch_ops.nms
            feature_map_size_prod = batch_box_preds.shape[
                1] // self.target_assigner.num_anchors_per_location
            if self._multiclass_nms:
                assert self._encode_background_as_zeros is True
                boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
                if not self._use_rotate_nms:
                    box_preds_corners = box_torch_ops.center_to_corner_box2d(
                        boxes_for_nms[:, :2], boxes_for_nms[:, 2:4],
                        boxes_for_nms[:, 4])
                    boxes_for_nms = box_torch_ops.corner_to_standup_nd(
                        box_preds_corners)

                selected_boxes, selected_labels, selected_scores = [], [], []
                selected_dir_labels = []

                scores = total_scores
                boxes = boxes_for_nms
                selected_per_class = []
                score_threshs = self._nms_score_thresholds
                pre_max_sizes = self._nms_pre_max_sizes
                post_max_sizes = self._nms_post_max_sizes
                iou_thresholds = self._nms_iou_thresholds
                for class_idx, score_thresh, pre_ms, post_ms, iou_th in zip(
                        range(self._num_class),
                        score_threshs,
                        pre_max_sizes, post_max_sizes, iou_thresholds):
                    if self._nms_class_agnostic:
                        class_scores = total_scores.view(
                            feature_map_size_prod, -1,
                            self._num_class)[..., class_idx]
                        class_scores = class_scores.contiguous().view(-1)
                        class_boxes_nms = boxes.view(-1,
                                                     boxes_for_nms.shape[-1])
                        class_boxes = box_preds
                        class_dir_labels = dir_labels
                    else:
                        anchors_range = self.target_assigner.anchors_range(class_idx)
                        class_scores = total_scores.view(
                            -1,
                            self._num_class)[anchors_range[0]:anchors_range[1], class_idx]
                        class_boxes_nms = boxes.view(-1,
                            boxes_for_nms.shape[-1])[anchors_range[0]:anchors_range[1], :]
                        class_scores = class_scores.contiguous().view(-1)
                        class_boxes_nms = class_boxes_nms.contiguous().view(
                            -1, boxes_for_nms.shape[-1])
                        class_boxes = box_preds.view(-1,
                            box_preds.shape[-1])[anchors_range[0]:anchors_range[1], :]
                        class_boxes = class_boxes.contiguous().view(
                            -1, box_preds.shape[-1])
                        if self._use_direction_classifier:
                            class_dir_labels = dir_labels.view(-1)[anchors_range[0]:anchors_range[1]]
                            class_dir_labels = class_dir_labels.contiguous(
                            ).view(-1)
                    if score_thresh > 0.0:
                        class_scores_keep = class_scores >= score_thresh
                        if class_scores_keep.shape[0] == 0:
                            selected_per_class.append(None)
                            continue
                        class_scores = class_scores[class_scores_keep]
                    if class_scores.shape[0] != 0:
                        if score_thresh > 0.0:
                            class_boxes_nms = class_boxes_nms[
                                class_scores_keep]
                            class_boxes = class_boxes[class_scores_keep]
                            class_dir_labels = class_dir_labels[
                                class_scores_keep]
                        keep = nms_func(class_boxes_nms, class_scores, pre_ms,
                                        post_ms, iou_th)
                        if keep.shape[0] != 0:
                            selected_per_class.append(keep)
                        else:
                            selected_per_class.append(None)
                    else:
                        selected_per_class.append(None)
                    selected = selected_per_class[-1]

                    if selected is not None:
                        selected_boxes.append(class_boxes[selected])
                        selected_labels.append(
                            torch.full([class_boxes[selected].shape[0]],
                                       class_idx,
                                       dtype=torch.int64,
                                       device=box_preds.device))
                        if self._use_direction_classifier:
                            selected_dir_labels.append(
                                class_dir_labels[selected])
                        selected_scores.append(class_scores[selected])
                selected_boxes = torch.cat(selected_boxes, dim=0)
                selected_labels = torch.cat(selected_labels, dim=0)
                selected_scores = torch.cat(selected_scores, dim=0)
                if self._use_direction_classifier:
                    selected_dir_labels = torch.cat(selected_dir_labels, dim=0)
            else:
                # get highest score per prediction, than apply nms
                # to remove overlapped box.
                if num_class_with_bg == 1:
                    top_scores = total_scores.squeeze(-1)
                    top_labels = torch.zeros(
                        total_scores.shape[0],
                        device=total_scores.device,
                        dtype=torch.long)
                else:
                    top_scores, top_labels = torch.max(
                        total_scores, dim=-1)
                if self._nms_score_thresholds[0] > 0.0:
                    top_scores_keep = top_scores >= self._nms_score_thresholds[0]
                    top_scores = top_scores.masked_select(top_scores_keep)
                    print("nms_thres is {} selected {} cars ".format(self._nms_score_thresholds, len(top_scores)))
                if top_scores.shape[0] != 0:
                    if self._nms_score_thresholds[0] > 0.0:
                        box_preds = box_preds[top_scores_keep]
                        if self._use_direction_classifier:
                            dir_labels = dir_labels[top_scores_keep]
                        top_labels = top_labels[top_scores_keep]
                    boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
                    if not self._use_rotate_nms:
                        box_preds_corners = box_torch_ops.center_to_corner_box2d(
                            boxes_for_nms[:, :2], boxes_for_nms[:, 2:4],
                            boxes_for_nms[:, 4])
                        boxes_for_nms = box_torch_ops.corner_to_standup_nd(
                            box_preds_corners)
                    # the nms in 3d detection just remove overlap boxes.
                    selected = nms_func(
                        boxes_for_nms,
                        top_scores,
                        pre_max_size=self._nms_pre_max_sizes[0],
                        post_max_size=self._nms_post_max_sizes[0],
                        iou_threshold=self._nms_iou_thresholds[0],
                    )
                else:
                    selected = []
                # if selected is not None:
                selected_boxes = box_preds[selected]
                print("IoU_thresh is {} remove {} overlap".format(self._nms_iou_thresholds, (len(box_preds)-len(selected_boxes))))
                #Eval debug
                if "gt_num" in example:
                    eval_idx = example['metadata'][0]['image_idx']
                    eval_obj_num = example['gt_num']
                    detetion_error = eval_obj_num-len(selected_boxes)
                    print("Eval img_{} have {} Object, detected {} Object, error {} ".format(eval_idx, eval_obj_num, len(selected_boxes), detetion_error))

                if self._use_direction_classifier:
                    selected_dir_labels = dir_labels[selected]
                selected_labels = top_labels[selected]
                selected_scores = top_scores[selected]
            # finally generate predictions.
            if selected_boxes.shape[0] != 0:
                box_preds = selected_boxes
                scores = selected_scores
                label_preds = selected_labels
                if self._use_direction_classifier:
                    dir_labels = selected_dir_labels
                    period = (2 * np.pi / self._num_direction_bins)
                    dir_rot = box_torch_ops.limit_period(
                        box_preds[..., 6] - self._dir_offset,
                        self._dir_limit_offset, period)
                    box_preds[
                        ...,
                        6] = dir_rot + self._dir_offset + period * dir_labels.to(
                            box_preds.dtype)
                final_box_preds = box_preds
                final_scores = scores
                final_labels = label_preds
                if post_center_range is not None:
                    mask = (final_box_preds[:, :3] >=
                            post_center_range[:3]).all(1)
                    mask &= (final_box_preds[:, :3] <=
                             post_center_range[3:]).all(1)
                    predictions_dict = {
                        "box3d_lidar": final_box_preds[mask],
                        "scores": final_scores[mask],
                        "label_preds": label_preds[mask],
                        "metadata": meta,
                    }
                else:
                    predictions_dict = {
                        "box3d_lidar": final_box_preds,
                        "scores": final_scores,
                        "label_preds": label_preds,
                        "metadata": meta,
                    }
            else:
                dtype = batch_box_preds.dtype
                device = batch_box_preds.device
                predictions_dict = {
                    "box3d_lidar":
                    torch.zeros([0, box_preds.shape[-1]],
                                dtype=dtype,
                                device=device),
                    "scores":
                    torch.zeros([0], dtype=dtype, device=device),
                    "label_preds":
                    torch.zeros([0], dtype=top_labels.dtype, device=device),
                    "metadata":
                    meta,
                }

            predictions_dicts.append(predictions_dict)

        return predictions_dicts

    ############################################################################
    # For segmentation evaluation
    ############################################################################
    def segmentation_evaluation(self, global_step):
        print("Initialize test net")
        test_net = caffe.Net(self.test_net, caffe.TEST)
        print("Load train net weights")
        test_net.share_with(self.solver.net)
        _, eval_input_cfg, model_cfg, train_cfg = load_config(self.model_dir, self.config_path)
        voxel_generator, self.target_assigner = build_network(model_cfg)
        ## TODO:
        dataloader, _= load_dataloader(eval_input_cfg, model_cfg,
                                                        voxel_generator,
                                                        self.target_assigner,
                                                        args = self.args)
        data_iter=iter(dataloader)


        model_dir = str(Path(self.model_dir).resolve())
        model_dir = Path(model_dir)
        result_path = model_dir / 'results'
        result_path_step = result_path / f"step_{global_step}"
        result_path_step.mkdir(parents=True, exist_ok=True)

        detections = []
        detections_voc = []
        detections_05 = []
        for i in tqdm(range(len(data_iter))):
            example = next(data_iter)
            points = example['seg_points']
            test_net.blobs['top_prev'].reshape(*points.shape)
            test_net.blobs['top_prev'].data[...] = points
            test_net.forward()

            #seg_cls_pred output shape (1,1,1,16000)
            # seg_cls_pred = test_net.blobs["output"].data[...].squeeze()
            seg_cls_pred = test_net.blobs['seg_output'].data[...].squeeze()
            detections += self.seg_predict(np.arange(0.5, 0.75, 0.05), seg_cls_pred, example, result_path_step, vis=False)
            # detections_voc += self.seg_predict([0.1, 0.3, 0.5, 0.7, 0.9], seg_cls_pred, example, result_path_step, vis=False)
            # detections_05 += self.seg_predict([0.5], seg_cls_pred, example, result_path_step, vis=False)

        result_dict = self.total_segmentation_result(detections)
        # result_dict_voc = self.total_segmentation_result(detections_voc)
        # result_dict_05 = self.total_segmentation_result(detections_05)

        self.model_logging.log_text(
            f'\nEval at step ---------> {global_step:.2f}:\n', global_step)
        for k, v in result_dict["results"].items():
            self.model_logging.log_text("Evaluation {}".format(k), global_step)
            self.model_logging.log_text(v, global_step)
        self.model_logging.log_metrics(result_dict["detail"], global_step)

        # print("\n")
        # for k, v in result_dict_voc["results"].items():
        #     self.model_logging.log_text("Evaluation VOC {}".format(k), global_step)
        #     self.model_logging.log_text(v, global_step)
        # self.model_logging.log_metrics(result_dict_voc["detail"], global_step)
        # print("\n")
        # for k, v in result_dict_05["results"].items():
        #     self.model_logging.log_text("Evaluation 0.5 {}".format(k), global_step)
        #     self.model_logging.log_text(v, global_step)
        # self.model_logging.log_metrics(result_dict_05["detail"], global_step)


    def seg_predict(self, thresh_range, pred, example, result_path_step, vis=False):
        # pred = 1 / (1 + np.exp(-pred)) #sigmoid
        gt = example['seg_labels']
        ############### Params ###############
        eps = 1e-5

        cls_thresh_range = thresh_range
        pos_class = 1 # Car
        list_score = []
        cls_thresh_list = []
        ############### Params ###############

        pred, gt = np.array(pred), np.array(gt)
        gt = np.squeeze(gt)
        labels = np.unique(gt)
        ##################Traverse cls_thresh###################################
        for cls_thresh in cls_thresh_range:
            scores = {}
            _pred = np.where(pred>cls_thresh, 1, 0)

            TPs = np.sum((gt == pos_class) * (_pred == pos_class))
            TNs = np.sum((gt != pos_class) * (_pred != pos_class))
            FPs = np.sum((gt != pos_class) * (_pred == pos_class))
            FNs = np.sum((gt == pos_class) * (_pred != pos_class))
            TargetTotal= np.sum(gt == pos_class)

            scores['accuracy'] = TPs / (TargetTotal + eps)
            scores['class_iou'] = TPs / ((TPs + FNs + FPs) + eps)
            scores['precision'] = TPs / ((TPs + FPs) + eps)

            cls_thresh_list.append(scores)

        ###################Found best cls_thresh################################
        thresh_accuracy=[]
        thresh_class_iou=[]
        thresh_precision=[]
        max_class_iou = 0
        max_class_iou_thresh = 0

        for thresh, cls_list in zip(cls_thresh_range, cls_thresh_list):
            accuracy = cls_list['accuracy']
            class_iou = cls_list['class_iou']
            precision = cls_list['precision']
            thresh_accuracy.append(accuracy)
            thresh_class_iou.append(class_iou)
            thresh_precision.append(precision)

            if class_iou > max_class_iou:
                max_class_iou = class_iou
                max_class_iou_thresh = thresh

        scores['accuracy'] = np.mean(np.array(thresh_accuracy))
        scores['class_iou'] = np.mean(np.array(thresh_class_iou))
        scores['precision'] = np.mean(np.array(thresh_precision))
        scores['best_thresh'] = max_class_iou_thresh #choose the max_thresh for seg

        ############################pred_thresh#################################
        pred_thresh = self._nms_score_thresholds[0]

        points = example['seg_points']
        points = np.squeeze(points)
        pd_points = points[pred >= pred_thresh]

        with open(os.path.join(result_path_step, "gt_points.pkl"), 'ab') as f:
            pickle.dump(pd_points,f)

        if vis:
            image_idx = example['image_idx']
            gt_boxes = example['gt_boxes']
            with open(os.path.join(result_path_step, "image_idx.pkl"), 'ab') as f:
                pickle.dump(image_idx,f)
            with open(os.path.join(result_path_step, "points.pkl"), 'ab') as f:
                pickle.dump(points,f)
            with open(os.path.join(result_path_step, "gt_boxes.pkl"), 'ab') as f:
                pickle.dump(gt_boxes,f)

        list_score.append(scores)

        return list_score

    def total_segmentation_result(self, detections):
        avg_accuracy=[]
        avg_class_iou=[]
        avg_precision=[]
        avg_thresh=[]
        for det in detections:
            avg_accuracy.append(det['accuracy'])
            avg_class_iou.append(det['class_iou'])
            avg_precision.append(det['precision'])
            avg_thresh.append(det['best_thresh'])

        avg_accuracy = np.sum(np.array(avg_accuracy)) / np.sum((np.array(avg_accuracy)!=0)) #divided by none zero no Cars
        avg_class_iou = np.sum(np.array(avg_class_iou)) / np.sum((np.array(avg_class_iou)!=0))  #divided by none zero no Cars
        avg_precision = np.sum(np.array(avg_precision)) / np.sum((np.array(avg_precision)!=0))  #divided by none zero no Cars
        avg_thresh = np.sum(np.array(avg_thresh)) / np.sum((np.array(avg_thresh)!=0))  #divided by none zero no Cars

        print('-------------------- Summary --------------------')

        result_dict = {}
        result_dict['results'] ={"Summary" : 'Threshhold: {:.3f} \n'.format(avg_thresh) + \
                                             'Accuracy: {:.3f} \n'.format(avg_accuracy) + \
                                             'Car IoU: {:.3f} \n'.format(avg_class_iou) + \
                                             'Precision: {:.3f} \n'.format(avg_precision)
                                             }

        result_dict['detail'] = {"Threshold" : avg_thresh,
                                 "Accuracy" : avg_accuracy,
                                 "Car IoU": avg_class_iou,
                                 "Precision": avg_precision,
                                 }
        return result_dict
예제 #8
0
def train(config_path,
          model_dir,
          result_path=None,
          create_folder=False,
          display_step=50,
          pretrained_path=None,
          pretrained_include=None,
          pretrained_exclude=None,
          freeze_include=None,
          freeze_exclude=None,
          multi_gpu=False,
          measure_time=False,
          resume=False):
    """train a PointPillars model specified by a config file.
    """
    torch.cuda.empty_cache()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_dir = str(Path(model_dir).resolve())
    if create_folder:
        if Path(model_dir).exists():
            model_dir = torchplus.train.create_folder(model_dir)
    model_dir = Path(model_dir)
    if not resume and model_dir.exists():
        raise ValueError("model dir exists and you don't specify resume.")
    model_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'

    config, proto_str = load_config(model_dir, config_path)

    input_cfg = config.train_input_reader
    model_cfg = config.model.second
    train_cfg = config.train_config
    target_assigner_cfg = model_cfg.target_assigner

    voxel_generator = voxel_builder.build(model_cfg.voxel_generator)
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = box_coder_builder.build(model_cfg.box_coder)

    target_assigner = target_assigner_builder.build(target_assigner_cfg,
                                                    bv_range, box_coder)
    box_coder.custom_ndim = target_assigner._anchor_generators[0].custom_ndim

    net = PointPillarsNet(1,
                          voxel_generator.grid_size,
                          target_assigner.num_anchors_per_location,
                          target_assigner.box_coder.code_size,
                          with_distance=False).to(device)
    kaiming_init(net, 1.0)

    net_loss = build_net_loss(model_cfg, target_assigner).to(device)
    net_loss.clear_global_step()
    net_loss.clear_metrics()
    # print("num parameters:", len(list(net.parameters())))

    load_pretrained_model(net, pretrained_path, pretrained_include,
                          pretrained_exclude, freeze_include, freeze_exclude)

    if resume:
        torchplus.train.try_restore_latest_checkpoints(model_dir, [net])

    amp_optimizer, lr_scheduler = create_optimizer(model_dir, train_cfg, net)

    collate_fn = merge_second_batch
    num_gpu = 1

    ######################
    # PREPARE INPUT
    ######################
    dataset = input_reader_builder.build(input_cfg,
                                         model_cfg,
                                         training=True,
                                         voxel_generator=voxel_generator,
                                         target_assigner=target_assigner,
                                         multi_gpu=multi_gpu)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=input_cfg.batch_size * num_gpu,
        shuffle=True,
        num_workers=input_cfg.preprocess.num_workers * num_gpu,
        pin_memory=False,
        collate_fn=collate_fn,
        worker_init_fn=_worker_init_fn,
        drop_last=not multi_gpu)

    ######################
    # TRAINING
    ######################
    model_logging = SimpleModelLog(model_dir)
    model_logging.open()
    model_logging.log_text(proto_str + "\n", 0, tag="config")

    start_step = net_loss.get_global_step()
    total_step = train_cfg.steps
    t = time.time()
    steps_per_eval = train_cfg.steps_per_eval
    clear_metrics_every_epoch = train_cfg.clear_metrics_every_epoch

    amp_optimizer.zero_grad()
    step_times = []
    step = start_step
    best_mAP = 0
    epoch = 0

    net.train()
    net_loss.train()
    try:
        while True:
            if clear_metrics_every_epoch:
                net_loss.clear_metrics()
            for example in dataloader:
                lr_scheduler.step(net_loss.get_global_step())
                time_metrics = example["metrics"]
                example.pop("metrics")
                example_torch = example_convert_to_torch(example, float_dtype)

                batch_size = example_torch["anchors"].shape[0]

                coors = example_torch["coordinates"]
                input_features = compute_model_input(
                    voxel_generator.voxel_size,
                    voxel_generator.point_cloud_range,
                    with_distance=False,
                    voxels=example_torch['voxels'],
                    num_voxels=example_torch['num_points'],
                    coors=coors)
                # input_features = reshape_input(batch_size, input_features, coors, voxel_generator.grid_size)
                input_features = reshape_input1(input_features)

                net.batch_size = batch_size
                preds_list = net(input_features, coors)

                ret_dict = net_loss(example_torch, preds_list)

                cls_preds = ret_dict["cls_preds"]
                loss = ret_dict["loss"].mean()
                cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                cls_pos_loss = ret_dict["cls_pos_loss"].mean()
                cls_neg_loss = ret_dict["cls_neg_loss"].mean()
                loc_loss = ret_dict["loc_loss"]
                cls_loss = ret_dict["cls_loss"]

                cared = ret_dict["cared"]
                labels = example_torch["labels"]

                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
                amp_optimizer.step()
                amp_optimizer.zero_grad()

                net_loss.update_global_step()

                net_metrics = net_loss.update_metrics(cls_loss_reduced,
                                                      loc_loss_reduced,
                                                      cls_preds, labels, cared)

                step_time = (time.time() - t)
                step_times.append(step_time)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                if 'anchors_mask' not in example_torch:
                    num_anchors = example_torch['anchors'].shape[1]
                else:
                    num_anchors = int(example_torch['anchors_mask'][0].sum())
                global_step = net_loss.get_global_step()

                if global_step % display_step == 0:
                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]
                    metrics["runtime"] = {
                        "step": global_step,
                        "steptime": np.mean(step_times),
                    }
                    metrics["runtime"].update(time_metrics[0])
                    step_times = []
                    metrics.update(net_metrics)
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())
                    if model_cfg.use_direction_classifier:
                        dir_loss_reduced = ret_dict["dir_loss_reduced"].mean()
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())

                    metrics["misc"] = {
                        "num_vox": int(example_torch["voxels"].shape[0]),
                        "num_pos": int(num_pos),
                        "num_neg": int(num_neg),
                        "num_anchors": int(num_anchors),
                        "lr": float(amp_optimizer.lr),
                        "mem_usage": psutil.virtual_memory().percent,
                    }
                    model_logging.log_metrics(metrics, global_step)
                step += 1
            epoch += 1
            if epoch % 2 == 0:
                global_step = net_loss.get_global_step()
                torchplus.train.save_models(model_dir, [net, amp_optimizer],
                                            global_step)
                net.eval()
                net_loss.eval()
                best_mAP = evaluate(net, net_loss, best_mAP, voxel_generator,
                                    target_assigner, config, model_logging,
                                    model_dir, result_path)
                net.train()
                net_loss.train()
                if epoch > 100:
                    break
            if epoch > 100:
                break
    except Exception as e:
        print(json.dumps(example["metadata"], indent=2))
        model_logging.log_text(str(e), step)
        model_logging.log_text(json.dumps(example["metadata"], indent=2), step)
        torchplus.train.save_models(model_dir, [net, amp_optimizer], step)
        raise e
    finally:
        model_logging.close()
    torchplus.train.save_models(model_dir, [net, amp_optimizer],
                                net_loss.get_global_step())