Beispiel #1
0
def test_interpolate_tlm():
    from firedrake_adjoint import ReducedFunctional, Control, taylor_test
    mesh = UnitSquareMesh(10, 10)
    V1 = VectorFunctionSpace(mesh, "CG", 1)
    V2 = VectorFunctionSpace(mesh, "DG", 0)
    V3 = VectorFunctionSpace(mesh, "CG", 2)

    x = SpatialCoordinate(mesh)
    f = interpolate(as_vector((x[0] * x[1], x[0] + x[1])), V1)
    g = interpolate(as_vector((sin(x[1]) + x[0], cos(x[0]) * x[1])), V2)
    u = Function(V3)

    u.interpolate(f - 0.5 * g + f / (1 + dot(f, g)))
    J = assemble(inner(f, g) * u**2 * dx)
    rf = ReducedFunctional(J, Control(f))

    h = Function(V1)
    h.vector()[:] = 1
    f.block_variable.tlm_value = h

    tape = get_working_tape()
    tape.evaluate_tlm()

    assert J.block_variable.tlm_value is not None
    assert taylor_test(rf, f, h, dJdm=J.block_variable.tlm_value) > 1.9
Beispiel #2
0
def test_interpolate_tlm_wit_constant():
    from firedrake_adjoint import ReducedFunctional, Control, taylor_test
    mesh = IntervalMesh(10, 0, 1)
    V1 = FunctionSpace(mesh, "CG", 2)
    V2 = FunctionSpace(mesh, "DG", 1)

    x = SpatialCoordinate(mesh)
    f = interpolate(x[0], V1)
    g = interpolate(sin(x[0]), V1)
    c = Constant(5.0)
    u = Function(V2)
    u.interpolate(c * f**2)

    # test tlm w.r.t constant only:
    c.block_variable.tlm_value = Constant(1.0)
    J = assemble(u**2 * dx)
    rf = ReducedFunctional(J, Control(c))
    h = Constant(1.0)

    tape = get_working_tape()
    tape.evaluate_tlm()
    assert abs(J.block_variable.tlm_value - 2.0) < 1e-5
    assert taylor_test(rf, c, h, dJdm=J.block_variable.tlm_value) > 1.9

    # test tlm w.r.t constant c and function f:
    tape.reset_tlm_values()
    c.block_variable.tlm_value = Constant(0.4)
    f.block_variable.tlm_value = g
    rf(c)  # replay to reset checkpoint values based on c=5
    tape.evaluate_tlm()
    assert abs(J.block_variable.tlm_value -
               (0.8 + 100. * (5 * cos(1.) - 3 * sin(1.)))) < 1e-4
Beispiel #3
0
def assemble(*args, **kwargs):
    """When a form is assembled, the information about its nonlinear dependencies is lost,
    and it is no longer easy to manipulate. Therefore, fenics_adjoint overloads the :py:func:`dolfin.assemble`
    function to *attach the form to the assembled object*. This lets the automatic annotation work,
    even when the user calls the lower-level :py:data:`solve(A, x, b)`.
    """
    annotate = annotate_tape(kwargs)
    with stop_annotating():
        output = backend.assemble(*args, **kwargs)

    form = args[0]
    if isinstance(output, float):
        output = create_overloaded_object(output)

        if annotate:
            block = AssembleBlock(form)

            tape = get_working_tape()
            tape.add_block(block)

            block.add_output(output.block_variable)
    else:
        # Assembled a vector or matrix
        output.form = form

    return output
def test_supermesh_project_hessian(vector):
    from firedrake_adjoint import ReducedFunctional, Control, taylor_test
    source, target_space = supermesh_setup()
    control = Control(source)
    target = project(source, target_space)
    J = assemble(inner(target, target)**2 * dx)
    rf = ReducedFunctional(J, control)

    source_space = source.function_space()
    h = Function(source_space)
    h.vector()[:] = 10 * rand(source_space.dim())

    J.block_variable.adj_value = 1.0
    source.block_variable.tlm_value = h

    tape = get_working_tape()
    tape.evaluate_adj()
    tape.evaluate_tlm()

    J.block_variable.hessian_value = 0.0

    tape.evaluate_hessian()

    dJdm = J.block_variable.tlm_value
    assert isinstance(source.block_variable.adj_value, Vector)
    assert isinstance(source.block_variable.hessian_value, Vector)
    Hm = source.block_variable.hessian_value.inner(h.vector())
    assert taylor_test(rf, source, h, dJdm=dJdm, Hm=Hm) > 2.9
Beispiel #5
0
    def solve(self, *args, **kwargs):
        annotate = annotate_tape(kwargs)

        if annotate:
            tape = get_working_tape()
            factory = args[0]
            vec = args[1]
            b = backend.as_backend_type(vec).__class__()

            factory.F(b=b, x=vec)

            F = b.form
            bcs = b.bcs

            u = vec.function

            sb_kwargs = SolveVarFormBlock.pop_kwargs(kwargs)
            block = SolveVarFormBlock(
                F == 0,
                u,
                bcs,
                solver_parameters={"newton_solver": self.parameters.copy()},
                **sb_kwargs)
            tape.add_block(block)

        newargs = [self] + list(args)
        out = backend.NewtonSolver.solve(*newargs, **kwargs)

        if annotate:
            block.add_output(u.create_block_variable())

        return out
Beispiel #6
0
    def wrapper(*args, **kwargs):
        """The project call performs an equation solve, and so it too must be annotated so that the
        adjoint and tangent linear models may be constructed automatically by pyadjoint.

        To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in
        cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint
        computation (such as projecting fields to other function spaces for the purposes of
        visualisation)."""

        annotate = annotate_tape(kwargs)
        if annotate:
            bcs = kwargs.get("bcs", [])
            sb_kwargs = ProjectBlock.pop_kwargs(kwargs)
            if isinstance(args[1], function.Function):
                # block should be created before project because output might also be an input that needs checkpointing
                output = args[1]
                V = output.function_space()
                block = ProjectBlock(args[0], V, output, bcs, **sb_kwargs)

        with stop_annotating():
            output = project(*args, **kwargs)

        if annotate:
            tape = get_working_tape()
            if not isinstance(args[1], function.Function):
                block = ProjectBlock(args[0], args[1], output, bcs,
                                     **sb_kwargs)
            tape.add_block(block)
            block.add_output(output.create_block_variable())

        return output
Beispiel #7
0
    def solve(self, **kwargs):
        annotate = annotate_tape()
        if annotate:
            block_helper = BlockSolveBlockHelper()
            tape = get_working_tape()
            problem = self._ad_problem

            #            sb_kwargs = SolveBlock.pop_kwargs(kwargs)
            block = NonlinearBlockSolveBlock(
                problem._ad_b == 0,
                problem._ad_u,
                problem._ad_bcs,
                block_helper=block_helper,
                problem_J=problem._ad_A,
                block_field=self._ad_problem.block_field,
                block_split=self._ad_problem.block_split)
            tape.add_block(block)

        with stop_annotating():
            out = super(NonlinearBlockSolver, self).solve()

        if annotate:
            block.add_output(self._ad_problem._ad_u.create_block_variable())

        return out
Beispiel #8
0
    def wrapper(*args, **kwargs):

        ad_block_tag = kwargs.pop("ad_block_tag", None)
        annotate = annotate_tape(kwargs)

        if annotate:
            tape = get_working_tape()
            solve_block_type = SolveVarFormBlock
            if not isinstance(args[0], ufl.equation.Equation):
                solve_block_type = SolveLinearSystemBlock

            sb_kwargs = solve_block_type.pop_kwargs(kwargs)
            sb_kwargs.update(kwargs)
            block = solve_block_type(*args,
                                     ad_block_tag=ad_block_tag,
                                     **sb_kwargs)
            tape.add_block(block)

        with stop_annotating():
            output = solve(*args, **kwargs)

        if annotate:
            if hasattr(args[1], "create_block_variable"):
                block_variable = args[1].create_block_variable()
            else:
                block_variable = args[1].function.create_block_variable()
            block.add_output(block_variable)

        return output
Beispiel #9
0
        def wrapper(self, **kwargs):
            """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
            Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
            for the purposes of the adjoint computation (such as projecting fields to other function spaces
            for the purposes of visualisation)."""

            annotate = annotate_tape(kwargs)
            if annotate:
                tape = get_working_tape()
                problem = self._ad_problem
                sb_kwargs = NonlinearVariationalSolveBlock.pop_kwargs(kwargs)
                sb_kwargs.update(kwargs)
                block = NonlinearVariationalSolveBlock(
                    problem._ad_F == 0,
                    problem._ad_u,
                    problem._ad_bcs,
                    problem_J=problem._ad_J,
                    solver_params=self.parameters,
                    solver_kwargs=self._ad_kwargs,
                    **sb_kwargs)
                tape.add_block(block)

            with stop_annotating():
                out = solve(self, **kwargs)

            if annotate:
                block.add_output(
                    self._ad_problem._ad_u.create_block_variable())

            return out
Beispiel #10
0
    def solve(self, *args, **kwargs):
        annotate = annotate_tape(kwargs)

        if annotate:
            if len(args) == 3:
                block_helper = LUSolveBlockHelper()
                A = args[0]
                x = args[1]
                b = args[2]
            elif len(args) == 2:
                block_helper = self.block_helper
                A = self.operator
                x = args[0]
                b = args[1]

            u = x.function
            parameters = self.parameters.copy()

            tape = get_working_tape()
            sb_kwargs = LUSolveBlock.pop_kwargs(kwargs)
            block = LUSolveBlock(A,
                                 x,
                                 b,
                                 lu_solver_parameters=parameters,
                                 block_helper=block_helper,
                                 lu_solver_method=self.method,
                                 **sb_kwargs)
            tape.add_block(block)

        out = backend.LUSolver.solve(self, *args, **kwargs)

        if annotate:
            block.add_output(u.create_block_variable())

        return out
Beispiel #11
0
        def wrapper(self, b, *args, **kwargs):
            ad_block_tag = kwargs.pop("ad_block_tag", None)
            annotate = annotate_tape(kwargs)

            if annotate:
                bcs = kwargs.get("bcs", [])
                if isinstance(
                        b, firedrake.Function
                ) and b.ufl_domain() != self.function_space().mesh():
                    block = SupermeshProjectBlock(b,
                                                  self.function_space(),
                                                  self,
                                                  bcs,
                                                  ad_block_tag=ad_block_tag)
                else:
                    block = ProjectBlock(b,
                                         self.function_space(),
                                         self,
                                         bcs,
                                         ad_block_tag=ad_block_tag)

                tape = get_working_tape()
                tape.add_block(block)

            with stop_annotating():
                output = project(self, b, *args, **kwargs)

            if annotate:
                block.add_output(output.create_block_variable())

            return output
Beispiel #12
0
    def wrapper(*args, **kwargs):
        """The project call performs an equation solve, and so it too must be annotated so that the
        adjoint and tangent linear models may be constructed automatically by pyadjoint.

        To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in
        cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint
        computation (such as projecting fields to other function spaces for the purposes of
        visualisation)."""

        annotate = annotate_tape(kwargs)
        with stop_annotating():
            output = project(*args, **kwargs)
        output = create_overloaded_object(output)

        if annotate:
            bcs = kwargs.pop("bcs", [])
            sb_kwargs = ProjectBlock.pop_kwargs(kwargs)
            block = ProjectBlock(args[0], args[1], output, bcs, **sb_kwargs)

            tape = get_working_tape()
            tape.add_block(block)

            block.add_output(output.block_variable)

        return output
Beispiel #13
0
    def wrapper(interpolator, *function, **kwargs):
        """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
        Firedrake interpolate call."""
        ad_block_tag = kwargs.pop("ad_block_tag", None)
        annotate = annotate_tape(kwargs)

        if annotate:
            sb_kwargs = InterpolateBlock.pop_kwargs(kwargs)
            sb_kwargs.update(kwargs)
            block = InterpolateBlock(interpolator,
                                     *function,
                                     ad_block_tag=ad_block_tag,
                                     **sb_kwargs)
            tape = get_working_tape()
            tape.add_block(block)

        with stop_annotating():
            output = interpolate(interpolator, *function, **kwargs)

        if annotate:
            from firedrake import Function
            if isinstance(interpolator.V, Function):
                block.add_output(output.create_block_variable())
            else:
                block.add_output(output.block_variable)

        return output
Beispiel #14
0
    def assign(self, *args, **kwargs):
        annotate = annotate_tape(kwargs)
        outputs = Enlist(args[0])
        inputs = Enlist(args[1])

        if annotate:
            for i, o in enumerate(outputs):
                if not isinstance(o, OverloadedType):
                    outputs[i] = create_overloaded_object(o)

            for j, i in enumerate(outputs):
                if not isinstance(i, OverloadedType):
                    inputs[j] = create_overloaded_object(i)

            block = FunctionAssignerBlock(self, inputs)
            tape = get_working_tape()
            tape.add_block(block)

        with stop_annotating():
            ret = backend.FunctionAssigner.assign(self, outputs.delist(), inputs.delist(), **kwargs)

        if annotate:
            for output in outputs:
                block.add_output(output.block_variable)
        return ret
Beispiel #15
0
 def sub(self, i, deepcopy=False, **kwargs):
     from .function_assigner import FunctionAssigner, FunctionAssignerBlock
     annotate = annotate_tape(kwargs)
     if deepcopy:
         ret = create_overloaded_object(
             backend.Function.sub(self, i, deepcopy, **kwargs))
         if annotate:
             fa = FunctionAssigner(ret.function_space(),
                                   self.function_space())
             block = FunctionAssignerBlock(fa, Enlist(self))
             tape = get_working_tape()
             tape.add_block(block)
             block.add_output(ret.block_variable)
     else:
         extra_kwargs = {}
         if annotate:
             extra_kwargs = {
                 "block_class": FunctionSplitBlock,
                 "_ad_floating_active": True,
                 "_ad_args": [self, i],
                 "_ad_output_args": [i],
                 "output_block_class": FunctionMergeBlock,
                 "_ad_outputs": [self],
             }
         ret = compat.create_function(self, i, **extra_kwargs)
     return ret
Beispiel #16
0
 def __init__(
     self,
     functional,
     controls,
     level_set,
     scale=1.0,
     tape=None,
     eval_cb_pre=lambda *args: None,
     eval_cb_post=lambda *args: None,
     derivative_cb_pre=lambda *args: None,
     derivative_cb_post=lambda *args: None,
     hessian_cb_pre=lambda *args: None,
     hessian_cb_post=lambda *args: None,
 ):
     self.functional = functional
     self.cost_function = self.functional
     self.tape = get_working_tape() if tape is None else tape
     self.controls = Enlist(controls)
     self.level_set = Enlist(level_set)
     self.scale = scale
     self.eval_cb_pre = eval_cb_pre
     self.eval_cb_post = eval_cb_post
     self.derivative_cb_pre = derivative_cb_pre
     self.derivative_cb_post = derivative_cb_post
     self.hessian_cb_pre = hessian_cb_pre
     self.hessian_cb_post = hessian_cb_post
Beispiel #17
0
    def wrapper(*args, **kwargs):
        """When a form is assembled, the information about its nonlinear dependencies is lost,
        and it is no longer easy to manipulate. Therefore, we decorate :func:`.assemble`
        to *attach the form to the assembled object*. This lets the automatic annotation work,
        even when the user calls the lower-level :py:data:`solve(A, x, b)`.
        """
        ad_block_tag = kwargs.pop("ad_block_tag", None)
        annotate = annotate_tape(kwargs)
        with stop_annotating():
            output = assemble(*args, **kwargs)

        form = args[0]
        if isinstance(output, numbers.Complex):
            if not annotate:
                return output

            if not isinstance(output, float):
                raise NotImplementedError(
                    "Taping for complex-valued 0-forms not yet done!")
            output = create_overloaded_object(output)
            block = AssembleBlock(form, ad_block_tag=ad_block_tag)

            tape = get_working_tape()
            tape.add_block(block)

            block.add_output(output.block_variable)
        else:
            # Assembled a vector or matrix
            output.form = form

        return output
Beispiel #18
0
def test_interpolate_hessian_nonlinear_expr_multi():
    # Note this is a direct copy of
    # pyadjoint/tests/firedrake_adjoint/test_hessian.py::test_nonlinear
    # with modifications where indicated.

    from firedrake_adjoint import ReducedFunctional, Control, taylor_test, get_working_tape

    # Get tape instead of creating a new one for consistency with other tests
    tape = get_working_tape()

    mesh = UnitSquareMesh(10, 10)
    V = FunctionSpace(mesh, "Lagrange", 1)

    # Interpolate from f in another function space to force hessian evaluation
    # of interpolation. Functions in W form our control space c, our expansion
    # space h and perterbation direction g.
    W = FunctionSpace(mesh, "Lagrange", 2)
    f = Function(W)
    f.vector()[:] = 5
    w = Function(W)
    w.vector()[:] = 4
    c = Constant(2.)
    # Note that we interpolate from a nonlinear expression with 3 coefficients
    expr_interped = Function(V).interpolate(f**2 + w**2 + c**2)

    u = Function(V)
    v = TestFunction(V)
    bc = DirichletBC(V, Constant(1), "on_boundary")

    F = inner(grad(u), grad(v)) * dx - u**2 * v * dx - expr_interped * v * dx
    solve(F == 0, u, bc)

    J = assemble(u**4 * dx)
    Jhat = ReducedFunctional(J, Control(f))

    # Note functions are in W, not V.
    h = Function(W)
    h.vector()[:] = 10 * rand(W.dim())

    J.block_variable.adj_value = 1.0
    # Note only the tlm_value of f is set here - unclear why.
    f.block_variable.tlm_value = h

    tape.evaluate_adj()
    tape.evaluate_tlm()

    J.block_variable.hessian_value = 0

    tape.evaluate_hessian()

    g = f.copy(deepcopy=True)

    dJdm = J.block_variable.tlm_value
    assert isinstance(f.block_variable.adj_value, Vector)
    assert isinstance(f.block_variable.hessian_value, Vector)
    Hm = f.block_variable.hessian_value.inner(h.vector())
    # If the new interpolate block has the right hessian, taylor test
    # convergence rate should be as for the unmodified test.
    assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9
Beispiel #19
0
        def wrapper(self, other, **kwargs):
            annotate = annotate_tape(kwargs)
            func = __idiv__(self, other, **kwargs)

            if annotate:
                block = FunctionAssignBlock(func, self / other)
                tape = get_working_tape()
                tape.add_block(block)
                block.add_output(func.create_block_variable())

            return func
Beispiel #20
0
def get_solve_blocks():
    """
    Extract all blocks of the tape which correspond
    to PDE solves, except for those which correspond
    to calls of the ``project`` operator.
    """
    return [
        block for block in get_working_tape().get_blocks()
        if issubclass(type(block), GenericSolveBlock)
        and not issubclass(type(block), ProjectBlock)
    ]
Beispiel #21
0
    def __init__(self, *args, **kwargs):
        ad_block_tag = kwargs.pop("ad_block_tag", None)
        annotate = annotate_tape(kwargs)
        super(Constant, self).__init__(*args, **kwargs)
        backend.Constant.__init__(self, *args, **kwargs)

        if annotate and len(args) > 0:
            value = args[0]
            if isinstance(value, OverloadedType):
                block = ConstantAssignBlock(value, ad_block_tag=ad_block_tag)
                tape = get_working_tape()
                tape.add_block(block)
                block.add_output(self.block_variable)
            elif isinstance(value, (tuple, list)):
                value = numpy.array(value, dtype="O")
                if any(isinstance(v, OverloadedType) for v in value.flat):
                    block = ConstantAssignBlock(value,
                                                ad_block_tag=ad_block_tag)
                    tape = get_working_tape()
                    tape.add_block(block)
                    block.add_output(self.block_variable)
Beispiel #22
0
def move(mesh, vector, **kwargs):
    annotate = annotate_tape(kwargs)
    if annotate:
        assert isinstance(mesh, OverloadedType)
        assert isinstance(vector, OverloadedType)
        tape = get_working_tape()
        block = ALEMoveBlock(mesh, vector, **kwargs)
        tape.add_block(block)

    with stop_annotating():
        output = __backend_ALE_move(mesh, vector)
    if annotate:
        block.add_output(mesh.create_block_variable())
    return output
Beispiel #23
0
        def wrapper(self, other, **kwargs):
            ad_block_tag = kwargs.pop("ad_block_tag", None)
            annotate = annotate_tape(kwargs)
            func = __imul__(self, other, **kwargs)

            if annotate:
                block = FunctionAssignBlock(func,
                                            self * other,
                                            ad_block_tag=ad_block_tag)
                tape = get_working_tape()
                tape.add_block(block)
                block.add_output(func.create_block_variable())

            return func
Beispiel #24
0
    def __getitem__(self, item):
        annotate = annotate_tape()
        if annotate:
            block = NumpyArraySliceBlock(self, item)
            tape = get_working_tape()
            tape.add_block(block)

        with stop_annotating():
            out = numpy.ndarray.__getitem__(self, item)

        if annotate:
            out = create_overloaded_object(out)
            block.add_output(out.create_block_variable())
        return out
Beispiel #25
0
def solve(*args, **kwargs):
    """This solve routine wraps the real Dolfin solve call. Its purpose is to annotate the model,
    recording what solves occur and what forms are involved, so that the adjoint and tangent linear models may be
    constructed automatically by pyadjoint.

    To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
    Dolfin solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
    for the purposes of the adjoint computation (such as projecting fields to other function spaces
    for the purposes of visualisation).

    The overloaded solve takes optional callback functions to extract adjoint solutions.
    All of the callback functions follow the same signature, taking a single argument of type Function.

    Keyword Args:
        adj_cb (function, optional): callback function supplying the adjoint solution in the interior.
            The boundary values are zero.
        adj_bdy_cb (function, optional): callback function supplying the adjoint solution on the boundary.
            The interior values are not guaranteed to be zero.
        adj2_cb (function, optional): callback function supplying the second-order adjoint solution in the interior.
            The boundary values are zero.
        adj2_bdy_cb (function, optional): callback function supplying the second-order adjoint solution on
            the boundary. The interior values are not guaranteed to be zero.

    """
    ad_block_tag = kwargs.pop("ad_block_tag", None)
    annotate = annotate_tape(kwargs)
    if annotate:
        tape = get_working_tape()

        solve_block_type = SolveVarFormBlock
        if not isinstance(args[0], ufl.equation.Equation):
            solve_block_type = SolveLinearSystemBlock

        sb_kwargs = solve_block_type.pop_kwargs(kwargs)
        sb_kwargs.update(kwargs)
        block = solve_block_type(*args, ad_block_tag=ad_block_tag, **sb_kwargs)
        tape.add_block(block)

    with stop_annotating():
        output = backend.solve(*args, **kwargs)

    if annotate:
        if hasattr(args[1], "create_block_variable"):
            block_variable = args[1].create_block_variable()
        else:
            block_variable = args[1].function.create_block_variable()
        block.add_output(block_variable)

    return output
Beispiel #26
0
        def wrapper(self, *args, **kwargs):
            annotate = annotate_tape(kwargs)
            func = copy(self, *args, **kwargs)

            if annotate:
                if kwargs.pop("deepcopy", False):
                    block = FunctionAssignBlock(func, self)
                    tape = get_working_tape()
                    tape.add_block(block)
                    block.add_output(func.create_block_variable())
                else:
                    # TODO: Implement. Here we would need to use floating types.
                    raise NotImplementedError("Currently kwargs['deepcopy'] must be set True")

            return func
Beispiel #27
0
    def project(self, b, *args, **kwargs):
        annotate = annotate_tape(kwargs)
        with stop_annotating():
            output = super(Function, self).project(b, *args, **kwargs)
        output = create_overloaded_object(output)

        if annotate:
            bcs = kwargs.pop("bcs", [])
            block = ProjectBlock(b, self.function_space(), output, bcs)

            tape = get_working_tape()
            tape.add_block(block)

            block.add_output(output.create_block_variable())

        return output
Beispiel #28
0
    def copy(self, *args, **kwargs):
        annotate = annotate_tape(kwargs)
        c = backend.Function.copy(self, *args, **kwargs)
        func = create_overloaded_object(c)

        if annotate:
            if kwargs.pop("deepcopy", False):
                block = FunctionAssignBlock(func, self)
                tape = get_working_tape()
                tape.add_block(block)
                block.add_output(func.create_block_variable())
            else:
                # TODO: Implement. Here we would need to use floating types.
                pass

        return func
Beispiel #29
0
    def assign(self, *args, **kwargs):
        annotate_tape = kwargs.pop("annotate_tape", True)
        if annotate_tape:
            other = args[0]
            if not isinstance(other, OverloadedType):
                other = create_overloaded_object(other)

            block = AssignBlock(self, other)
            tape = get_working_tape()
            tape.add_block(block)

        ret = backend.Constant.assign(self, *args, **kwargs)

        if annotate_tape:
            block.add_output(self.create_block_variable())

        return ret
Beispiel #30
0
        def wrapper(self, *args, **kwargs):
            annotate = annotate_tape(kwargs)
            if annotate:
                other = args[0]
                if not isinstance(other, OverloadedType):
                    other = create_overloaded_object(other)

                block = ConstantAssignBlock(other)
                tape = get_working_tape()
                tape.add_block(block)

            ret = assign(self, *args, **kwargs)

            if annotate:
                block.add_output(self.create_block_variable())

            return ret