def setup_model(config, prepared, **kwargs):
    """
    Create a model

    Parameters
    ----------
    config : CfgNode
        Model configuration (cf. configs/default_config.py)
    prepared : bool
        True if the model has been prepared before
    kwargs : dict
        Extra parameters for the model

    Returns
    -------
    model : nn.Module
        Created model
    """
    print0(pcolor('Model: %s' % config.name, 'yellow'))
    model = load_class(config.name, paths=['packnet_sfm.models',])(
        **{**config.loss, **kwargs})
    # Add depth network if required
    if 'depth_net' in model.network_requirements:
        model.add_depth_net(setup_depth_net(config.depth_net, prepared))
    # Add pose network if required
    if 'pose_net' in model.network_requirements:
        model.add_pose_net(setup_pose_net(config.pose_net, prepared))
    # If a checkpoint is provided, load pretrained model
    if not prepared and config.checkpoint_path is not '':
        model = load_network(model, config.checkpoint_path, 'model')
    # Return model
    return model
示例#2
0
def get_default_config(cfg_default):
    """Get default configuration from file"""
    config = load_class('get_cfg_defaults',
                        paths=[cfg_default.replace('/', '.')],
                        concat=False)()
    config.merge_from_list(['default', cfg_default])
    return config
示例#3
0
def setup_model(config, prepared, **kwargs):
    """
    Create a model

    Parameters
    ----------
    config : CfgNode
        Model configuration (cf. configs/default_config.py)
    prepared : bool
        True if the model has been prepared before
    kwargs : dict
        Extra parameters for the model

    Returns
    -------
    model : nn.Module
        Created model
    """
    print0(pcolor('Model: %s' % config.name, 'yellow'))
    # SfmModel, SelfSupModel, VelSupModel loaded
    model = load_class(config.name, paths=['packnet_sfm.models',])(
        **{**config.loss, **kwargs})
    # Add depth network if required
    if model.network_requirements['depth_net']:
        model.add_depth_net(setup_depth_net(config.depth_net, prepared,
                                            num_scales=config.loss.num_scales,
                                            min_depth=config.params.min_depth,
                                            max_depth=config.params.max_depth,
                                            upsample_depth_maps=config.loss.upsample_depth_maps
                                            ))
    # Add pose network if required
    if model.network_requirements['pose_net']:
        model.add_pose_net(
            setup_pose_net(config.pose_net,
                           prepared,
                           rotation_mode=config.loss.rotation_mode,
                           **kwargs))
    # If a checkpoint is provided, load pretrained model
    if not prepared and config.checkpoint_path is not '':
        model = load_network(model, config.checkpoint_path, 'model')
    # Return model
    return model