Beispiel #1
0
def load_class(filename, paths, concat=True):
    """
    Look for a file in different locations and return its method with the same name
    Optionally, you can use concat to search in path.filename instead

    Parameters
    ----------
    filename : str
        Name of the file we are searching for
    paths : str or list of str
        Folders in which the file will be searched
    concat : bool
        Flag to concatenate filename to each path during the search

    Returns
    -------
    method : Function
        Loaded method
    """
    # for each path in paths
    for path in make_list(paths):
        # Create full path
        full_path = '{}.{}'.format(path, filename) if concat else path
        if importlib.util.find_spec(full_path):
            # Return method with same name as the file
            return getattr(importlib.import_module(full_path), filename)
    raise ValueError('Unknown class {}'.format(filename))
Beispiel #2
0
def load_network(network, path, prefixes=''):
    """
    Loads a pretrained network

    Parameters
    ----------
    network : nn.Module
        Network that will receive the pretrained weights
    path : str
        File containing a 'state_dict' key with pretrained network weights
    prefixes : str or list of str
        Layer name prefixes to consider when loading the network

    Returns
    -------
    network : nn.Module
        Updated network with pretrained weights
    """
    prefixes = make_list(prefixes)
    # If path is a string
    if is_str(path):
        saved_state_dict = torch.load(path, map_location='cpu')['state_dict']
        if path.endswith('.pth.tar'):
            saved_state_dict = backwards_state_dict(saved_state_dict)
    # If state dict is already provided
    else:
        saved_state_dict = path
    # Get network state dict
    network_state_dict = network.state_dict()

    updated_state_dict = OrderedDict()
    n, n_total = 0, len(network_state_dict.keys())
    for key, val in saved_state_dict.items():
        for prefix in prefixes:
            prefix = prefix + '.'
            if prefix in key:
                idx = key.find(prefix) + len(prefix)
                key = key[idx:]
                if key in network_state_dict.keys() and \
                        same_shape(val.shape, network_state_dict[key].shape):
                    updated_state_dict[key] = val
                    n += 1

    network.load_state_dict(updated_state_dict, strict=False)
    base_color, attrs = 'cyan', ['bold', 'dark']
    color = 'green' if n == n_total else 'yellow' if n > 0 else 'red'
    print0(
        pcolor('###### Pretrained {} loaded:'.format(prefixes[0]),
               base_color,
               attrs=attrs) +
        pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) +
        pcolor('tensors', base_color, attrs=attrs))
    return network
Beispiel #3
0
def prep_dataset(config):
    """
    Expand dataset configuration to match split length

    Parameters
    ----------
    config : CfgNode
        Dataset configuration

    Returns
    -------
    config : CfgNode
        Updated dataset configuration
    """
    # If there is no dataset, do nothing
    if len(config.path) == 0:
        return config
    # If cameras != a double list, make it so
    if not config.cameras or not is_list(config.cameras[0]):
        config.cameras = [config.cameras]
    # Get maximum length and expand other arguments to the same length
    n = max(len(config.split), len(config.cameras), len(config.depth_type))
    config.dataset = make_list(config.dataset, n)
    config.path = make_list(config.path, n)
    config.split = make_list(config.split, n)
    config.depth_type = make_list(config.depth_type, n)
    config.cameras = make_list(config.cameras, n)
    if 'repeat' in config:
        config.repeat = make_list(config.repeat, n)
    # Return updated configuration
    return config
Beispiel #4
0
 def compute_inv_depths(self, image):
     """Computes inverse depth maps from single images"""
     # Randomly flip and estimate inverse depth maps
     flip_lr = random.random(
     ) < self.flip_lr_prob if self.training else False
     inv_depths = make_list(flip_model(self.depth_net, image, flip_lr))
     # If upsampling depth maps
     if self.upsample_depth_maps:
         inv_depths = interpolate_scales(inv_depths,
                                         mode='nearest',
                                         align_corners=None)
     # Return inverse depth maps
     return inv_depths
Beispiel #5
0
    def __getitem__(self, idx):
        """Get a dataset sample"""
        # Get DGP sample (if single sensor, make it a list)
        self.sample_dgp = self.dataset[idx]
        self.sample_dgp = [make_list(sample) for sample in self.sample_dgp]

        # Loop over all cameras
        sample = []
        for i in range(self.num_cameras):
            data = {
                'idx': idx,
                'dataset_idx': self.dataset_idx,
                'sensor_name': self.get_current('datum_name', i),
                #
                'filename': self.get_filename(idx, i),
                'splitname': '%s_%010d' % (self.split, idx),
                #
                'rgb': self.get_current('rgb', i),
                'intrinsics': self.get_current('intrinsics', i),
            }

            # If depth is returned
            if self.with_depth:
                data.update({
                    'depth': self.generate_depth_map(idx, i, data['filename'])
                })

            # If pose is returned
            if self.with_pose:
                data.update({
                    'extrinsics': self.get_current('extrinsics', i).matrix,
                    'pose': self.get_current('pose', i).matrix,
                })

            # If context is returned
            if self.has_context:
                data.update({
                    'rgb_context': self.get_context('rgb', i),
                })
                # If context pose is returned
                if self.with_pose:
                    # Get original values to calculate relative motion
                    orig_extrinsics = Pose.from_matrix(data['extrinsics'])
                    orig_pose = Pose.from_matrix(data['pose'])
                    data.update({
                        'extrinsics_context':
                            [(orig_extrinsics.inverse() * extrinsics).matrix
                             for extrinsics in self.get_context('extrinsics', i)],
                        'pose_context':
                            [(orig_pose.inverse() * pose).matrix
                             for pose in self.get_context('pose', i)],
                    })

            sample.append(data)

        # Apply same data transformations for all sensors
        if self.data_transform:
            sample = [self.data_transform(smp) for smp in sample]

        # Return sample (stacked if necessary)
        return stack_sample(sample)