def main(args): # Initialize multi-processing print("starting...") distributed.init_process_group(backend='nccl', init_method='env://') device_id, device = args.local_rank, torch.device(args.local_rank) rank, world_size = distributed.get_rank(), distributed.get_world_size() torch.cuda.set_device(device_id) # Initialize logging if rank == 0: logging.init(args.log_dir, "test") # Load configuration config = make_config(args) # Create dataloader test_dataloader = make_dataloader(args, config, rank, world_size) meta = load_meta(args.meta) # Create model print("model 0 :\n\n\n\n\n\n\n\n") model = make_model(config, meta["num_thing"], meta["num_stuff"]) # Load snapshot log_debug("Loading snapshot from %s", args.model) resume_from_snapshot(model, args.model, ["body", "rpn_head", "roi_head", "sem_head"]) # Init GPU stuff torch.backends.cudnn.benchmark = config["general"].getboolean("cudnn_benchmark") model = DistributedDataParallel(model.cuda(device), device_ids=[device_id], output_device=device_id) print("model:\n", model) # Panoptic processing parameters panoptic_preprocessing = PanopticPreprocessing(args.score_threshold, args.iou_threshold, args.min_area) if args.raw: save_function = partial(save_prediction_raw, out_dir=args.out_dir) else: palette = [] for i in range(256): if i < len(meta["palette"]): palette.append(meta["palette"][i]) else: palette.append((0, 0, 0)) palette = np.array(palette, dtype=np.uint8) save_function = partial( save_prediction_image, out_dir=args.out_dir, colors=palette, num_stuff=meta["num_stuff"]) test(model, test_dataloader, device=device, summary=None, log_interval=config["general"].getint("log_interval"), save_function=save_function, make_panoptic=panoptic_preprocessing, num_stuff=meta["num_stuff"])
def main(args): # Initialize multi-processing distributed.init_process_group(backend='nccl', init_method='env://') device_id, device = args.local_rank, torch.device(args.local_rank) rank, world_size = distributed.get_rank(), distributed.get_world_size() torch.cuda.set_device(device_id) # Initialize logging if rank == 0: logging.init(args.log_dir, "test") # Load configuration config = make_config(args) # Create dataloader test_dataloader = make_dataloader(args, config, rank, world_size) meta = load_meta(args.meta) # Create model model = make_model(config, meta["num_thing"], meta["num_stuff"]) # Load snapshot log_debug("Loading snapshot from %s", args.model) resume_from_snapshot(model, args.model, ["body", "rpn_head", "roi_head"]) # Init GPU stuff torch.backends.cudnn.benchmark = config["general"].getboolean( "cudnn_benchmark") model = DistributedDataParallel(model.cuda(device), device_ids=[device_id], output_device=device_id) if args.raw: save_function = partial(save_prediction_raw, out_dir=args.out_dir, threshold=args.threshold, obj_cls=args.person) else: save_function = partial(save_prediction_image, out_dir=args.out_dir, colors=meta["palette"], num_stuff=meta["num_stuff"], threshold=args.threshold, obj_cls=args.person) test(model, test_dataloader, device=device, summary=None, log_interval=config["general"].getint("log_interval"), save_function=save_function)
def main(args): # Initialize multi-processing distributed.init_process_group(backend='nccl', init_method='env://') device_id, device = args.local_rank, torch.device(args.local_rank) rank, world_size = distributed.get_rank(), distributed.get_world_size() torch.cuda.set_device(device_id) # Initialize logging if rank == 0: logging.init(args.log_dir, "training" if not args.eval else "eval") summary = tensorboard.SummaryWriter(args.log_dir) else: summary = None # Load configuration config = make_config(args) # Create dataloaders train_dataloader, val_dataloader = make_dataloader(args, config, rank, world_size) # Create model model = make_model(config, train_dataloader.dataset.num_thing, train_dataloader.dataset.num_stuff) if args.resume: assert not args.pre_train, "resume and pre_train are mutually exclusive" log_debug("Loading snapshot from %s", args.resume) snapshot = resume_from_snapshot(model, args.resume, ["body", "rpn_head", "roi_head"]) elif args.pre_train: assert not args.resume, "resume and pre_train are mutually exclusive" log_debug("Loading pre-trained model from %s", args.pre_train) pre_train_from_snapshots(model, args.pre_train, ["body", "rpn_head", "roi_head"]) else: assert not args.eval, "--resume is needed in eval mode" snapshot = None # Init GPU stuff torch.backends.cudnn.benchmark = config["general"].getboolean( "cudnn_benchmark") model = DistributedDataParallel(model.cuda(device), device_ids=[device_id], output_device=device_id, find_unused_parameters=True) # Create optimizer optimizer, scheduler, batch_update, total_epochs = make_optimizer( config, model, len(train_dataloader)) if args.resume: optimizer.load_state_dict(snapshot["state_dict"]["optimizer"]) # Training loop momentum = 1. - 1. / len(train_dataloader) meters = { "loss": AverageMeter((), momentum), "obj_loss": AverageMeter((), momentum), "bbx_loss": AverageMeter((), momentum), "roi_cls_loss": AverageMeter((), momentum), "roi_bbx_loss": AverageMeter((), momentum), "roi_msk_loss": AverageMeter((), momentum) } if args.resume: starting_epoch = snapshot["training_meta"]["epoch"] + 1 best_score = snapshot["training_meta"]["best_score"] global_step = snapshot["training_meta"]["global_step"] for name, meter in meters.items(): meter.load_state_dict(snapshot["state_dict"][name + "_meter"]) del snapshot else: starting_epoch = 0 best_score = 0 global_step = 0 # Optional: evaluation only: if args.eval: log_info("Validating epoch %d", starting_epoch - 1) validate(model, val_dataloader, config["optimizer"].getstruct("loss_weights"), device=device, summary=summary, global_step=global_step, epoch=starting_epoch - 1, num_epochs=total_epochs, log_interval=config["general"].getint("log_interval"), coco_gt=config["dataloader"]["coco_gt"], log_dir=args.log_dir) exit(0) for epoch in range(starting_epoch, total_epochs): log_info("Starting epoch %d", epoch + 1) if not batch_update: scheduler.step(epoch) # Run training epoch global_step = train( model, optimizer, scheduler, train_dataloader, meters, batch_update=batch_update, epoch=epoch, summary=summary, device=device, log_interval=config["general"].getint("log_interval"), num_epochs=total_epochs, global_step=global_step, loss_weights=config["optimizer"].getstruct("loss_weights")) # Save snapshot (only on rank 0) if rank == 0: snapshot_file = path.join(args.log_dir, "model_last.pth.tar") log_debug("Saving snapshot to %s", snapshot_file) meters_out_dict = { k + "_meter": v.state_dict() for k, v in meters.items() } save_snapshot(snapshot_file, config, epoch, 0, best_score, global_step, body=model.module.body.state_dict(), rpn_head=model.module.rpn_head.state_dict(), roi_head=model.module.roi_head.state_dict(), optimizer=optimizer.state_dict(), **meters_out_dict) if (epoch + 1) % config["general"].getint("val_interval") == 0: log_info("Validating epoch %d", epoch + 1) score = validate( model, val_dataloader, config["optimizer"].getstruct("loss_weights"), device=device, summary=summary, global_step=global_step, epoch=epoch, num_epochs=total_epochs, log_interval=config["general"].getint("log_interval"), coco_gt=config["dataloader"]["coco_gt"], log_dir=args.log_dir) # Update the score on the last saved snapshot if rank == 0: snapshot = torch.load(snapshot_file, map_location="cpu") snapshot["training_meta"]["last_score"] = score torch.save(snapshot, snapshot_file) del snapshot if score > best_score: best_score = score if rank == 0: shutil.copy(snapshot_file, path.join(args.log_dir, "model_best.pth.tar"))