Esempio n. 1
0
    def forward(ctx, fenics_solver, *args):
        """Computes the output of a FEniCS model and saves a corresponding gradient tape

        Input:
            fenics_solver (FEniCSSolver): FEniCSSolver to be executed during the forward pass
            args (tuple): tensor representation of the input to fenics_solver.forward

        Output:
            tensor representation of the output from fenics_solver.solve
        """
        # Check that the number of inputs arguments is correct
        n_args = len(args)
        expected_nargs = len(fenics_solver.fenics_input_templates())
        if n_args != expected_nargs:
            raise ValueError(f'Wrong number of arguments to {fenics_solver}.'
                             ' Expected {expected_nargs} got {n_args}.')

        # Check that each input argument has correct dimensions
        for i, (arg, template) in enumerate(
                zip(args, fenics_solver.numpy_input_templates())):
            if arg.shape != template.shape:
                raise ValueError(
                    f'Expected input shape {template.shape} for input'
                    ' {i} but got {arg.shape}.')

        # Check that the inputs are of double precision
        for i, arg in enumerate(args):
            if (isinstance(arg, np.ndarray) and arg.dtype != np.float64) or \
               (torch.is_tensor(arg) and arg.dtype != torch.float64):
                raise TypeError(f'All inputs must be type {torch.float64},'
                                ' but got {arg.dtype} for input {i}.')

        # Convert input tensors to corresponding FEniCS variables
        fenics_inputs = []
        for inp, template in zip(args, fenics_solver.fenics_input_templates()):
            if torch.is_tensor(inp):
                inp = inp.detach().numpy()
            fenics_inputs.append(numpy_fenics.numpy_to_fenics(inp, template))

        # Create tape associated with this forward pass
        tape = fenics_adjoint.Tape()
        fenics_adjoint.set_working_tape(tape)

        # Execute forward pass
        fenics_outputs = fenics_solver.solve(*fenics_inputs)

        # If single output
        if not isinstance(fenics_outputs, tuple):
            fenics_outputs = (fenics_outputs, )

        # Save variables to be used for backward pass
        ctx.tape = tape
        ctx.fenics_inputs = fenics_inputs
        ctx.fenics_outputs = fenics_outputs

        # Return tensor representation of outputs
        return tuple(
            torch.from_numpy(numpy_fenics.fenics_to_numpy(fenics_output))
            for fenics_output in fenics_outputs)
    def backward(ctx, *grad_outputs):
        """Computes the gradients of the output with respect to the input

        Input:
            ctx: Context used for storing information from the forward pass
            grad_output: gradient of the output from successive operations
        """
        # Convert gradient of output to a FEniCS variable
        adj_values = []
        for grad_output, fenics_output in zip(grad_outputs, ctx.fenics_outputs):
            adj_value = numpy_to_fenics(grad_output.numpy(), fenics_output)
            # Special case
            if isinstance(adj_value, Function):
                adj_value = adj_value.vector()
            adj_values.append(adj_value)

        # Skip first input since this corresponds to the FEniCSModel
        needs_input_grad = ctx.needs_input_grad[1:]

        # Check which gradients need to be computed
        controls = []
        for needs_grad, fenics_input in zip(needs_input_grad, ctx.fenics_inputs):
            if needs_grad:
                controls.append(Control(fenics_input))

        # Compute and accumulate gradient for each output with respect to each input
        accumulated_grads = [None for _ in range(len(needs_input_grad))]
        for fenics_output, adj_value in zip(ctx.fenics_outputs, adj_values):
            fenics_grads = compute_gradient(fenics_output,
                                            controls,
                                            tape=ctx.tape,
                                            adj_value=adj_value)

            # Convert FEniCS gradients to tensor representation
            numpy_grads = []
            for fenics_grad in fenics_grads:
                if fenics_grad is None:
                    numpy_grads.append(None)
                else:
                    numpy_grads.append(torch.from_numpy(fenics_to_numpy(fenics_grad)))

            # Accumulate gradients
            i = 0
            for j, needs_grad in enumerate(needs_input_grad):
                if needs_grad:
                    if numpy_grads[i] is not None:
                        if accumulated_grads[j] is None:
                            accumulated_grads[j] = numpy_grads[i]
                        else:
                            accumulated_grads[j] += numpy_grads[i]
                    i += 1

        # Prepend None gradient corresponding to FEniCSModel input
        return tuple([None] + accumulated_grads)
Esempio n. 3
0
    def backward(ctx, *grad_outputs):
        """Computes the gradients of the output with respect to the input

        Input:
            ctx: Context used for storing information from the forward pass
            grad_output: gradient of the output from successive operations
        """
        # Convert gradient of output to a FEniCS variable
        adj_values = []
        for grad_output, fenics_output in zip(grad_outputs,
                                              ctx.fenics_outputs):
            adj_value = numpy_fenics.numpy_to_fenics(grad_output.numpy(),
                                                     fenics_output)
            # Special case
            if isinstance(adj_value,
                          (fenics.Function, fenics_adjoint.Function)):
                adj_value = adj_value.vector()
            adj_values.append(adj_value)

        # Check which gradients need to be computed
        controls = list(
            map(fenics_adjoint.Control,
                (c for g, c in zip(ctx.needs_input_grad[1:], ctx.fenics_inputs)
                 if g)))

        # Compute and accumulate gradient for each output with respect to each input
        accumulated_grads = [None] * len(controls)
        for fenics_output, adj_value in zip(ctx.fenics_outputs, adj_values):
            fenics_grads = fenics_adjoint.compute_gradient(fenics_output,
                                                           controls,
                                                           tape=ctx.tape,
                                                           adj_value=adj_value)

            # Convert FEniCS gradients to tensor representation
            numpy_grads = [
                g if g is None else torch.from_numpy(
                    numpy_fenics.fenics_to_numpy(g)) for g in fenics_grads
            ]
            for i, (acc_g, g) in enumerate(zip(accumulated_grads,
                                               numpy_grads)):
                if g is None:
                    continue
                if acc_g is None:
                    accumulated_grads[i] = g
                else:
                    accumulated_grads[i] += g

        # Insert None for not computed gradients
        acc_grad_iter = iter(accumulated_grads)
        return tuple(None if not g else next(acc_grad_iter)
                     for g in ctx.needs_input_grad)
    def forward(ctx, fenics_model, *args):
        """Computes the output of a FEniCS model and saves a corresponding tape

        Input:
            forward_model (ForwardModel): Defines the forward pass
            args (tuple): tensor representation of the input to the forward pass

        Output:
            tuple of outputs
        """
        # Check that the number of inputs arguments is correct
        n_args = len(args)
        expected_nargs = len(fenics_model.fenics_input_templates())
        if n_args != expected_nargs:
            err_msg = 'Wrong number of arguments to {}.' \
                      ' Expected {} got {}.'.format(type(fenics_model), expected_nargs, n_args)
            raise ValueError(err_msg)

        # Convert input tensors to corresponding FEniCS variables
        fenics_inputs = []
        for inp, template in zip(args, fenics_model.fenics_input_templates()):
            if torch.is_tensor(inp):
                inp = inp.detach().numpy()
            fenics_inputs.append(numpy_to_fenics(inp, template))

        # Create tape associated with this forward pass
        tape = Tape()
        set_working_tape(tape)

        # Execute forward pass
        fenics_outputs = fenics_model.forward(*fenics_inputs)

        # If single output
        if not isinstance(fenics_outputs, tuple):
            fenics_outputs = (fenics_outputs,)

        # Save variables to be used for backward pass
        ctx.tape = tape
        ctx.fenics_inputs = fenics_inputs
        ctx.fenics_outputs = fenics_outputs

        # Return tensor representation of outputs
        return tuple(torch.from_numpy(fenics_to_numpy(fenics_output)) for fenics_output in fenics_outputs)