Ejemplo n.º 1
0
def main(args):
    # Initialize multi-processing
    print("starting...")
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = args.local_rank, torch.device(args.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Initialize logging
    if rank == 0:
        logging.init(args.log_dir, "test")

    # Load configuration
    config = make_config(args)

    # Create dataloader
    test_dataloader = make_dataloader(args, config, rank, world_size)
    meta = load_meta(args.meta)

    # Create model
    print("model 0 :\n\n\n\n\n\n\n\n")
    model = make_model(config, meta["num_thing"], meta["num_stuff"])

    # Load snapshot
    log_debug("Loading snapshot from %s", args.model)
    resume_from_snapshot(model, args.model, ["body", "rpn_head", "roi_head", "sem_head"])

    # Init GPU stuff
    torch.backends.cudnn.benchmark = config["general"].getboolean("cudnn_benchmark")
    model = DistributedDataParallel(model.cuda(device), device_ids=[device_id], output_device=device_id)
    print("model:\n", model)

    # Panoptic processing parameters
    panoptic_preprocessing = PanopticPreprocessing(args.score_threshold, args.iou_threshold, args.min_area)

    if args.raw:
        save_function = partial(save_prediction_raw, out_dir=args.out_dir)
    else:
        palette = []
        for i in range(256):
            if i < len(meta["palette"]):
                palette.append(meta["palette"][i])
            else:
                palette.append((0, 0, 0))
        palette = np.array(palette, dtype=np.uint8)

        save_function = partial(
            save_prediction_image, out_dir=args.out_dir, colors=palette, num_stuff=meta["num_stuff"])
    test(model, test_dataloader, device=device, summary=None,
         log_interval=config["general"].getint("log_interval"), save_function=save_function,
         make_panoptic=panoptic_preprocessing, num_stuff=meta["num_stuff"])
Ejemplo n.º 2
0
def main(args):
    # Initialize multi-processing
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = args.local_rank, torch.device(args.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Initialize logging
    if rank == 0:
        logging.init(args.log_dir, "test")

    # Load configuration
    config = make_config(args)

    # Create dataloader
    test_dataloader = make_dataloader(args, config, rank, world_size)
    meta = load_meta(args.meta)

    # Create model
    model = make_model(config, meta["num_thing"], meta["num_stuff"])

    # Load snapshot
    log_debug("Loading snapshot from %s", args.model)
    resume_from_snapshot(model, args.model, ["body", "rpn_head", "roi_head"])

    # Init GPU stuff
    torch.backends.cudnn.benchmark = config["general"].getboolean(
        "cudnn_benchmark")
    model = DistributedDataParallel(model.cuda(device),
                                    device_ids=[device_id],
                                    output_device=device_id)

    if args.raw:
        save_function = partial(save_prediction_raw,
                                out_dir=args.out_dir,
                                threshold=args.threshold,
                                obj_cls=args.person)
    else:
        save_function = partial(save_prediction_image,
                                out_dir=args.out_dir,
                                colors=meta["palette"],
                                num_stuff=meta["num_stuff"],
                                threshold=args.threshold,
                                obj_cls=args.person)
    test(model,
         test_dataloader,
         device=device,
         summary=None,
         log_interval=config["general"].getint("log_interval"),
         save_function=save_function)
Ejemplo n.º 3
0
def main(args):
    # Initialize multi-processing
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = args.local_rank, torch.device(args.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Initialize logging
    if rank == 0:
        logging.init(args.log_dir, "training" if not args.eval else "eval")
        summary = tensorboard.SummaryWriter(args.log_dir)
    else:
        summary = None

    # Load configuration
    config = make_config(args)

    # Create dataloaders
    train_dataloader, val_dataloader = make_dataloader(args, config, rank,
                                                       world_size)

    # Create model
    model = make_model(config, train_dataloader.dataset.num_thing,
                       train_dataloader.dataset.num_stuff)
    if args.resume:
        assert not args.pre_train, "resume and pre_train are mutually exclusive"
        log_debug("Loading snapshot from %s", args.resume)
        snapshot = resume_from_snapshot(model, args.resume,
                                        ["body", "rpn_head", "roi_head"])
    elif args.pre_train:
        assert not args.resume, "resume and pre_train are mutually exclusive"
        log_debug("Loading pre-trained model from %s", args.pre_train)
        pre_train_from_snapshots(model, args.pre_train,
                                 ["body", "rpn_head", "roi_head"])
    else:
        assert not args.eval, "--resume is needed in eval mode"
        snapshot = None

    # Init GPU stuff
    torch.backends.cudnn.benchmark = config["general"].getboolean(
        "cudnn_benchmark")
    model = DistributedDataParallel(model.cuda(device),
                                    device_ids=[device_id],
                                    output_device=device_id,
                                    find_unused_parameters=True)

    # Create optimizer
    optimizer, scheduler, batch_update, total_epochs = make_optimizer(
        config, model, len(train_dataloader))
    if args.resume:
        optimizer.load_state_dict(snapshot["state_dict"]["optimizer"])

    # Training loop
    momentum = 1. - 1. / len(train_dataloader)
    meters = {
        "loss": AverageMeter((), momentum),
        "obj_loss": AverageMeter((), momentum),
        "bbx_loss": AverageMeter((), momentum),
        "roi_cls_loss": AverageMeter((), momentum),
        "roi_bbx_loss": AverageMeter((), momentum),
        "roi_msk_loss": AverageMeter((), momentum)
    }

    if args.resume:
        starting_epoch = snapshot["training_meta"]["epoch"] + 1
        best_score = snapshot["training_meta"]["best_score"]
        global_step = snapshot["training_meta"]["global_step"]
        for name, meter in meters.items():
            meter.load_state_dict(snapshot["state_dict"][name + "_meter"])
        del snapshot
    else:
        starting_epoch = 0
        best_score = 0
        global_step = 0

    # Optional: evaluation only:
    if args.eval:
        log_info("Validating epoch %d", starting_epoch - 1)
        validate(model,
                 val_dataloader,
                 config["optimizer"].getstruct("loss_weights"),
                 device=device,
                 summary=summary,
                 global_step=global_step,
                 epoch=starting_epoch - 1,
                 num_epochs=total_epochs,
                 log_interval=config["general"].getint("log_interval"),
                 coco_gt=config["dataloader"]["coco_gt"],
                 log_dir=args.log_dir)
        exit(0)

    for epoch in range(starting_epoch, total_epochs):
        log_info("Starting epoch %d", epoch + 1)
        if not batch_update:
            scheduler.step(epoch)

        # Run training epoch
        global_step = train(
            model,
            optimizer,
            scheduler,
            train_dataloader,
            meters,
            batch_update=batch_update,
            epoch=epoch,
            summary=summary,
            device=device,
            log_interval=config["general"].getint("log_interval"),
            num_epochs=total_epochs,
            global_step=global_step,
            loss_weights=config["optimizer"].getstruct("loss_weights"))

        # Save snapshot (only on rank 0)
        if rank == 0:
            snapshot_file = path.join(args.log_dir, "model_last.pth.tar")
            log_debug("Saving snapshot to %s", snapshot_file)
            meters_out_dict = {
                k + "_meter": v.state_dict()
                for k, v in meters.items()
            }
            save_snapshot(snapshot_file,
                          config,
                          epoch,
                          0,
                          best_score,
                          global_step,
                          body=model.module.body.state_dict(),
                          rpn_head=model.module.rpn_head.state_dict(),
                          roi_head=model.module.roi_head.state_dict(),
                          optimizer=optimizer.state_dict(),
                          **meters_out_dict)

        if (epoch + 1) % config["general"].getint("val_interval") == 0:
            log_info("Validating epoch %d", epoch + 1)
            score = validate(
                model,
                val_dataloader,
                config["optimizer"].getstruct("loss_weights"),
                device=device,
                summary=summary,
                global_step=global_step,
                epoch=epoch,
                num_epochs=total_epochs,
                log_interval=config["general"].getint("log_interval"),
                coco_gt=config["dataloader"]["coco_gt"],
                log_dir=args.log_dir)

            # Update the score on the last saved snapshot
            if rank == 0:
                snapshot = torch.load(snapshot_file, map_location="cpu")
                snapshot["training_meta"]["last_score"] = score
                torch.save(snapshot, snapshot_file)
                del snapshot

            if score > best_score:
                best_score = score
                if rank == 0:
                    shutil.copy(snapshot_file,
                                path.join(args.log_dir, "model_best.pth.tar"))