示例#1
0
def load_network(network, path, prefixes=''):
    """
    Loads a pretrained network

    Parameters
    ----------
    network : nn.Module
        Network that will receive the pretrained weights
    path : str
        File containing a 'state_dict' key with pretrained network weights
    prefixes : str or list of str
        Layer name prefixes to consider when loading the network

    Returns
    -------
    network : nn.Module
        Updated network with pretrained weights
    """
    prefixes = make_list(prefixes)
    # If path is a string
    if is_str(path):
        saved_state_dict = torch.load(path, map_location='cpu')['state_dict']
        if path.endswith('.pth.tar'):
            saved_state_dict = backwards_state_dict(saved_state_dict)
    # If state dict is already provided
    else:
        saved_state_dict = path
    # Get network state dict
    network_state_dict = network.state_dict()

    updated_state_dict = OrderedDict()
    n, n_total = 0, len(network_state_dict.keys())
    for key, val in saved_state_dict.items():
        for prefix in prefixes:
            prefix = prefix + '.'
            if prefix in key:
                idx = key.find(prefix) + len(prefix)
                key = key[idx:]
                if key in network_state_dict.keys() and \
                        same_shape(val.shape, network_state_dict[key].shape):
                    updated_state_dict[key] = val
                    n += 1

    network.load_state_dict(updated_state_dict, strict=False)
    base_color, attrs = 'cyan', ['bold', 'dark']
    color = 'green' if n == n_total else 'yellow' if n > 0 else 'red'
    print0(
        pcolor('###### Pretrained {} loaded:'.format(prefixes[0]),
               base_color,
               attrs=attrs) +
        pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) +
        pcolor('tensors', base_color, attrs=attrs))
    return network
示例#2
0
 def prepare_model(self, resume=None):
     """Prepare self.model (incl. loading previous state)"""
     print0(pcolor('### Preparing Model', 'green'))
     self.model = setup_model(self.config.model, self.config.prepared)
     # Resume model if available
     if resume:
         print0(
             pcolor('### Resuming from {}'.format(resume['file']),
                    'magenta',
                    attrs=['bold']))
         self.model = load_network(self.model, resume['state_dict'],
                                   'model')
         if 'epoch' in resume:
             self.current_epoch = resume['epoch']
示例#3
0
def setup_pose_net(config, prepared, **kwargs):
    """
    Create a pose network

    Parameters
    ----------
    config : CfgNode
        Network configuration
    prepared : bool
        True if the network has been prepared before
    kwargs : dict
        Extra parameters for the network

    Returns
    -------
    pose_net : nn.Module
        Created pose network
    """
    print0(pcolor('PoseNet: %s' % config.name, 'yellow'))
    pose_net = load_class_args_create(
        config.name,
        paths=[
            'outdoornet.networks.pose',
        ],
        args={
            **config,
            **kwargs
        },
    )
    if not prepared and config.checkpoint_path is not '':
        pose_net = load_network(pose_net, config.checkpoint_path,
                                ['pose_net', 'pose_network'])
    return pose_net
示例#4
0
 def _sync_s3(self, filepath, model):
     # If it's not time to sync, do nothing
     if self.s3_enabled and (model.current_epoch +
                             1) % self.s3_frequency == 0:
         filepath = os.path.dirname(filepath)
         # Print message and links
         print(
             pcolor('###### Syncing: {} -> {}'.format(
                 filepath, model.config.checkpoint.s3_path),
                    'red',
                    attrs=['bold']))
         print(
             pcolor('###### URL: {}'.format(model.config.checkpoint.s3_url),
                    'red',
                    attrs=['bold']))
         # If it's time to save code
         if self.save_code:
             self.save_code = False
             save_code(filepath)
         # Sync model to s3
         sync_s3_data(filepath, model)
示例#5
0
    def prepare_datasets(self, validation_requirements, test_requirements):
        """Prepare dataandgeo for training, validation and test."""
        # Prepare dataandgeo
        print0(pcolor('### Preparing Datasets', 'green'))

        augmentation = self.config.datasets.augmentation
        # Setup train dataset (requirements are given by the model itself)
        self.train_dataset = setup_dataset(self.config.datasets.train, 'train',
                                           self.model.train_requirements,
                                           **augmentation)
        # Setup validation dataset
        self.validation_dataset = setup_dataset(
            self.config.datasets.validation, 'validation',
            validation_requirements, **augmentation)
        # Setup test dataset
        self.test_dataset = setup_dataset(self.config.datasets.test, 'test',
                                          test_requirements, **augmentation)
示例#6
0
def setup_model(config, prepared, **kwargs):
    """
    Create a model

    Parameters
    ----------
    config : CfgNode
        Model configuration (cf. configs/default_config.py)
    prepared : bool
        True if the model has been prepared before
    kwargs : dict
        Extra parameters for the model

    Returns
    -------
    model : nn.Module
        Created model
    """
    print0(pcolor('Model: %s' % config.name, 'yellow'))
    model = load_class(config.name, paths=[
        'outdoornet.model',
    ])(**{
        **config.loss,
        **kwargs
    })
    # Add depth network if required
    if model.network_requirements['depth_net']:
        model.add_depth_net(setup_depth_net(config.depth_net, prepared))
    # Add pose network if required
    if model.network_requirements['pose_net']:
        model.add_pose_net(setup_pose_net(config.pose_net, prepared))
    # If a checkpoint is provided, load pretrained model
    if not prepared and config.checkpoint_path is not '':
        model = load_network(model, config.checkpoint_path, 'model')
    # Return model
    return model
示例#7
0
def setup_dataset(config, mode, requirements, **kwargs):
    """
    Create a dataset class

    Parameters
    ----------
    config : CfgNode
        Configuration (cf. configs/default_config.py)
    mode : str {'train', 'validation', 'test'}
        Mode from which we want the dataset
    requirements : dict (string -> bool)
        Different requirements for dataset loading (gt_depth, gt_pose, etc)
    kwargs : dict
        Extra parameters for dataset creation

    Returns
    -------
    dataset : Dataset
        Dataset class for that mode
    """
    # If no dataset is given, return None
    if len(config.path) == 0:
        return None

    print0(pcolor('###### Setup %s dataandgeo' % mode, 'red'))

    # Global shared dataset arguments
    dataset_args = {
        'back_context': config.back_context,
        'forward_context': config.forward_context,
        'data_transform': get_transforms(mode, **kwargs)
    }

    # Loop over all dataandgeo
    datasets = []
    for i in range(len(config.split)):
        path_split = os.path.join(config.path[i], config.split[i])

        # Individual shared dataset arguments
        dataset_args_i = {
            'depth_type':
            config.depth_type[i] if requirements['gt_depth'] else None,
            'with_pose': requirements['gt_pose'],
        }

        # KITTI dataset
        if config.dataset[i] == 'KITTI':
            from outdoornet.dataandgeo.kitti_dataset import KITTIDataset
            dataset = KITTIDataset(
                config.path[i],
                path_split,
                **dataset_args,
                **dataset_args_i,
            )
        # DGP dataset
        elif config.dataset[i] == 'DGP':
            from outdoornet.dataandgeo.dgp_dataset import DGPDataset
            dataset = DGPDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
                cameras=config.cameras[i],
            )
        # Image dataset
        elif config.dataset[i] == 'Image':
            from outdoornet.dataandgeo.image_dataset import ImageDataset
            dataset = ImageDataset(
                config.path[i],
                config.split[i],
                **dataset_args,
                **dataset_args_i,
            )
        else:
            ValueError('Unknown dataset %d' % config.dataset[i])

        # Repeat if needed
        if 'repeat' in config and config.repeat[i] > 1:
            dataset = ConcatDataset([dataset for _ in range(config.repeat[i])])
        datasets.append(dataset)

        # Display dataset information
        bar = '######### {:>7}'.format(len(dataset))
        if 'repeat' in config:
            bar += ' (x{})'.format(config.repeat[i])
        bar += ': {:<}'.format(path_split)
        print0(pcolor(bar, 'yellow'))

    # If training, concatenate all dataandgeo into a single one
    if mode == 'train':
        datasets = [ConcatDataset(datasets)]

    return datasets
示例#8
0
    def print_metrics(self, metrics_data, dataset):
        """Print depth metrics on rank 0 if available"""
        if not metrics_data[0]:
            return

        hor_line = '|{:<}|'.format('*' * 93)
        met_line = '| {:^14} | {:^8} | {:^8} | {:^8} | {:^8} | {:^8} | {:^8} | {:^8} |'
        num_line = '{:<14} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f}'

        def wrap(string):
            return '| {} |'.format(string)

        print()
        print()
        print()
        print(hor_line)

        if self.optimizer is not None:
            bs = 'E: {} BS: {}'.format(self.current_epoch + 1,
                                       self.config.datasets.train.batch_size)
            if self.model is not None:
                bs += ' - {}'.format(self.config.model.name)
            lr = 'LR ({}):'.format(self.config.model.optimizer.name)
            for param in self.optimizer.param_groups:
                lr += ' {} {:.2e}'.format(param['name'], param['lr'])
            par_line = wrap(
                pcolor('{:<40}{:>51}'.format(bs, lr),
                       'green',
                       attrs=['bold', 'dark']))
            print(par_line)
            print(hor_line)

        print(met_line.format(*(('METRIC', ) + self.metrics_keys)))
        for n, metrics in enumerate(metrics_data):
            print(hor_line)
            path_line = '{}'.format(
                os.path.join(dataset.path[n], dataset.split[n]))
            if len(dataset.cameras[n]) == 1:  # only allows single cameras
                path_line += ' ({})'.format(dataset.cameras[n][0])
            print(
                wrap(
                    pcolor('*** {:<87}'.format(path_line),
                           'magenta',
                           attrs=['bold'])))
            print(hor_line)
            for key, metric in metrics.items():
                if self.metrics_name in key:
                    print(
                        wrap(
                            pcolor(
                                num_line.format(*((key.upper(), ) +
                                                  tuple(metric.tolist()))),
                                'cyan')))
        print(hor_line)

        if self.logger:
            run_line = wrap(
                pcolor('{:<60}{:>31}'.format(self.config.wandb.url,
                                             self.config.wandb.name),
                       'yellow',
                       attrs=['dark']))
            print(run_line)
            print(hor_line)

        print()
示例#9
0
def infer_and_save_depth(input_file, output_file, model_wrapper, image_shape,
                         half, save):
    """
    Process a single input file to produce and save visualization

    Parameters
    ----------
    input_file : str
        Image file
    output_file : str
        Output file, or folder where the output will be saved
    model_wrapper : nn.Module
        Model wrapper used for inference
    image_shape : Image shape
        Input image shape
    half: bool
        use half precision (fp16)
    save: str
        Save format (npz or png)
    """
    if not is_image(output_file):
        # If not an image, assume it's a folder and append the input name
        os.makedirs(output_file, exist_ok=True)
        output_file = os.path.join(output_file, os.path.basename(input_file))

    # change to half precision for evaluation if requested
    dtype = torch.float16 if half else None

    # Load image
    image = load_image(input_file)
    # Resize and to tensor
    image = resize_image(image, image_shape)
    image = to_tensor(image).unsqueeze(0)

    # Send image to GPU if available
    if torch.cuda.is_available():
        image = image.to('cuda:{}'.format(rank()), dtype=dtype)

    # Depth inference (returns predicted inverse depth)
    pred_inv_depth = model_wrapper.depth(image)[0]

    # cvimage = cv2.imread(pred_inv_depth)
    # cv2.imshow("img", cvimage)
    # cv2.waitKey(0)

    if save == 'npz' or save == 'png':
        # Get depth from predicted depth map and save to different formats
        filename = '{}.{}'.format(os.path.splitext(output_file)[0], save)
        print('Saving {} to {}'.format(
            pcolor(input_file, 'cyan', attrs=['bold']),
            pcolor(filename, 'magenta', attrs=['bold'])))
        write_depth(filename, depth=inv2depth(pred_inv_depth))
        print("aaaaaaaaabbbbbbbbbbbbbbb")

    else:
        print("aaaaaaaaabbbbbbbbbbbbbbbccccccccccccccccc11")
        # Prepare RGB image
        rgb = image[0].permute(1, 2, 0).detach().cpu().numpy() * 255
        # Prepare inverse depth
        viz_pred_inv_depth = viz_inv_depth(pred_inv_depth[0]) * 255

        # Concatenate both vertically
        image = np.concatenate([rgb, viz_pred_inv_depth], 0)

        # plt.figure(figsize=(10, 5))
        # plt.imshow(image)

        # Save visualization
        print('Saving {} to {}'.format(
            pcolor(input_file, 'cyan', attrs=['bold']),
            pcolor(output_file, 'magenta', attrs=['bold'])))
        cv2.imwrite(output_file, image[:, :, ::-1])