Exemplo n.º 1
0
 def scatter(self, base_grid, indices, values, mode: str):
     base_grid, values = self.auto_cast(base_grid, values)
     batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
     spatial_dims = tuple(range(base_grid.ndim - 2))
     dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,),  # channel dim of updates (batch dim removed)
                                             inserted_window_dims=spatial_dims,  # no idea what this does but spatial_dims seems to work
                                             scatter_dims_to_operand_dims=spatial_dims)  # spatial dims of base_grid (batch dim removed)
     scatter = jax.lax.scatter_add if mode == 'add' else jax.lax.scatter
     result = []
     for b in range(batch_size):
         b_grid = base_grid[b, ...]
         b_indices = indices[min(b, indices.shape[0] - 1), ...]
         b_values = values[min(b, values.shape[0] - 1), ...]
         result.append(scatter(b_grid, b_indices, b_values, dnums))
     return jnp.stack(result)
Exemplo n.º 2
0
 def batched_gather_nd(self, values, indices):
     values = self.as_tensor(values)
     indices = self.as_tensor(indices).long()
     batch_size = combined_dim(values.shape[0], indices.shape[0])
     result = []
     for b in range(batch_size):
         b_indices = self.unstack(indices[min(b, indices.shape[0] - 1)], -1)
         result.append(values[(min(b, values.shape[0] - 1),) + b_indices])
     return self.stack(result, axis=0)
Exemplo n.º 3
0
 def batched_gather_nd(self, values, indices):
     assert indices.shape[-1] == self.ndims(values) - 2
     batch_size = combined_dim(values.shape[0], indices.shape[0])
     results = []
     for b in range(batch_size):
         b_values = values[min(b, values.shape[0] - 1)]
         b_indices = self.unstack(indices[min(b, indices.shape[0] - 1)], -1)
         results.append(b_values[b_indices])
     return jnp.stack(results)
Exemplo n.º 4
0
 def scatter(self, base_grid, indices, values, mode: str):
     base_grid, values = self.auto_cast(base_grid, values)
     indices = self.as_tensor(indices)
     batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
     scatter = torch.scatter_add if mode == 'add' else torch.scatter
     if indices.shape[0] < batch_size:
         indices = indices.repeat([batch_size] + [1] * (len(indices.shape)-1))
     if values.shape[0] < batch_size or values.shape[1] == 1:
         values = values.repeat([batch_size // values.shape[0], indices.shape[1] // indices.shape[1]] + [1] * (len(values.shape)-2))
     if len(base_grid.shape) > 3:
         resolution = base_grid.shape[1:-1]
         ravel = [1]
         for i in range(1, len(resolution)):
             ravel.insert(0, ravel[0] * resolution[-i])
         ravel = self.to_int64(self.as_tensor(ravel, True))
         indices = torch.sum(indices * ravel, dim=-1, keepdim=True)
     base_grid_flat = torch.reshape(base_grid, [base_grid.shape[0], -1, base_grid.shape[-1]])
     indices = indices.long().repeat([1, 1, values.shape[-1]])
     result = scatter(base_grid_flat, dim=1, index=indices, src=values)
     return torch.reshape(result, base_grid.shape)
Exemplo n.º 5
0
 def grid_sample(self, grid, coordinates, extrapolation: str):
     assert extrapolation in ('undefined', 'zeros', 'boundary', 'periodic', 'symmetric', 'reflect'), extrapolation
     if get_functional_derivative_order() > 1:
         return NotImplemented  # PyTorch's grid_sample operator does not define higher-order derivatives
     extrapolation = {'undefined': 'zeros', 'zeros': 'zeros', 'boundary': 'border', 'reflect': 'reflection'}.get(extrapolation, None)
     if extrapolation is None:
         return NotImplemented
     grid = channels_first(self.as_tensor(grid))
     coordinates = self.as_tensor(coordinates)
     if coordinates.shape[0] != grid.shape[0]:  # repeating yields wrong result
         return NotImplemented
     if coordinates.ndim != grid.ndim or coordinates.ndim not in (4, 5):
         return NotImplemented  # torchf.grid_sample cannot handle this case
     if coordinates.dtype.is_floating_point and not grid.dtype.is_complex and not grid.dtype.is_floating_point:
         grid = self.to_float(grid)
     resolution = torch.tensor(self.staticshape(grid)[2:], dtype=coordinates.dtype, device=coordinates.device)
     coordinates = 2 * coordinates / (resolution - 1) - 1
     coordinates = torch.flip(coordinates, dims=[-1])
     batch_size = combined_dim(coordinates.shape[0], grid.shape[0])
     coordinates = coordinates.repeat(batch_size, *[1] * (len(coordinates.shape-1))) if coordinates.shape[0] < batch_size else coordinates
     grid = grid.repeat(batch_size, *[1] * (len(grid.shape)-1)) if grid.shape[0] < batch_size else grid
     result = torchf.grid_sample(grid, coordinates, mode='bilinear', padding_mode=extrapolation, align_corners=True)  # can cause segmentation violation if NaN or inf are present
     result = channels_last(result)
     return result