Ejemplo n.º 1
0
    def __enter__(self):
        assert not self.entered, "Forward mode manager can be entered only once."
        self.entered = True

        impl.get_runtime().materialize()
        if not isinstance(self.loss, list):
            self.loss = [self.loss]

        # Currently we only support only one N-D field as a group of parameters,
        # which is sufficient for computing Jacobian-vector product(Jvp).
        # For cases with multiple groups of parameters, it requires to run the forward ad multiple times,
        # which is out of scope of the current design for this interface.

        # TODO: support vector field and matrix field
        assert isinstance(self.parameters, ScalarField)

        all_fields = [*self.loss, self.parameters]
        fields_without_dual = []

        for x in all_fields:
            if not x.snode.ptr.has_dual():
                fields_without_dual.append(x)

        if len(fields_without_dual) > 0:
            dual_root = FieldsBuilder()
            for x in fields_without_dual:
                allocate_dual(x, dual_root)
            dual_root.finalize()

        def shape_flatten(shape):
            return reduce((lambda x, y: x * y), list(shape))

        parameters_shape_flatten = shape_flatten(self.parameters.shape)

        # Handle 0-D field
        if parameters_shape_flatten == 0:
            parameters_shape_flatten = 1

        if not self.seed:
            # Compute the derivative respect to the first variable by default
            self.seed = [0.0 for _ in range(parameters_shape_flatten)]
            self.seed[0] = 1.0
        else:
            assert parameters_shape_flatten == len(self.seed)

        # Set seed for each variable
        if len(self.seed) == 1:
            self.parameters.dual[None] = 1.0 * self.seed[0]
        else:
            for idx, s in enumerate(self.seed):
                self.parameters.dual[idx] = 1.0 * s

        # Clear gradients
        if self.clear_gradients:
            for ls in self.loss:
                ls.dual.fill(0)

        # Attach the context manager to the runtime
        self.runtime.fwd_mode_manager = self
Ejemplo n.º 2
0
 def materialize_root_fb(is_first_call):
     if root.finalized:
         return
     if not is_first_call and root.empty:
         # We have to forcefully finalize when `is_first_call` is True (even
         # if the root itself is empty), so that there is a valid struct
         # llvm::Module, if no field has been declared before the first kernel
         # invocation. Example case:
         # https://github.com/taichi-dev/taichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266
         return
     root.finalize(raise_warning=not is_first_call)
     global _root_fb
     _root_fb = FieldsBuilder()
Ejemplo n.º 3
0
def deactivate_all_snodes():
    """Recursively deactivate all SNodes."""
    for root_fb in FieldsBuilder._finalized_roots():
        root_fb.deactivate_all()