def stack_batch(batch):
    """
    Stack multi-camera batches (B,N,C,H,W becomes BN,C,H,W)

    Parameters
    ----------
    batch : dict
        Batch

    Returns
    -------
    batch : dict
        Stacked batch
    """
    # If there is multi-camera information
    if len(batch['rgb'].shape) == 5:
        assert batch['rgb'].shape[
            0] == 1, 'Only batch size 1 is supported for multi-cameras'
        # Loop over all keys
        for key in batch.keys():
            # If list, stack every item
            if is_list(batch[key]):
                if is_tensor(batch[key][0]) or is_numpy(batch[key][0]):
                    batch[key] = [sample[0] for sample in batch[key]]
            # Else, stack single item
            else:
                batch[key] = batch[key][0]
    return batch
Esempio n. 2
0
def stack_sample(sample):
    """Stack a sample from multiple sensors"""
    # If there is only one sensor don't do anything
    if len(sample) == 1:
        return sample[0]

    # Otherwise, stack sample
    stacked_sample = {}
    for key in sample[0]:
        # Global keys (do not stack)
        if key in ['idx', 'dataset_idx', 'sensor_name', 'filename']:
            stacked_sample[key] = sample[0][key]
        else:
            # Stack torch tensors
            if is_tensor(sample[0][key]):
                stacked_sample[key] = torch.stack([s[key] for s in sample], 0)
            # Stack numpy arrays
            elif is_numpy(sample[0][key]):
                stacked_sample[key] = np.stack([s[key] for s in sample], 0)
            # Stack list
            elif is_list(sample[0][key]):
                stacked_sample[key] = []
                # Stack list of torch tensors
                if is_tensor(sample[0][key][0]):
                    for i in range(len(sample[0][key])):
                        stacked_sample[key].append(
                            torch.stack([s[key][i] for s in sample], 0))
                # Stack list of numpy arrays
                if is_numpy(sample[0][key][0]):
                    for i in range(len(sample[0][key])):
                        stacked_sample[key].append(
                            np.stack([s[key][i] for s in sample], 0))

    # Return stacked sample
    return stacked_sample
Esempio n. 3
0
def prep_dataset(config):
    """
    Expand dataset configuration to match split length

    Parameters
    ----------
    config : CfgNode
        Dataset configuration

    Returns
    -------
    config : CfgNode
        Updated dataset configuration
    """
    # If there is no dataset, do nothing
    if len(config.path) == 0:
        return config
    # If cameras is not a double list, make it so
    if not config.cameras or not is_list(config.cameras[0]):
        config.cameras = [config.cameras]
    # Get maximum length and expand other arguments to the same length
    n = max(len(config.split), len(config.cameras), len(config.depth_type))
    config.dataset = make_list(config.dataset, n)
    config.path = make_list(config.path, n)
    config.split = make_list(config.split, n)
    config.depth_type = make_list(config.depth_type, n)
    config.cameras = make_list(config.cameras, n)
    if 'repeat' in config:
        config.repeat = make_list(config.repeat, n)
    # Return updated configuration
    return config
def flip(tensor, flip_fn):
    """
    Flip tensors or list of tensors based on a function

    Parameters
    ----------
    tensor : torch.Tensor or list[torch.Tensor] or list[list[torch.Tensor]]
        Tensor to be flipped
    flip_fn : Function
        Flip function

    Returns
    -------
    tensor : torch.Tensor or list[torch.Tensor] or list[list[torch.Tensor]]
        Flipped tensor or list of tensors
    """
    if not is_list(tensor):
        return flip_fn(tensor)
    else:
        if not is_list(tensor[0]):
            return [flip_fn(val) for val in tensor]
        else:
            return [[flip_fn(v) for v in val] for val in tensor]
Esempio n. 5
0
    def __call__(self, progress):
        """
        Call for an update in the number of scales

        Parameters
        ----------
        progress : float
            Training progress percentage

        Returns
        -------
        num_scales : int
            New number of scales
        """
        if is_list(self.progressive_scaling):
            return int(self.num_scales -
                       np.searchsorted(self.progressive_scaling, progress))
        else:
            return self.num_scales
Esempio n. 6
0
def make_list(var, n=None):
    """
    Wraps the input into a list, and optionally repeats it to be size n

    Parameters
    ----------
    var : Any
        Variable to be wrapped in a list
    n : int
        How much the wrapped variable will be repeated

    Returns
    -------
    var_list : list
        List generated from var
    """
    var = var if is_list(var) else [var]
    if n is None:
        return var
    else:
        assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list'
        return var * n if len(var) == 1 else var
Esempio n. 7
0
def log_inv_depth(key, prefix, batch, i=0):
    """
    Converts an inverse depth map from a batch for logging

    Parameters
    ----------
    key : str
        Key from data containing the inverse depth map
    prefix : str
        Prefix added to the key for logging
    batch : dict
        Dictionary containing the key
    i : int
        Batch index from which to get the inverse depth map

    Returns
    -------
    image : wandb.Image
        Wandb image ready for logging
    """
    inv_depth = batch[key] if is_dict(batch) else batch
    inv_depth = inv_depth[0] if is_list(inv_depth) else inv_depth
    return prep_image(prefix, key, viz_inv_depth(inv_depth[i]))
    def forward(self,
                image,
                context,
                inv_depths,
                poses,
                path_to_ego_mask,
                path_to_ego_mask_context,
                K,
                ref_K,
                extrinsics,
                ref_extrinsics,
                context_type,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        inv_depths2 = []
        for k in range(len(inv_depths)):

            Btmp, C, H, W = inv_depths[k].shape
            depths2 = torch.zeros_like(inv_depths[k])
            for i in range(H):
                for j in range(W):
                    depths2[
                        0, 0, i,
                        j] = 2.0  #(2/H)**4 * (2/W)**4 * i**2 * (H - i)**2 * j**2 * (W - j)**2 * 20

            inv_depths2.append(depth2inv(depths2))

        if is_list(context_type[0][0]):
            n_context = len(context_type[0])
        else:
            n_context = len(context_type)

        #n_context = len(context)
        device = image.get_device()
        B = len(path_to_ego_mask)
        if is_list(path_to_ego_mask[0]):
            H_full, W_full = np.load(path_to_ego_mask[0][0]).shape
        else:
            H_full, W_full = np.load(path_to_ego_mask[0]).shape

        # getting ego masks for target and source cameras
        # fullsize mask
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor = []
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(
                torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            if self.mask_ego:
                if is_list(path_to_ego_mask[b]):
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b][0])).float()
                else:
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    if is_list(path_to_ego_mask_context[0][0]):
                        paths_context_ego = [
                            p[i_context][0] for p in path_to_ego_mask_context
                        ]
                    else:
                        paths_context_ego = path_to_ego_mask_context[i_context]
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(paths_context_ego[b])).float()
        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors = []
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            Btmp, C, H, W = images[i].shape
            if W < W_full:
                #inv_scale_factor = int(W_full / W)
                #print(W_full / W)
                #ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor,
                                      shape=(Btmp, 1, H, W),
                                      mode='nearest',
                                      align_corners=None))
                for i_context in range(n_context):
                    #ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                    ref_ego_mask_tensors[i_context].append(
                        interpolate_image(ref_ego_mask_tensor[i_context],
                                          shape=(Btmp, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        ref_ego_mask_tensor[i_context])
        for i_context in range(n_context):
            _, C, H, W = context[i_context].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                #ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                ref_ego_mask_tensor[i_context] = interpolate_image(
                    ref_ego_mask_tensor[i_context],
                    shape=(Btmp, 1, H, W),
                    mode='nearest',
                    align_corners=None)

        print(ref_extrinsics)
        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            print(ref_extrinsics[j])
            ref_context_type = [c[j][0] for c in context_type] if is_list(
                context_type[0][0]) else context_type[j]
            print(ref_context_type)
            print(pose.mat)
            # Calculate warped images
            ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image(
                inv_depths2, ref_image, ref_ego_mask_tensor[j], K,
                ref_K[:, j, :, :], pose, ref_extrinsics[j], ref_context_type)
            print(pose.mat)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            tt = str(int(time.time() % 10000))
            for i in range(self.n):
                B, C, H, W = images[i].shape
                for b in range(B):
                    orig_PIL_0 = torch.transpose(
                        (ref_image[b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    orig_PIL = torch.transpose(
                        (ref_image[b, :, :, :] *
                         ref_ego_mask_tensor[j][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    warped_PIL_0 = torch.transpose(
                        (ref_warped[i][b, :, :, :]).unsqueeze(0).unsqueeze(4),
                        1, 4).squeeze().detach().cpu().numpy()
                    warped_PIL = torch.transpose(
                        (ref_warped[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL_0 = torch.transpose(
                        (images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL = torch.transpose(
                        (images[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()

                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png',
                        orig_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png',
                        orig_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png',
                        warped_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png',
                        warped_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png',
                        target_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png',
                        target_PIL * 255)

            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i] *
                                             ego_mask_tensors[i] *
                                             ref_ego_mask_tensors_warped[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ref_ego_mask_tensors[j][i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }