Esempio n. 1
0
def load_class(filename, paths, concat=True):
    """
    Look for a file in different locations and return its method with the same name
    Optionally, you can use concat to search in path.filename instead

    Parameters
    ----------
    filename : str
        Name of the file we are searching for
    paths : str or list of str
        Folders in which the file will be searched
    concat : bool
        Flag to concatenate filename to each path during the search

    Returns
    -------
    method : Function
        Loaded method
    """
    # for each path in paths
    for path in make_list(paths):
        # Create full path
        full_path = '{}.{}'.format(path, filename) if concat else path
        if importlib.util.find_spec(full_path):
            # Return method with same name as the file
            return getattr(importlib.import_module(full_path), filename)
    raise ValueError('Unknown class {}'.format(filename))
Esempio n. 2
0
def load_network(network, path, prefixes=''):
    """
    Loads a pretrained network

    Parameters
    ----------
    network : nn.Module
        Network that will receive the pretrained weights
    path : str
        File containing a 'state_dict' key with pretrained network weights
    prefixes : str or list of str
        Layer name prefixes to consider when loading the network

    Returns
    -------
    network : nn.Module
        Updated network with pretrained weights
    """
    prefixes = make_list(prefixes)
    # If path is a string
    if is_str(path):
        saved_state_dict = torch.load(path, map_location='cpu')['state_dict']
        if path.endswith('.pth.tar'):
            saved_state_dict = backwards_state_dict(saved_state_dict)
    # If state dict is already provided
    else:
        saved_state_dict = path
    # Get network state dict
    network_state_dict = network.state_dict()

    updated_state_dict = OrderedDict()
    n, n_total = 0, len(network_state_dict.keys())
    for key, val in saved_state_dict.items():
        for prefix in prefixes:
            prefix = prefix + '.'
            if prefix in key:
                idx = key.find(prefix) + len(prefix)
                key = key[idx:]
                if key in network_state_dict.keys() and \
                        same_shape(val.shape, network_state_dict[key].shape):
                    updated_state_dict[key] = val
                    n += 1
    try:
        network.load_state_dict(updated_state_dict, strict=True)
    except Exception as e:
        print(e)
        network.load_state_dict(updated_state_dict, strict=False)
    base_color, attrs = 'cyan', ['bold', 'dark']
    color = 'green' if n == n_total else 'yellow' if n > 0 else 'red'
    print0(pcolor('=====###### Pretrained {} loaded:'.format(prefixes[0]), base_color, attrs=attrs) +
          pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) +
          pcolor('tensors', base_color, attrs=attrs))
    return network
Esempio n. 3
0
 def compute_inv_depths(self, image):
     """Computes inverse depth maps from single images"""
     # Randomly flip and estimate inverse depth maps
     flip_lr = random.random(
     ) < self.flip_lr_prob if self.training else False
     inv_depths = make_list(flip_model(self.depth_net, image, flip_lr))
     # If upsampling depth maps
     if self.upsample_depth_maps:
         inv_depths = interpolate_scales(inv_depths,
                                         mode='nearest',
                                         align_corners=None)
     # Return inverse depth maps
     return inv_depths
Esempio n. 4
0
 def compute_inv_depths(self, image, ref_imgs, intrinsics):
     """Computes inverse depth maps from single images"""
     # Randomly flip and estimate inverse depth maps
     flip_lr = random.random(
     ) < self.flip_lr_prob if self.training else False
     if flip_lr:
         intrinsics = flip_lr_intr(intrinsics, width=image.shape[3])
     inv_depths_with_poses = flip_mf_model(self.depth_net, image, ref_imgs,
                                           intrinsics, flip_lr)
     inv_depths, poses = inv_depths_with_poses
     inv_depths = make_list(inv_depths)
     if flip_lr:
         inv_depths = [flip_lr_img(inv_d) for inv_d in inv_depths]
     # If upsampling depth maps
     if self.upsample_depth_maps:
         inv_depths = interpolate_scales(inv_depths,
                                         mode='nearest',
                                         align_corners=None)
     # Return inverse depth maps
     return inv_depths, poses
Esempio n. 5
0
    def __getitem__(self, idx):
        """Get a dataset sample"""
        # Get DGP sample (if single sensor, make it a list)
        self.sample_dgp = self.dataset[idx]
        self.sample_dgp = [make_list(sample) for sample in self.sample_dgp]

        # Loop over all cameras
        sample = []
        for i in range(self.num_cameras):
            data = {
                'idx': idx,
                'dataset_idx': self.dataset_idx,
                'sensor_name': self.get_current('datum_name', i),
                #
                'filename': self.get_filename(idx, i),
                'splitname': '%s_%010d' % (self.split, idx),
                #
                'rgb': self.get_current('rgb', i),
                'intrinsics': self.get_current('intrinsics', i),
            }

            # If depth is returned
            if self.with_depth:
                data.update({
                    'depth':
                    self.generate_depth_map(idx, i, data['filename'])
                })

            # If pose is returned
            if self.with_pose:
                data.update({
                    'extrinsics': self.get_current('extrinsics', i).matrix,
                    'pose': self.get_current('pose', i).matrix,
                })

            # If context is returned
            if self.has_context:
                data.update({
                    'rgb_context': self.get_context('rgb', i),
                })
                # If context pose is returned
                if self.with_pose:
                    # Get original values to calculate relative motion
                    orig_extrinsics = Pose.from_matrix(data['extrinsics'])
                    orig_pose = Pose.from_matrix(data['pose'])
                    data.update({
                        'extrinsics_context':
                        [(orig_extrinsics.inverse() * extrinsics).matrix
                         for extrinsics in self.get_context('extrinsics', i)],
                        'pose_context':
                        [(orig_pose.inverse() * pose).matrix
                         for pose in self.get_context('pose', i)],
                    })

            sample.append(data)

        # Apply same data transformations for all sensors
        if self.data_transform:
            sample = [self.data_transform(smp) for smp in sample]

        # Return sample (stacked if necessary)
        return stack_sample(sample)