def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None): _rank, _num_replicas = get_dist_info() if num_replicas is None: num_replicas = _num_replicas if rank is None: rank = _rank self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas self.rank = rank self.epoch = 0 assert hasattr(self.dataset, "flag") self.flag = self.dataset.flag self.group_sizes = np.bincount(self.flag) self.num_samples = 0 for i, j in enumerate(self.group_sizes): self.num_samples += (int( math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / self.num_replicas)) * self.samples_per_gpu) self.total_size = self.num_samples * self.num_replicas
def get_root_logger(log_level=logging.INFO): logger = logging.getLogger() if not logger.hasHandlers(): logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=log_level) rank, _ = get_dist_info() if rank != 0: logger.setLevel("ERROR") return logger
def build_dataloader(dataset, batch_size, workers_per_gpu, num_gpus=1, dist=True, **kwargs): shuffle = kwargs.get("shuffle", True) if dist: rank, world_size = get_dist_info() # sampler = DistributedSamplerV2(dataset, # num_replicas=world_size, # rank=rank, # shuffle=shuffle) if shuffle: sampler = DistributedGroupSampler(dataset, batch_size, world_size, rank) else: sampler = DistributedSampler(dataset, world_size, rank, shuffle=False) batch_size = batch_size num_workers = workers_per_gpu else: sampler = GroupSampler(dataset, batch_size) if shuffle else None sampler = None batch_size = num_gpus * batch_size num_workers = num_gpus * workers_per_gpu # TODO change pin_memory data_loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, shuffle=(sampler is None), num_workers=num_workers, collate_fn=collate_kitti, # pin_memory=True, pin_memory=False, ) return data_loader
def main(): args = parse_args() assert args.out or args.show or args.json_out, ( "Please specify at least one operation (save or show the results) " 'with the argument "--out" or "--show" or "--json_out"' ) if args.out is not None and not args.out.endswith((".pkl", ".pickle")): raise ValueError("The output file must be a pkl file.") if args.json_out is not None and args.json_out.endswith(".json"): args.json_out = args.json_out[:-5] cfg = torchie.Config.fromfile(args.config) # set cudnn_benchmark if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True # cfg.model.pretrained = None cfg.data.test.test_mode = True # cfg.data.val.test_mode = True # init distributed env first, since logger depends on the dist info. if args.launcher == "none": distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) dataset = build_dataset(cfg.data.test) # dataset = build_dataset(cfg.data.val) data_loader = build_dataloader( dataset, batch_size=cfg.data.samples_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False, ) # build the model and load checkpoint model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") # old versions did not save class info in checkpoints, this walkaround is # for backward compatibility if "CLASSES" in checkpoint["meta"]: model.CLASSES = checkpoint["meta"]["CLASSES"] else: model.CLASSES = dataset.CLASSES model = MegDataParallel(model, device_ids=[0]) result_dict, detections = test( data_loader, model, save_dir=None, distributed=distributed ) for k, v in result_dict["results"].items(): print(f"Evaluation {k}: {v}") rank, _ = get_dist_info() if args.out and rank == 0: print("\nwriting results to {}".format(args.out)) torchie.dump(detections, args.out) if args.txt_result: res_dir = os.path.join(os.getcwd(), "predictions") for dt in detections: with open( os.path.join(res_dir, "%06d.txt" % int(dt["metadata"]["token"])), "w" ) as fout: lines = kitti.annos_to_kitti_label(dt) for line in lines: fout.write(line + "\n") ap_result_str, ap_dict = kitti_evaluate( "/data/Datasets/KITTI/Kitti/object/training/label_2", res_dir, label_split_file="/data/Datasets/KITTI/Kitti/ImageSets/val.txt", current_class=0, ) print(ap_result_str)
def convert_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict into a module """ unexpected_keys = [] shape_mismatch_pairs = [] own_state = module.state_dict() for name, param in state_dict.items(): # a hacky fixed to load a new voxelnet if name not in own_state: if name[:20] == 'backbone.middle_conv': index = int(name[20:].split('.')[1]) if index in [0, 1, 2]: new_name = 'backbone.conv_input.{}.{}'.format( str(index), name[23:]) elif index in [3, 4]: new_name = 'backbone.conv1.{}.{}'.format( str(index - 3), name[23:]) elif index in [5, 6, 7, 8, 9]: new_name = 'backbone.conv2.{}.{}'.format( str(index - 5), name[23:]) elif index in [10, 11, 12, 13, 14]: new_name = 'backbone.conv3.{}.{}'.format( str(index - 10), name[24:]) elif index in [15, 16, 17, 18, 19]: new_name = 'backbone.conv4.{}.{}'.format( str(index - 15), name[24:]) elif index in [20, 21, 22]: new_name = 'backbone.extra_conv.{}.{}'.format( str(index - 20), name[24:]) else: raise NotImplementedError(index) if param.size() != own_state[new_name].size(): shape_mismatch_pairs.append( [name, own_state[name].size(), param.size()]) continue own_state[new_name].copy_(param) print("load {}'s param from {}".format(new_name, name)) continue unexpected_keys.append(name) continue if isinstance(param, torch.nn.Parameter): # backwards compatibility for serialized parameters param = param.data if param.size() != own_state[name].size(): shape_mismatch_pairs.append( [name, own_state[name].size(), param.size()]) continue own_state[name].copy_(param) all_missing_keys = set(own_state.keys()) - set(state_dict.keys()) # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if "num_batches_tracked" not in key ] err_msg = [] if unexpected_keys: err_msg.append("unexpected key in source state_dict: {}\n".format( ", ".join(unexpected_keys))) if missing_keys: err_msg.append("missing keys in source state_dict: {}\n".format( ", ".join(missing_keys))) if shape_mismatch_pairs: mismatch_info = "these keys have mismatched shape:\n" header = ["key", "expected shape", "loaded shape"] table_data = [header] + shape_mismatch_pairs table = AsciiTable(table_data) err_msg.append(mismatch_info + table.table) rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, "The model and loaded state dict do not match exactly\n") err_msg = "\n".join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg)