示例#1
0
文件: eval.py 项目: zxhou/packnet-sfm
def test(ckpt_file, cfg_file, half):
    """
    Monocular depth estimation test script.

    Parameters
    ----------
    ckpt_file : str
        Checkpoint path for a pretrained model
    cfg_file : str
        Configuration file
    half: bool
        use half precision (fp16)
    """
    # Initialize horovod
    hvd_init()

    # Parse arguments
    config, state_dict = parse_test_file(ckpt_file, cfg_file)

    # Set debug if requested
    set_debug(config.debug)

    # Initialize monodepth model from checkpoint arguments
    model_wrapper = ModelWrapper(config)
    # Restore model state
    model_wrapper.load_state_dict(state_dict)

    # change to half precision for evaluation if requested
    config.arch["dtype"] = torch.float16 if half else None

    # Create trainer with args.arch parameters
    trainer = HorovodTrainer(**config.arch)

    # Test model
    trainer.test(model_wrapper)
示例#2
0
def train(file):
    """
    Monocular depth estimation training script.

    Parameters
    ----------
    file : str
        Filepath, can be either a
        **.yaml** for a yacs configuration file or a
        **.ckpt** for a pre-trained checkpoint file.
    """
    # Initialize horovod
    hvd_init()

    # Produce configuration and checkpoint from filename
    config, ckpt = parse_train_file(file)

    # Set debug if requested
    set_debug(config.debug)

    # model checkpoint
    checkpoint = (None if config.checkpoint.filepath is "" or rank() > 0 else
                  filter_args_create(ModelCheckpoint, config.checkpoint))

    # Initialize model wrapper
    model_wrapper = ModelWrapper(config, resume=ckpt)

    # Create trainer with args.arch parameters
    trainer = HorovodTrainer(**config.arch, checkpoint=checkpoint)

    # Train model
    trainer.fit(model_wrapper)
示例#3
0
def train(file):
    """
    Monocular depth estimation training script.

    Parameters
    ----------
    file : str
        Filepath, can be either a
        **.yaml** for a yacs configuration file or a
        **.ckpt** for a pre-trained checkpoint file.
    """
    # Produce configuration and checkpoint from filename
    config, ckpt = parse_train_file(file)
    # config.arch.max_epochs=50

    # Initialize horovod
    if hasattr(config, "gpu"):
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            str(x) for x in config.gpu)  # "0,1"
    hvd_init()

    # Set debug if requested
    set_debug(config.debug)

    # Wandb Logger
    logger = None  # if config.wandb.dry_run or rank() > 0 \
    # else filter_args_create(WandbLogger, config.wandb)

    # model checkpoint
    checkpoint = None if config.checkpoint.filepath is '' or rank() > 0 else \
        filter_args_create(ModelCheckpoint, config.checkpoint)

    # Initialize model wrapper
    model_wrapper = ModelWrapper(config,
                                 resume=ckpt,
                                 logger=logger,
                                 use_horovod=HAS_HOROVOD)

    # Create trainer with args.arch parameters
    if HAS_HOROVOD:
        from packnet_sfm.trainers.horovod_trainer import HorovodTrainer
        trainer = HorovodTrainer(checkpoint=checkpoint, **config.arch)
    else:
        trainer = PytorchTrainer(checkpoint=checkpoint, **config.arch)

    # Train model
    trainer.fit(model_wrapper)
示例#4
0
def train(file):
    """
    Monocular depth estimation training script.

    Parameters
    ----------
    file : str
        Filepath, can be either a
        **.yaml** for a yacs configuration file or a
        **.ckpt** for a pre-trained checkpoint file.
    """
    # Initialize horovod
    hvd_init()

    # Produce configuration and checkpoint from filename
    config, ckpt = parse_train_file(file)

    # Set debug if requested
    set_debug(config.debug)

    # Wandb Logger
    logger = None if config.wandb.dry_run or rank() > 0 \
        else filter_args_create(WandbLogger, config.wandb)

    # model checkpoint
    checkpoint = None if config.checkpoint.filepath is '' or rank() > 0 else \
        filter_args_create(ModelCheckpoint, config.checkpoint)

    # Initialize model wrapper
    model_wrapper = ModelWrapper(config, resume=ckpt, logger=logger)

    print("Depth Net - %d parameters" % sum(p.numel() for p in model_wrapper.depth_net.parameters()))
    print("Pose Net - %d parameters" % sum(p.numel() for p in model_wrapper.pose_net.parameters()))

    # Create trainer with args.arch parameters
    trainer = HorovodTrainer(**config.arch, checkpoint=checkpoint)

    # Train model
    trainer.fit(model_wrapper)