def interpolate_image(image, shape, mode='bilinear', align_corners=True):
    """
    Interpolate an image to a different resolution

    Parameters
    ----------
    image : torch.Tensor [B,?,h,w]
        Image to be interpolated
    shape : tuple (H, W)
        Output shape
    mode : str
        Interpolation mode
    align_corners : bool
        True if corners will be aligned after interpolation

    Returns
    -------
    image : torch.Tensor [B,?,H,W]
        Interpolated image
    """
    # Take last two dimensions as shape
    if len(shape) > 2:
        shape = shape[-2:]
    # If the shapes are the same, do nothing
    if same_shape(image.shape[-2:], shape):
        return image
    else:
        # Interpolate image to match the shape
        return funct.interpolate(image,
                                 size=shape,
                                 mode=mode,
                                 align_corners=align_corners)
Esempio n. 2
0
def load_network(network, path, prefixes=''):
    """
    Loads a pretrained network

    Parameters
    ----------
    network : nn.Module
        Network that will receive the pretrained weights
    path : str
        File containing a 'state_dict' key with pretrained network weights
    prefixes : str or list of str
        Layer name prefixes to consider when loading the network

    Returns
    -------
    network : nn.Module
        Updated network with pretrained weights
    """
    prefixes = make_list(prefixes)
    # If path is a string
    if is_str(path):
        saved_state_dict = torch.load(path, map_location='cpu')['state_dict']
        if path.endswith('.pth.tar'):
            saved_state_dict = backwards_state_dict(saved_state_dict)
    # If state dict is already provided
    else:
        saved_state_dict = path
    # Get network state dict
    network_state_dict = network.state_dict()

    updated_state_dict = OrderedDict()
    n, n_total = 0, len(network_state_dict.keys())
    for key, val in saved_state_dict.items():
        for prefix in prefixes:
            prefix = prefix + '.'
            if prefix in key:
                idx = key.find(prefix) + len(prefix)
                key = key[idx:]
                if key in network_state_dict.keys() and \
                        same_shape(val.shape, network_state_dict[key].shape):
                    updated_state_dict[key] = val
                    n += 1

    network.load_state_dict(updated_state_dict, strict=False)
    base_color, attrs = 'cyan', ['bold', 'dark']
    color = 'green' if n == n_total else 'yellow' if n > 0 else 'red'
    print0(
        pcolor('###### Pretrained {} loaded:'.format(prefixes[0]),
               base_color,
               attrs=attrs) +
        pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) +
        pcolor('tensors', base_color, attrs=attrs))
    return network
def match_scales(image,
                 targets,
                 num_scales,
                 mode='bilinear',
                 align_corners=True):
    """
    Interpolate one image to produce a list of images with the same shape as targets

    Parameters
    ----------
    image : torch.Tensor [B,?,h,w]
        Input image
    targets : list of torch.Tensor [B,?,?,?]
        Tensors with the target resolutions
    num_scales : int
        Number of considered scales
    mode : str
        Interpolation mode
    align_corners : bool
        True if corners will be aligned after interpolation

    Returns
    -------
    images : list of torch.Tensor [B,?,?,?]
        List of images with the same resolutions as targets
    """
    # For all scales
    images = []
    image_shape = image.shape[-2:]
    for i in range(num_scales):
        target_shape = targets[i].shape
        # If image shape is equal to target shape
        if same_shape(image_shape, target_shape):
            images.append(image)
        else:
            # Otherwise, interpolate
            images.append(
                interpolate_image(image,
                                  target_shape,
                                  mode=mode,
                                  align_corners=align_corners))
    # Return scaled images
    return images