def interpolate_image(image, shape, mode='bilinear', align_corners=True): """ Interpolate an image to a different resolution Parameters ---------- image : torch.Tensor [B,?,h,w] Image to be interpolated shape : tuple (H, W) Output shape mode : str Interpolation mode align_corners : bool True if corners will be aligned after interpolation Returns ------- image : torch.Tensor [B,?,H,W] Interpolated image """ # Take last two dimensions as shape if len(shape) > 2: shape = shape[-2:] # If the shapes are the same, do nothing if same_shape(image.shape[-2:], shape): return image else: # Interpolate image to match the shape return funct.interpolate(image, size=shape, mode=mode, align_corners=align_corners)
def load_network(network, path, prefixes=''): """ Loads a pretrained network Parameters ---------- network : nn.Module Network that will receive the pretrained weights path : str File containing a 'state_dict' key with pretrained network weights prefixes : str or list of str Layer name prefixes to consider when loading the network Returns ------- network : nn.Module Updated network with pretrained weights """ prefixes = make_list(prefixes) # If path is a string if is_str(path): saved_state_dict = torch.load(path, map_location='cpu')['state_dict'] if path.endswith('.pth.tar'): saved_state_dict = backwards_state_dict(saved_state_dict) # If state dict is already provided else: saved_state_dict = path # Get network state dict network_state_dict = network.state_dict() updated_state_dict = OrderedDict() n, n_total = 0, len(network_state_dict.keys()) for key, val in saved_state_dict.items(): for prefix in prefixes: prefix = prefix + '.' if prefix in key: idx = key.find(prefix) + len(prefix) key = key[idx:] if key in network_state_dict.keys() and \ same_shape(val.shape, network_state_dict[key].shape): updated_state_dict[key] = val n += 1 network.load_state_dict(updated_state_dict, strict=False) base_color, attrs = 'cyan', ['bold', 'dark'] color = 'green' if n == n_total else 'yellow' if n > 0 else 'red' print0( pcolor('###### Pretrained {} loaded:'.format(prefixes[0]), base_color, attrs=attrs) + pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) + pcolor('tensors', base_color, attrs=attrs)) return network
def match_scales(image, targets, num_scales, mode='bilinear', align_corners=True): """ Interpolate one image to produce a list of images with the same shape as targets Parameters ---------- image : torch.Tensor [B,?,h,w] Input image targets : list of torch.Tensor [B,?,?,?] Tensors with the target resolutions num_scales : int Number of considered scales mode : str Interpolation mode align_corners : bool True if corners will be aligned after interpolation Returns ------- images : list of torch.Tensor [B,?,?,?] List of images with the same resolutions as targets """ # For all scales images = [] image_shape = image.shape[-2:] for i in range(num_scales): target_shape = targets[i].shape # If image shape is equal to target shape if same_shape(image_shape, target_shape): images.append(image) else: # Otherwise, interpolate images.append( interpolate_image(image, target_shape, mode=mode, align_corners=align_corners)) # Return scaled images return images