예제 #1
0
def setup_depth_net(config, prepared, **kwargs):
    """
    Create a depth network

    Parameters
    ----------
    config : CfgNode
        Network configuration
    prepared : bool
        True if the network has been prepared before
    kwargs : dict
        Extra parameters for the network

    Returns
    -------
    depth_net : nn.Module
        Create depth network
    """
    print0(pcolor("DepthNet: %s" % config.name, "yellow"))
    depth_net = load_class_args_create(
        config.name,
        paths=[
            "packnet_sfm.networks.depth",
        ],
        args={
            **config,
            **kwargs
        },
    )
    if not prepared and config.checkpoint_path is not "":
        depth_net = load_network(depth_net, config.checkpoint_path,
                                 ["depth_net", "disp_network"])
    return depth_net
def setup_model(config, prepared, **kwargs):
    """
    Create a model

    Parameters
    ----------
    config : CfgNode
        Model configuration (cf. configs/default_config.py)
    prepared : bool
        True if the model has been prepared before
    kwargs : dict
        Extra parameters for the model

    Returns
    -------
    model : nn.Module
        Created model
    """
    print0(pcolor('Model: %s' % config.name, 'yellow'))
    model = load_class(config.name, paths=['packnet_sfm.models',])(
        **{**config.loss, **kwargs})
    # Add depth network if required
    if 'depth_net' in model.network_requirements:
        model.add_depth_net(setup_depth_net(config.depth_net, prepared))
    # Add pose network if required
    if 'pose_net' in model.network_requirements:
        model.add_pose_net(setup_pose_net(config.pose_net, prepared))
    # If a checkpoint is provided, load pretrained model
    if not prepared and config.checkpoint_path is not '':
        model = load_network(model, config.checkpoint_path, 'model')
    # Return model
    return model
예제 #3
0
def setup_pose_net(config, prepared, **kwargs):
    """
    Create a pose network

    Parameters
    ----------
    config : CfgNode
        Network configuration
    prepared : bool
        True if the network has been prepared before
    kwargs : dict
        Extra parameters for the network

    Returns
    -------
    pose_net : nn.Module
        Created pose network
    """
    print0(pcolor('PoseNet: %s' % config.name, 'yellow'))
    pose_net = load_class_args_create(
        config.name,
        paths=[
            'packnet_sfm.networks.pose',
        ],
        args={
            **config,
            **kwargs
        },
    )
    if not prepared and config.checkpoint_path is not '':
        pose_net = load_network(pose_net, config.checkpoint_path,
                                ['pose_net', 'pose_network'])
    return pose_net
예제 #4
0
def main(args, N):

    # Initialize horovod
    hvd_init()

    # Parse arguments
    configs = []
    state_dicts = []
    for i in range(N):
        config, state_dict = parse_test_file(args.checkpoints[i])
        configs.append(config)
        state_dicts.append(state_dict)

    # If no image shape is provided, use the checkpoint one
    image_shape = args.image_shape
    if image_shape is None:
        image_shape = configs[0].datasets.augmentation.image_shape

    # Set debug if requested
    set_debug(configs[0].debug)

    model_wrappers = []
    for i in range(N):
        # Initialize model wrapper from checkpoint arguments
        model_wrappers.append(ModelWrapper(configs[i], load_datasets=False))
        # Restore monodepth_model state
        model_wrappers[i].load_state_dict(state_dicts[i])

    # change to half precision for evaluation if requested
    dtype = torch.float16 if args.half else None

    # Send model to GPU if available
    if torch.cuda.is_available():
        for i in range(N):
            model_wrappers[i] = model_wrappers[i].to('cuda:{}'.format(rank()),
                                                     dtype=dtype)

    # Set to eval mode
    for i in range(N):
        model_wrappers[i].eval()

    if args.input_folders is None:
        files = [[args.input_imgs[i]] for i in range(N)]
    else:
        files = [[] for i in range(N)]
        for i in range(N):
            for ext in ['png', 'jpg']:
                files[i] = glob.glob((os.path.join(args.input_folders[i],
                                                   '*.{}'.format(ext))))
            files[i].sort()
            print0('Found {} files'.format(len(files[i])))

    n_files = len(files[0])
    # Process each file
    infer_plot_and_save_3D_pcl(files, args.output, model_wrappers, image_shape,
                               args.half, args.save, bool(int(args.stop)))
예제 #5
0
파일: load.py 프로젝트: zxhou/packnet-sfm
def load_network(network, path, prefixes=''):
    """
    Loads a pretrained network

    Parameters
    ----------
    network : nn.Module
        Network that will receive the pretrained weights
    path : str
        File containing a 'state_dict' key with pretrained network weights
    prefixes : str or list of str
        Layer name prefixes to consider when loading the network

    Returns
    -------
    network : nn.Module
        Updated network with pretrained weights
    """
    prefixes = make_list(prefixes)
    # If path is a string
    if is_str(path):
        saved_state_dict = torch.load(path, map_location='cpu')['state_dict']
        if path.endswith('.pth.tar'):
            saved_state_dict = backwards_state_dict(saved_state_dict)
    # If state dict is already provided
    else:
        saved_state_dict = path
    # Get network state dict
    network_state_dict = network.state_dict()

    updated_state_dict = OrderedDict()
    n, n_total = 0, len(network_state_dict.keys())
    for key, val in saved_state_dict.items():
        for prefix in prefixes:
            prefix = prefix + '.'
            if prefix in key:
                idx = key.find(prefix) + len(prefix)
                key = key[idx:]
                if key in network_state_dict.keys() and \
                        same_shape(val.shape, network_state_dict[key].shape):
                    updated_state_dict[key] = val
                    n += 1

    network.load_state_dict(updated_state_dict, strict=False)
    base_color, attrs = 'cyan', ['bold', 'dark']
    color = 'green' if n == n_total else 'yellow' if n > 0 else 'red'
    print0(
        pcolor('###### Pretrained {} loaded:'.format(prefixes[0]),
               base_color,
               attrs=attrs) +
        pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) +
        pcolor('tensors', base_color, attrs=attrs))
    return network
 def prepare_model(self, resume=None):
     """Prepare self.model (incl. loading previous state)"""
     print0(pcolor('### Preparing Model', 'green'))
     self.model = setup_model(self.config.model, self.config.prepared)
     # Resume model if available
     if resume:
         print0(pcolor('### Resuming from {}'.format(
             resume['file']), 'magenta', attrs=['bold']))
         self.model = load_network(
             self.model, resume['state_dict'], 'model')
         if 'epoch' in resume:
             self.current_epoch = resume['epoch']
예제 #7
0
def main(args, N):

    # Initialize horovod
    hvd_init()

    # Parse arguments
    configs = []
    state_dicts = []
    for i in range(N):
        config, state_dict = parse_test_file(args.checkpoints[i])
        configs.append(config)
        state_dicts.append(state_dict)

    # If no image shape is provided, use the checkpoint one
    image_shape = args.image_shape
    if image_shape is None:
        image_shape = configs[0].datasets.augmentation.image_shape

    # Set debug if requested
    set_debug(configs[0].debug)

    model_wrappers = []
    for i in range(N):
        # Initialize model wrapper from checkpoint arguments
        model_wrappers.append(ModelWrapper(configs[i], load_datasets=False))
        # Restore monodepth_model state
        model_wrappers[i].load_state_dict(state_dicts[i])

    # Send model to GPU if available
    if torch.cuda.is_available():
        for i in range(N):
            model_wrappers[i] = model_wrappers[i].to('cuda:{}'.format(rank()))

    # Set to eval mode
    for i in range(N):
        model_wrappers[i].eval()

    if args.input_folder is None:
        files = [[args.input_imgs[i]] for i in range(N)]
    else:
        files = [[] for _ in range(N)]
        for i in range(N):
            for ext in ['png', 'jpg']:
                files[i] = glob(
                    (os.path.join(args.input_folder, 'cam_' + str(i) + '/',
                                  '*.{}'.format(ext))))
            files[i].sort()
            files[i] = files[i][::args.every_n_files]
            print0('Found {} files'.format(len(files[i])))

    n_files = len(files[0])
    # Process each file
    infer_optimal_calib(files, model_wrappers, image_shape)
예제 #8
0
def infer(ckpt_file, input_file, output_file, image_shape):
    """
    Monocular depth estimation test script.

    Parameters
    ----------
    ckpt_file : str
        Checkpoint path for a pretrained model
    input_file : str
        File or folder with input images
    output_file : str
        File or folder with output images
    image_shape : tuple
        Input image shape (H,W)
    """
    # Initialize horovod
    hvd_init()

    # Parse arguments
    config, state_dict = parse_test_file(ckpt_file)

    # If no image shape is provided, use the checkpoint one
    if image_shape is None:
        image_shape = config.datasets.augmentation.image_shape

    # Set debug if requested
    set_debug(config.debug)

    # Initialize model wrapper from checkpoint arguments
    model_wrapper = ModelWrapper(config, load_datasets=False)
    # Restore monodepth_model state
    model_wrapper.load_state_dict(state_dict)

    # Send model to GPU if available
    if torch.cuda.is_available():
        model_wrapper = model_wrapper.to('cuda:{}'.format(rank()))

    if os.path.isdir(input_file):
        # If input file is a folder, search for image files
        files = []
        for ext in ['png', 'jpg']:
            files.extend(glob((os.path.join(input_file, '*.{}'.format(ext)))))
        files.sort()
        print0('Found {} files'.format(len(files)))
    else:
        # Otherwise, use it as is
        files = [input_file]

    # Process each file
    for file in files[rank()::world_size()]:
        process(file, output_file, model_wrapper, image_shape)
예제 #9
0
    def prepare_datasets(self):
        """Prepare datasets for training, validation and test."""

        # Prepare datasets
        print0(pcolor('### Preparing Datasets', 'green'))

        augmentation = self.config.datasets.augmentation
        self.train_dataset = setup_dataset(self.config.datasets.train, 'train',
                                           self.model.requires_gt_depth,
                                           **augmentation)
        self.validation_dataset = setup_dataset(
            self.config.datasets.validation, 'validation', **augmentation)
        self.test_dataset = setup_dataset(self.config.datasets.test, 'test',
                                          **augmentation)
예제 #10
0
 def prepare_model(self, resume=None):
     """Prepare self.model (incl. loading previous state)"""
     print0(pcolor("### Preparing Model", "green"))
     self.model = setup_model(self.config.model, self.config.prepared)
     # Resume model if available
     if resume:
         print0(
             pcolor(
                 "### Resuming from {}".format(resume["file"]),
                 "magenta",
                 attrs=["bold"],
             ))
         self.model = load_network(self.model, resume["state_dict"],
                                   "model")
         if "epoch" in resume:
             self.current_epoch = resume["epoch"]
예제 #11
0
    def prepare_datasets(self, validation_requirements, test_requirements):
        """Prepare datasets for training, validation and test."""
        # Prepare datasets
        print0(pcolor('### Preparing Datasets', 'green'))

        augmentation = self.config.datasets.augmentation
        # Setup train dataset (requirements are given by the model itself)
        self.train_dataset = setup_dataset(self.config.datasets.train, 'train',
                                           self.model.train_requirements,
                                           **augmentation)
        # Setup validation dataset
        self.validation_dataset = setup_dataset(
            self.config.datasets.validation, 'validation',
            validation_requirements, **augmentation)
        # Setup test dataset
        self.test_dataset = setup_dataset(self.config.datasets.test, 'test',
                                          test_requirements, **augmentation)
예제 #12
0
def main(args):

    # Initialize horovod
    hvd_init()

    # Parse arguments
    config, state_dict = parse_test_file(args.checkpoint)

    # If no image shape is provided, use the checkpoint one
    image_shape = args.image_shape
    if image_shape is None:
        image_shape = config.datasets.augmentation.image_shape

    # Set debug if requested
    set_debug(config.debug)

    # Initialize model wrapper from checkpoint arguments
    model_wrapper = ModelWrapper(config, load_datasets=False)
    # Restore monodepth_model state
    model_wrapper.load_state_dict(state_dict)

    # change to half precision for evaluation if requested
    dtype = torch.float16 if args.half else None

    # Send model to GPU if available
    if torch.cuda.is_available():
        model_wrapper = model_wrapper.to('cuda:{}'.format(rank()), dtype=dtype)

    # Set to eval mode
    model_wrapper.eval()

    if os.path.isdir(args.input):
        # If input file is a folder, search for image files
        files = []
        for ext in ['png', 'jpg']:
            files.extend(glob((os.path.join(args.input, '*.{}'.format(ext)))))
        files.sort()
        print0('Found {} files'.format(len(files)))
    else:
        # Otherwise, use it as is
        files = [args.input]

    # Process each file
    for fn in files[rank()::world_size()]:
        infer_plot_and_save_3D_pcl(fn, args.output, model_wrapper, image_shape,
                                   args.half, args.save)
예제 #13
0
def setup_model(config, prepared, **kwargs):
    """
    Create a model

    Parameters
    ----------
    config : CfgNode
        Model configuration (cf. configs/default_config.py)
    prepared : bool
        True if the model has been prepared before
    kwargs : dict
        Extra parameters for the model

    Returns
    -------
    model : nn.Module
        Created model
    """
    print0(pcolor('Model: %s' % config.name, 'yellow'))
    # SfmModel, SelfSupModel, VelSupModel loaded
    model = load_class(config.name, paths=['packnet_sfm.models',])(
        **{**config.loss, **kwargs})
    # Add depth network if required
    if model.network_requirements['depth_net']:
        model.add_depth_net(setup_depth_net(config.depth_net, prepared,
                                            num_scales=config.loss.num_scales,
                                            min_depth=config.params.min_depth,
                                            max_depth=config.params.max_depth,
                                            upsample_depth_maps=config.loss.upsample_depth_maps
                                            ))
    # Add pose network if required
    if model.network_requirements['pose_net']:
        model.add_pose_net(
            setup_pose_net(config.pose_net,
                           prepared,
                           rotation_mode=config.loss.rotation_mode,
                           **kwargs))
    # If a checkpoint is provided, load pretrained model
    if not prepared and config.checkpoint_path is not '':
        model = load_network(model, config.checkpoint_path, 'model')
    # Return model
    return model
예제 #14
0
    def prepare_datasets(self, validation_requirements, test_requirements):
        """Prepare datasets for training, validation and test."""
        # Prepare datasets
        print0(pcolor("### Preparing Datasets", "green"))

        augmentation = self.config.datasets.augmentation
        # Setup train dataset (requirements are given by the model itself)
        self.train_dataset = setup_dataset(
            self.config.datasets.train,
            "train",
            self.model.train_requirements,
            **augmentation,
        )
        # Setup validation dataset
        self.validation_dataset = setup_dataset(
            self.config.datasets.validation,
            "validation",
            validation_requirements,
            **augmentation,
        )
        # Setup test dataset
        self.test_dataset = setup_dataset(self.config.datasets.test, "test",
                                          test_requirements, **augmentation)
예제 #15
0
def setup_dataset(config, mode, requirements, **kwargs):
    """
    Create a dataset class

    Parameters
    ----------
    config : CfgNode
        Configuration (cf. configs/default_config.py)
    mode : str {'train', 'validation', 'test'}
        Mode from which we want the dataset
    requirements : dict (string -> bool)
        Different requirements for dataset loading (gt_depth, gt_pose, etc)
    kwargs : dict
        Extra parameters for dataset creation

    Returns
    -------
    dataset : Dataset
        Dataset class for that mode
    """
    # If no dataset is given, return None
    if len(config.path) == 0:
        return None

    print0(pcolor('###### Setup %s datasets' % mode, 'red'))

    # Global shared dataset arguments
    dataset_args = {
        'back_context': config.back_context,
        'forward_context': config.forward_context,
        'with_geometric_context': config.with_geometric_context,
    }

    # Loop over all datasets
    datasets = []
    for i in range(len(config.split)):
        path_split = os.path.join(config.path[i], config.split[i])

        # Individual shared dataset arguments
        if config.dataset[i] == 'ValeoMultifocal':
            dataset_args_i = {
                'depth_type':
                config.depth_type[i] if requirements['gt_depth'] else None,
                'with_pose':
                requirements['gt_pose'],
                'data_transform':
                get_transforms_multifocal(mode, **kwargs),
                'with_spatiotemp_context':
                config.with_spatiotemp_context,
            }
        elif config.dataset[i] == 'KITTIValeoFisheye':
            dataset_args_i = {
                'depth_type':
                config.depth_type[i] if requirements['gt_depth'] else None,
                'with_pose':
                requirements['gt_pose'],
                'data_transform':
                get_transforms_fisheye(mode, **kwargs),
                'calibrations_suffix':
                config.calibrations_suffix,
                'depth_suffix':
                config.depth_suffix,
                'cam_convs':
                config.cam_convs
            }
        elif config.dataset[i] == 'KITTIValeoDistorted':
            dataset_args_i = {
                'depth_type':
                config.depth_type[i] if requirements['gt_depth'] else None,
                'with_pose':
                requirements['gt_pose'],
                'data_transform':
                get_transforms_distorted(mode, **kwargs)
            }
        elif config.dataset[i] == 'DGPvaleo':
            dataset_args_i = {
                'depth_type':
                config.depth_type[i] if requirements['gt_depth'] else None,
                'with_pose':
                requirements['gt_pose'],
                'data_transform':
                get_transforms_dgp_valeo(mode, **kwargs)
            }
        elif config.dataset[i] == 'WoodscapeFisheye':
            dataset_args_i = {
                'depth_type':
                config.depth_type[i] if requirements['gt_depth'] else None,
                'with_pose':
                requirements['gt_pose'],
                'data_transform':
                get_transforms_woodscape_fisheye(mode, **kwargs)
            }
        else:
            dataset_args_i = {
                'depth_type':
                config.depth_type[i] if requirements['gt_depth'] else None,
                'with_pose':
                requirements['gt_pose'],
                'data_transform':
                get_transforms(mode, **kwargs)
            }

        if config.dataset[i] == 'ValeoMultifocal':
            from packnet_sfm.datasets.kitti_based_valeo_dataset_multifocal import KITTIBasedValeoDatasetMultifocal
            dataset = KITTIBasedValeoDatasetMultifocal(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # KITTI dataset
        elif config.dataset[i] == 'KITTI':
            from packnet_sfm.datasets.kitti_dataset import KITTIDataset
            dataset = KITTIDataset(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
            )
        # DGP dataset
        elif config.dataset[i] == 'DGP':
            from packnet_sfm.datasets.dgp_dataset import DGPDataset
            dataset = DGPDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # DGP dataset
        elif config.dataset[i] == 'DGPvaleo':
            from packnet_sfm.datasets.dgp_valeo_dataset import DGPvaleoDataset
            dataset = DGPvaleoDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # Image dataset
        elif config.dataset[i] == 'Image':
            from packnet_sfm.datasets.image_dataset import ImageDataset
            dataset = ImageDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
            )
        # KITTI-based Valeo dataset
        elif config.dataset[i] == 'KITTIValeo':
            from packnet_sfm.datasets.kitti_based_valeo_dataset import KITTIBasedValeoDataset
            dataset = KITTIBasedValeoDataset(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # KITTI-based Valeo dataset (fisheye)
        elif config.dataset[i] == 'KITTIValeoFisheye':
            from packnet_sfm.datasets.kitti_based_valeo_dataset_fisheye_singleView import \
                KITTIBasedValeoDatasetFisheye_singleView
            dataset = KITTIBasedValeoDatasetFisheye_singleView(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        elif config.dataset[i] == 'KITTIValeoDistorted':
            from packnet_sfm.datasets.kitti_based_valeo_dataset_distorted_singleView import \
                KITTIBasedValeoDatasetDistorted_singleView
            dataset = KITTIBasedValeoDatasetDistorted_singleView(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        elif config.dataset[i] == 'WoodscapeFisheye':
            from packnet_sfm.datasets.woodscape_fisheye import WoodscapeFisheye
            dataset = WoodscapeFisheye(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # Image-based Valeo dataset
        elif config.dataset[i] == 'ImageValeo':
            from packnet_sfm.datasets.image_based_valeo_dataset import ImageBasedValeoDataset
            dataset = ImageBasedValeoDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        else:
            ValueError('Unknown dataset %d' % config.dataset[i])

        # Repeat if needed
        if 'repeat' in config and config.repeat[i] > 1:
            dataset = ConcatDataset([dataset for _ in range(config.repeat[i])])
        datasets.append(dataset)

        # Display dataset information
        bar = '######### {:>7}'.format(len(dataset))
        if 'repeat' in config:
            bar += ' (x{})'.format(config.repeat[i])
        bar += ': {:<}'.format(path_split)
        print0(pcolor(bar, 'yellow'))

    # If training, concatenate all datasets into a single one
    if mode == 'train':
        datasets = [ConcatDataset(datasets)]

    return datasets
예제 #16
0
def setup_dataset(config, mode, requirements, **kwargs):
    """
    Create a dataset class

    Parameters
    ----------
    config : CfgNode
        Configuration (cf. configs/default_config.py)
    mode : str {'train', 'validation', 'test'}
        Mode from which we want the dataset
    requirements : dict (string -> bool)
        Different requirements for dataset loading (gt_depth, gt_pose, etc)
    kwargs : dict
        Extra parameters for dataset creation

    Returns
    -------
    dataset : Dataset
        Dataset class for that mode
    """
    # If no dataset is given, return None
    if len(config.path) == 0:
        return None

    print0(pcolor("###### Setup %s datasets" % mode, "red"))

    # Global shared dataset arguments
    dataset_args = {
        "back_context": config.back_context,
        "forward_context": config.forward_context,
        "data_transform": get_transforms(mode, **kwargs),
    }

    # Loop over all datasets
    datasets = []
    for i in range(len(config.split)):
        path_split = os.path.join(config.path[i], config.split[i])

        # Individual shared dataset arguments
        dataset_args_i = {
            "depth_type":
            config.depth_type[i] if requirements["gt_depth"] else None,
            "with_pose": requirements["gt_pose"],
        }

        # KITTI dataset
        if config.dataset[i] == "KITTI":
            from packnet_sfm.datasets.kitti_dataset import KITTIDataset

            dataset = KITTIDataset(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
            )
        # DGP dataset
        elif config.dataset[i] == "DGP":
            from packnet_sfm.datasets.dgp_dataset import DGPDataset

            dataset = DGPDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # Image dataset
        elif config.dataset[i] == "Image":
            from packnet_sfm.datasets.image_dataset import ImageDataset

            dataset = ImageDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
            )
        else:
            ValueError("Unknown dataset %d" % config.dataset[i])

        # Repeat if needed
        if "repeat" in config and config.repeat[i] > 1:
            dataset = ConcatDataset([dataset for _ in range(config.repeat[i])])
        datasets.append(dataset)

        # Display dataset information
        bar = "######### {:>7}".format(len(dataset))
        if "repeat" in config:
            bar += " (x{})".format(config.repeat[i])
        bar += ": {:<}".format(path_split)
        print0(pcolor(bar, "yellow"))

    # If training, concatenate all datasets into a single one
    if mode == "train":
        datasets = [ConcatDataset(datasets)]

    return datasets
예제 #17
0
def main(args):

    # Initialize horovod
    hvd_init()

    # Parse arguments
    config1, state_dict1 = parse_test_file(args.checkpoint1)
    config2, state_dict2 = parse_test_file(args.checkpoint2)
    config3, state_dict3 = parse_test_file(args.checkpoint3)
    config4, state_dict4 = parse_test_file(args.checkpoint4)

    # If no image shape is provided, use the checkpoint one
    image_shape = args.image_shape
    if image_shape is None:
        image_shape = config1.datasets.augmentation.image_shape

    # Set debug if requested
    set_debug(config1.debug)

    # Initialize model wrapper from checkpoint arguments
    model_wrapper1 = ModelWrapper(config1, load_datasets=False)
    model_wrapper2 = ModelWrapper(config2, load_datasets=False)
    model_wrapper3 = ModelWrapper(config3, load_datasets=False)
    model_wrapper4 = ModelWrapper(config4, load_datasets=False)
    # Restore monodepth_model state
    model_wrapper1.load_state_dict(state_dict1)
    model_wrapper2.load_state_dict(state_dict2)
    model_wrapper3.load_state_dict(state_dict3)
    model_wrapper4.load_state_dict(state_dict4)

    # change to half precision for evaluation if requested
    dtype = torch.float16 if args.half else None

    # Send model to GPU if available
    if torch.cuda.is_available():
        model_wrapper1 = model_wrapper1.to('cuda:{}'.format(rank()),
                                           dtype=dtype)
        model_wrapper2 = model_wrapper2.to('cuda:{}'.format(rank()),
                                           dtype=dtype)
        model_wrapper3 = model_wrapper3.to('cuda:{}'.format(rank()),
                                           dtype=dtype)
        model_wrapper4 = model_wrapper4.to('cuda:{}'.format(rank()),
                                           dtype=dtype)

    # Set to eval mode
    model_wrapper1.eval()
    model_wrapper2.eval()
    model_wrapper3.eval()
    model_wrapper4.eval()

    if os.path.isdir(args.input1):
        # If input file is a folder, search for image files
        files1 = []
        for ext in ['png', 'jpg']:
            files1.extend(glob((os.path.join(args.input1,
                                             '*.{}'.format(ext)))))
        files1.sort()
        print0('Found {} files'.format(len(files1)))
    else:
        # Otherwise, use it as is
        files1 = [args.input1]
    if os.path.isdir(args.input2):
        # If input file is a folder, search for image files
        files2 = []
        for ext in ['png', 'jpg']:
            files2.extend(glob((os.path.join(args.input2,
                                             '*.{}'.format(ext)))))
        files2.sort()
        print0('Found {} files'.format(len(files2)))
    else:
        # Otherwise, use it as is
        files2 = [args.input2]
    if os.path.isdir(args.input3):
        # If input file is a folder, search for image files
        files3 = []
        for ext in ['png', 'jpg']:
            files3.extend(glob((os.path.join(args.input3,
                                             '*.{}'.format(ext)))))
        files3.sort()
        print0('Found {} files'.format(len(files3)))
    else:
        # Otherwise, use it as is
        files3 = [args.input3]
    if os.path.isdir(args.input4):
        # If input file is a folder, search for image files
        files4 = []
        for ext in ['png', 'jpg']:
            files4.extend(glob((os.path.join(args.input4,
                                             '*.{}'.format(ext)))))
        files4.sort()
        print0('Found {} files'.format(len(files4)))
    else:
        # Otherwise, use it as is
        files4 = [args.input4]

    n_files = len(files1)
    # Process each file
    for fn1, fn2, fn3, fn4 in zip(files1[rank()::world_size()],
                                  files2[rank()::world_size()],
                                  files3[rank()::world_size()],
                                  files4[rank()::world_size()]):
        infer_plot_and_save_3D_pcl(fn1, fn2, fn3, fn4, args.output1,
                                   args.output2, args.output3, args.output4,
                                   model_wrapper1, model_wrapper2,
                                   model_wrapper3, model_wrapper4,
                                   bool(int(args.hasGTdepth1)),
                                   bool(int(args.hasGTdepth2)),
                                   bool(int(args.hasGTdepth3)),
                                   bool(int(args.hasGTdepth4)), image_shape,
                                   args.half, args.save)
def setup_dataset(config, mode, requirements, **kwargs):
    """
    Create a dataset class

    Parameters
    ----------
    config : CfgNode
        Configuration (cf. configs/default_config.py)
    mode : str {'train', 'validation', 'test'}
        Mode from which we want the dataset
    requirements : dict (string -> bool)
        Different requirements for dataset loading (gt_depth, gt_pose, etc)
    kwargs : dict
        Extra parameters for dataset creation

    Returns
    -------
    dataset : Dataset
        Dataset class for that mode
    """
    # If no dataset is given, return None
    if len(config.path) == 0:
        return None

    print0(pcolor('###### Setup %s datasets' % mode, 'red'))

    # Global shared dataset arguments
    dataset_args = {
        'back_context': config.back_context,
        'forward_context': config.forward_context,
        'data_transform': get_transforms(mode, **kwargs)
    }

    # Loop over all datasets
    datasets = []
    for i in range(len(config.split)):
        path_split = os.path.join(config.path[i], config.split[i])

        # Individual shared dataset arguments
        dataset_args_i = {
            'depth_type': config.depth_type[i] if 'gt_depth' in requirements else None,
            'input_depth_type': config.input_depth_type[i] if 'gt_depth' in requirements else None,
            'with_pose': 'gt_pose' in requirements,
        }

        # KITTI dataset
        if config.dataset[i] == 'KITTI':
            from packnet_sfm.datasets.kitti_dataset import KITTIDataset
            dataset = KITTIDataset(
                config.path[i], path_split,
                **dataset_args, **dataset_args_i,
            )
        # DGP dataset
        elif config.dataset[i] == 'DGP':
            from packnet_sfm.datasets.dgp_dataset import DGPDataset
            dataset = DGPDataset(
                config.path[i], config.split[i],
                **dataset_args, **dataset_args_i,
                cameras=config.cameras[i],
            )
        # Image dataset
        elif config.dataset[i] == 'Image':
            from packnet_sfm.datasets.image_dataset import ImageDataset
            dataset = ImageDataset(
                config.path[i], config.split[i],
                **dataset_args, **dataset_args_i,
            )
        else:
            ValueError('Unknown dataset %d' % config.dataset[i])

        # Repeat if needed
        if 'repeat' in config and config.repeat[i] > 1:
            dataset = ConcatDataset([dataset for _ in range(config.repeat[i])])
        datasets.append(dataset)

        # Display dataset information
        bar = '######### {:>7}'.format(len(dataset))
        if 'repeat' in config:
            bar += ' (x{})'.format(config.repeat[i])
        bar += ': {:<}'.format(path_split)
        print0(pcolor(bar, 'yellow'))

    # If training, concatenate all datasets into a single one
    if mode == 'train':
        datasets = [ConcatDataset(datasets)]

    return datasets
예제 #19
0
def main(args):

    # Initialize horovod
    hvd_init()

    # Parse arguments
    config, state_dict = parse_test_file(args.checkpoint)

    # If no image shape is provided, use the checkpoint one
    image_shape = args.image_shape
    if image_shape is None:
        image_shape = config.datasets.augmentation.image_shape

    # Set debug if requested
    set_debug(config.debug)

    # Initialize model wrapper from checkpoint arguments
    model_wrapper = ModelWrapper(config, load_datasets=False)
    # Restore monodepth_model state
    model_wrapper.load_state_dict(state_dict)

    # change to half precision for evaluation if requested
    dtype = torch.float16 if args.half else None

    # Send model to GPU if available
    if torch.cuda.is_available():
        model_wrapper = model_wrapper.to('cuda:{}'.format(rank()), dtype=dtype)

    # Set to eval mode
    model_wrapper.eval()

    if os.path.isdir(args.input):
        # If input file is a folder, search for image files
        files = []
        for ext in ['png', 'jpg']:
            files.extend(glob((os.path.join(args.input, '*.{}'.format(ext)))))
        files.sort()
        print0('Found {} files'.format(len(files)))
    else:
        raise RuntimeError("Input needs directory, not file")

    if not os.path.isdir(args.output):
        root, file_name = os.path.split(args.output)
        os.makedirs(root, exist_ok=True)
    else:
        raise RuntimeError("Output needs to be a file")
        

    # Process each file
    list_of_files = list(zip(files[rank()  :-2:world_size()],
                              files[rank()+1:-1:world_size()],
                              files[rank()+2:  :world_size()]))
    if args.offset:
        list_of_files = list_of_files[args.offset:]
    if args.limit:
        list_of_files = list_of_files[:args.limit]
    for fn1, fn2, fn3 in list_of_files:
        infer_and_save_pose([fn1, fn3], fn2, model_wrapper, image_shape, args.half, args.save)

    position = np.zeros(3)
    orientation = np.eye(3)
    f = open(args.output + ".txt", 'w')

    for key in sorted(poses.keys()):
        
        rot_matrix, translation = poses[key]

        # print(rot_matrix, translation)

        # print("orientation, position")
        orientation = orientation.dot(rot_matrix.tolist())
        position += orientation.dot(translation.tolist())

        # print(torch.tensor(orientation))
        q = transforms.matrix_to_quaternion(torch.tensor(orientation))
        q = q.numpy()
        # print(q[0])
        # print(position)

        f.write("%.10f %.10f %.10f %.10f %.10f %.10f %.10f\n" % (position[0], position[1], position[2], q[0][3], q[0][2], q[0][1], q[0][0]))
        # f.write("{.10f} {.10f} {.10f} {.10f} {.10f} {.10f} {.10f}"
                # .format(position[0], position[1], position[2], q[0][1], q[0][2], q[0][3], q[0][0]))
        # poses[key] = {"rot": rot_matrix.tolist(),
        #               "trans": translation.tolist(),
        #               "pose": [*orientation[0], position[0],
        #                        *orientation[1], position[1],
        #                        *orientation[2], position[2],
        #                        0, 0, 0, 1]}

    f.close()
                               
    # json.dump(poses, open(args.output, "w"), sort_keys=True)
    print(f"Written pose of {len(list_of_files)} images to {args.output}")