Beispiel #1
0
    def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, "flag")
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.num_samples = 0
        for i, j in enumerate(self.group_sizes):
            self.num_samples += (int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu)
        self.total_size = self.num_samples * self.num_replicas
Beispiel #2
0
def get_root_logger(log_level=logging.INFO):
    logger = logging.getLogger()
    if not logger.hasHandlers():
        logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s",
                            level=log_level)
    rank, _ = get_dist_info()
    if rank != 0:
        logger.setLevel("ERROR")
    return logger
Beispiel #3
0
def build_dataloader(dataset,
                     batch_size,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     **kwargs):
    shuffle = kwargs.get("shuffle", True)
    if dist:
        rank, world_size = get_dist_info()
        # sampler = DistributedSamplerV2(dataset,
        #                      num_replicas=world_size,
        #                      rank=rank,
        #                      shuffle=shuffle)
        if shuffle:
            sampler = DistributedGroupSampler(dataset, batch_size, world_size,
                                              rank)
        else:
            sampler = DistributedSampler(dataset,
                                         world_size,
                                         rank,
                                         shuffle=False)
        batch_size = batch_size
        num_workers = workers_per_gpu
    else:
        sampler = GroupSampler(dataset, batch_size) if shuffle else None
        sampler = None
        batch_size = num_gpus * batch_size
        num_workers = num_gpus * workers_per_gpu

    # TODO change pin_memory
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=num_workers,
        collate_fn=collate_kitti,
        # pin_memory=True,
        pin_memory=False,
    )

    return data_loader
Beispiel #4
0
def main():
    args = parse_args()

    assert args.out or args.show or args.json_out, (
        "Please specify at least one operation (save or show the results) "
        'with the argument "--out" or "--show" or "--json_out"'
    )

    if args.out is not None and not args.out.endswith((".pkl", ".pickle")):
        raise ValueError("The output file must be a pkl file.")

    if args.json_out is not None and args.json_out.endswith(".json"):
        args.json_out = args.json_out[:-5]

    cfg = torchie.Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get("cudnn_benchmark", False):
        torch.backends.cudnn.benchmark = True

    # cfg.model.pretrained = None
    cfg.data.test.test_mode = True
#     cfg.data.val.test_mode = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == "none":
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # build the dataloader
    # TODO: support multiple images per gpu (only minor changes are needed)
    dataset = build_dataset(cfg.data.test)
#     dataset = build_dataset(cfg.data.val)
    data_loader = build_dataloader(
        dataset,
        batch_size=cfg.data.samples_per_gpu,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=distributed,
        shuffle=False,
    )

    # build the model and load checkpoint
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)

    checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu")
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if "CLASSES" in checkpoint["meta"]:
        model.CLASSES = checkpoint["meta"]["CLASSES"]
    else:
        model.CLASSES = dataset.CLASSES

    model = MegDataParallel(model, device_ids=[0])
    result_dict, detections = test(
        data_loader, model, save_dir=None, distributed=distributed
    )

    for k, v in result_dict["results"].items():
        print(f"Evaluation {k}: {v}")

    rank, _ = get_dist_info()
    if args.out and rank == 0:
        print("\nwriting results to {}".format(args.out))
        torchie.dump(detections, args.out)

    if args.txt_result:
        res_dir = os.path.join(os.getcwd(), "predictions")
        for dt in detections:
            with open(
                os.path.join(res_dir, "%06d.txt" % int(dt["metadata"]["token"])), "w"
            ) as fout:
                lines = kitti.annos_to_kitti_label(dt)
                for line in lines:
                    fout.write(line + "\n")

        ap_result_str, ap_dict = kitti_evaluate(
            "/data/Datasets/KITTI/Kitti/object/training/label_2",
            res_dir,
            label_split_file="/data/Datasets/KITTI/Kitti/ImageSets/val.txt",
            current_class=0,
        )

        print(ap_result_str)
Beispiel #5
0
def convert_state_dict(module, state_dict, strict=False, logger=None):
    """Load state_dict into a module
    """
    unexpected_keys = []
    shape_mismatch_pairs = []

    own_state = module.state_dict()
    for name, param in state_dict.items():
        # a hacky fixed to load a new voxelnet
        if name not in own_state:
            if name[:20] == 'backbone.middle_conv':
                index = int(name[20:].split('.')[1])

                if index in [0, 1, 2]:
                    new_name = 'backbone.conv_input.{}.{}'.format(
                        str(index), name[23:])
                elif index in [3, 4]:
                    new_name = 'backbone.conv1.{}.{}'.format(
                        str(index - 3), name[23:])
                elif index in [5, 6, 7, 8, 9]:
                    new_name = 'backbone.conv2.{}.{}'.format(
                        str(index - 5), name[23:])
                elif index in [10, 11, 12, 13, 14]:
                    new_name = 'backbone.conv3.{}.{}'.format(
                        str(index - 10), name[24:])
                elif index in [15, 16, 17, 18, 19]:
                    new_name = 'backbone.conv4.{}.{}'.format(
                        str(index - 15), name[24:])
                elif index in [20, 21, 22]:
                    new_name = 'backbone.extra_conv.{}.{}'.format(
                        str(index - 20), name[24:])
                else:
                    raise NotImplementedError(index)

                if param.size() != own_state[new_name].size():
                    shape_mismatch_pairs.append(
                        [name, own_state[name].size(),
                         param.size()])
                    continue

                own_state[new_name].copy_(param)
                print("load {}'s param from {}".format(new_name, name))
                continue

            unexpected_keys.append(name)
            continue
        if isinstance(param, torch.nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        if param.size() != own_state[name].size():
            shape_mismatch_pairs.append(
                [name, own_state[name].size(),
                 param.size()])
            continue
        own_state[name].copy_(param)

    all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
    # ignore "num_batches_tracked" of BN layers
    missing_keys = [
        key for key in all_missing_keys if "num_batches_tracked" not in key
    ]

    err_msg = []
    if unexpected_keys:
        err_msg.append("unexpected key in source state_dict: {}\n".format(
            ", ".join(unexpected_keys)))
    if missing_keys:
        err_msg.append("missing keys in source state_dict: {}\n".format(
            ", ".join(missing_keys)))
    if shape_mismatch_pairs:
        mismatch_info = "these keys have mismatched shape:\n"
        header = ["key", "expected shape", "loaded shape"]
        table_data = [header] + shape_mismatch_pairs
        table = AsciiTable(table_data)
        err_msg.append(mismatch_info + table.table)

    rank, _ = get_dist_info()
    if len(err_msg) > 0 and rank == 0:
        err_msg.insert(
            0, "The model and loaded state dict do not match exactly\n")
        err_msg = "\n".join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)