def main_worker_eval(worker_id, args): device = torch.device("cuda:%d" % worker_id) cfg = setup(args) # build test set test_loader = build_data_loader(cfg, dataset, "test", multigpu=False, num_workers=8) logger.info("test - %d" % len(test_loader)) # load checkpoing and build model if cfg.MODEL.CHECKPOINT == "": raise ValueError("Invalid checkpoing provided") logger.info("Loading model from checkpoint: %s" % (cfg.MODEL.CHECKPOINT)) cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) model = build_model(cfg) model.load_state_dict(state_dict) logger.info("Model loaded") model.to(device) wandb.init(project='MeshRCNN', config=cfg, name='meshrcnn-eval') if args.eval_p2m: evaluate_test_p2m(model, test_loader) else: evaluate_test(model, test_loader)
def main_worker(worker_id, args): distributed = False if args.num_gpus > 1: distributed = True dist.init_process_group( backend="NCCL", init_method=args.dist_url, world_size=args.num_gpus, rank=worker_id ) torch.cuda.set_device(worker_id) device = torch.device("cuda:%d" % worker_id) cfg = setup(args) # data loaders loaders = setup_loaders(cfg) for split_name, loader in loaders.items(): logger.info("%s - %d" % (split_name, len(loader))) # build the model model = build_model(cfg) model.to(device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[worker_id], output_device=worker_id, check_reduction=True, broadcast_buffers=False, # find_unused_parameters=True, ) optimizer = build_optimizer(cfg, model) cfg.SOLVER.COMPUTED_MAX_ITERS = cfg.SOLVER.NUM_EPOCHS * len(loaders["train"]) scheduler = build_lr_scheduler(cfg, optimizer) loss_fn_kwargs = { "chamfer_weight": cfg.MODEL.MESH_HEAD.CHAMFER_LOSS_WEIGHT, "normal_weight": cfg.MODEL.MESH_HEAD.NORMALS_LOSS_WEIGHT, "edge_weight": cfg.MODEL.MESH_HEAD.EDGE_LOSS_WEIGHT, "voxel_weight": cfg.MODEL.VOXEL_HEAD.LOSS_WEIGHT, "gt_num_samples": cfg.MODEL.MESH_HEAD.GT_NUM_SAMPLES, "pred_num_samples": cfg.MODEL.MESH_HEAD.PRED_NUM_SAMPLES, "upsample_pred_mesh": cfg.MODEL.MESH_HEAD.UPSAMPLE_PRED_MESH, } loss_fn = MeshLoss.voxel_loss checkpoint_path = "checkpoint.pt" checkpoint_path = os.path.join(cfg.OUTPUT_DIR, checkpoint_path) cp = Checkpoint(checkpoint_path) if len(cp.restarts) == 0: # We are starting from scratch, so store some initial data in cp iter_per_epoch = len(loaders["train"]) cp.store_data("iter_per_epoch", iter_per_epoch) else: logger.info("Loading model state from checkpoint") model.load_state_dict(cp.latest_states["model"]) optimizer.load_state_dict(cp.latest_states["optim"]) scheduler.load_state_dict(cp.latest_states["lr_scheduler"]) training_loop(cfg, cp, model, optimizer, scheduler, loaders, device, loss_fn)
def main_worker_eval(worker_id, args): device = torch.device("cuda:%d" % worker_id) cfg = setup(args) # build test set test_loader = build_data_loader( cfg, get_dataset_name(cfg), "test", multigpu=False ) logger.info("test - %d" % len(test_loader)) # load checkpoing and build model if cfg.MODEL.CHECKPOINT == "": raise ValueError("Invalid checkpoing provided") logger.info("Loading model from checkpoint: %s" % (cfg.MODEL.CHECKPOINT)) cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) model = build_model(cfg) model.load_state_dict(state_dict) logger.info("Model loaded") model.to(device) def disable_running_stats(model): if type(model).__name__.startswith("BatchNorm"): model.track_running_stats = False else: for m in model.children(): disable_running_stats(m) # disable_running_stats(model) val_loader = build_data_loader( cfg, get_dataset_name(cfg), "test", multigpu=False ) logger.info("val - %d" % len(val_loader)) test_metrics = evaluate_vox( model, val_loader, max_predictions=100 ) str_out = "Results on test" for k, v in test_metrics.items(): str_out += "%s %.4f " % (k, v) logger.info(str_out) prediction_dir = os.path.join( cfg.OUTPUT_DIR, "predictions" ) test_metrics = evaluate_vox(model, test_loader, prediction_dir) print(test_metrics) str_out = "Results on test" for k, v in test_metrics.items(): str_out += "%s %.4f " % (k, v) logger.info(str_out)