Пример #1
0
    def conjugate_gradient(self,
                           A,
                           y,
                           x0,
                           solve_params=LinearSolve(),
                           callback=None):
        bs_y = self.staticshape(y)[0]
        bs_x0 = self.staticshape(x0)[0]
        batch_size = combined_dim(bs_y, bs_x0)

        if isinstance(A, (tuple, list)) or self.ndims(A) == 3:
            batch_size = combined_dim(batch_size, self.staticshape(A)[0])

        results = []

        for batch in range(batch_size):
            y_ = y[min(batch, bs_y - 1)]
            x0_ = x0[min(batch, bs_x0 - 1)]
            x, ret_val = cg(A,
                            y_,
                            x0_,
                            tol=solve_params.relative_tolerance,
                            atol=solve_params.absolute_tolerance,
                            maxiter=solve_params.max_iterations)

            results.append(x)
        solve_params.result = SolveResult(success=True, iterations=-1)
        return self.stack(results)
Пример #2
0
 def batched_gather_nd(self, values, indices):
     assert indices.shape[-1] == self.ndims(values) - 2
     batch_size = combined_dim(values.shape[0], indices.shape[0])
     results = []
     for b in range(batch_size):
         b_values = values[min(b, values.shape[0] - 1)]
         b_indices = self.unstack(indices[min(b, indices.shape[0] - 1)], -1)
         results.append(b_values[b_indices])
     return jnp.stack(results)
Пример #3
0
 def batched_gather_nd(self, values, indices):
     values = self.as_tensor(values)
     indices = self.as_tensor(indices).long()
     batch_size = combined_dim(values.shape[0], indices.shape[0])
     result = []
     for b in range(batch_size):
         b_indices = self.unstack(indices[min(b, indices.shape[0] - 1)], -1)
         result.append(values[(min(b, values.shape[0] - 1), ) + b_indices])
     return self.stack(result, axis=0)
Пример #4
0
 def grid_sample(self,
                 grid,
                 spatial_dims: tuple,
                 coordinates,
                 extrapolation='constant'):
     assert extrapolation in ('undefined', 'zeros', 'boundary', 'periodic',
                              'symmetric', 'reflect'), extrapolation
     extrapolation = {
         'undefined': 'zeros',
         'zeros': 'zeros',
         'boundary': 'border',
         'reflect': 'reflection'
     }.get(extrapolation, None)
     if extrapolation is None:
         return NotImplemented
     grid = channels_first(self.as_tensor(grid))
     coordinates = self.as_tensor(coordinates)
     if coordinates.shape[0] != grid.shape[
             0]:  # repeating yields wrong result
         return NotImplemented
     resolution = torch.tensor(self.staticshape(grid)[2:],
                               dtype=coordinates.dtype,
                               device=coordinates.device)
     coordinates = 2 * coordinates / (resolution - 1) - 1
     coordinates = torch.flip(coordinates, dims=[-1])
     batch_size = combined_dim(coordinates.shape[0], grid.shape[0])
     coordinates = coordinates.repeat(
         batch_size, *[1] *
         (len(coordinates.shape -
              1))) if coordinates.shape[0] < batch_size else coordinates
     grid = grid.repeat(
         batch_size, *[1] *
         (len(grid.shape) - 1)) if grid.shape[0] < batch_size else grid
     result = torchf.grid_sample(
         grid,
         coordinates,
         mode='bilinear',
         padding_mode=extrapolation,
         align_corners=True
     )  # can cause segmentation violation if NaN or inf are present
     result = channels_last(result)
     return result
Пример #5
0
    def conjugate_gradient(self,
                           A,
                           y,
                           x0,
                           solve_params=LinearSolve(),
                           callback=None):
        if callable(A):
            function = A
        else:
            A = self.as_tensor(A)
            A_shape = self.staticshape(A)
            assert len(
                A_shape
            ) == 2, f"A must be a square matrix but got shape {A_shape}"
            assert A_shape[0] == A_shape[
                1], f"A must be a square matrix but got shape {A_shape}"

            def function(vec):
                return self.matmul(A, vec)

        y = self.to_float(y)
        x0 = self.to_float(x0)
        batch_size = combined_dim(x0.shape[0], y.shape[0])
        if x0.shape[0] < batch_size:
            x0 = x0.repeat([batch_size, 1])

        def cg_forward(y, x0, params: LinearSolve):
            tolerance_sq = self.maximum(
                params.relative_tolerance**2 * torch.sum(y**2, -1),
                params.absolute_tolerance**2)
            x = x0
            dx = residual = y - function(x)
            dy = function(dx)
            iterations = 0
            converged = True
            while self.all(self.sum(residual**2, -1) > tolerance_sq):
                if iterations == params.max_iterations:
                    converged = False
                    break
                iterations += 1
                dx_dy = self.sum(dx * dy, axis=-1, keepdims=True)
                step_size = self.divide_no_nan(
                    self.sum(dx * residual, axis=-1, keepdims=True), dx_dy)
                x += step_size * dx
                residual -= step_size * dy
                dx = residual - self.divide_no_nan(
                    self.sum(residual * dy, axis=-1, keepdims=True) * dx,
                    dx_dy)
                dy = function(dx)
            params.result = SolveResult(converged, iterations)
            return x

        class CGVariant(torch.autograd.Function):
            @staticmethod
            def forward(ctx, y):
                return cg_forward(y, x0, solve_params)

            @staticmethod
            def backward(ctx, dX):
                return cg_forward(dX, torch.zeros_like(x0),
                                  solve_params.gradient_solve)

        result = CGVariant.apply(y)
        return result