def __enter__(self): assert not self.entered, "Forward mode manager can be entered only once." self.entered = True impl.get_runtime().materialize() if not isinstance(self.loss, list): self.loss = [self.loss] # Currently we only support only one N-D field as a group of parameters, # which is sufficient for computing Jacobian-vector product(Jvp). # For cases with multiple groups of parameters, it requires to run the forward ad multiple times, # which is out of scope of the current design for this interface. # TODO: support vector field and matrix field assert isinstance(self.parameters, ScalarField) all_fields = [*self.loss, self.parameters] fields_without_dual = [] for x in all_fields: if not x.snode.ptr.has_dual(): fields_without_dual.append(x) if len(fields_without_dual) > 0: dual_root = FieldsBuilder() for x in fields_without_dual: allocate_dual(x, dual_root) dual_root.finalize() def shape_flatten(shape): return reduce((lambda x, y: x * y), list(shape)) parameters_shape_flatten = shape_flatten(self.parameters.shape) # Handle 0-D field if parameters_shape_flatten == 0: parameters_shape_flatten = 1 if not self.seed: # Compute the derivative respect to the first variable by default self.seed = [0.0 for _ in range(parameters_shape_flatten)] self.seed[0] = 1.0 else: assert parameters_shape_flatten == len(self.seed) # Set seed for each variable if len(self.seed) == 1: self.parameters.dual[None] = 1.0 * self.seed[0] else: for idx, s in enumerate(self.seed): self.parameters.dual[idx] = 1.0 * s # Clear gradients if self.clear_gradients: for ls in self.loss: ls.dual.fill(0) # Attach the context manager to the runtime self.runtime.fwd_mode_manager = self
def materialize_root_fb(is_first_call): if root.finalized: return if not is_first_call and root.empty: # We have to forcefully finalize when `is_first_call` is True (even # if the root itself is empty), so that there is a valid struct # llvm::Module, if no field has been declared before the first kernel # invocation. Example case: # https://github.com/taichi-dev/taichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266 return root.finalize(raise_warning=not is_first_call) global _root_fb _root_fb = FieldsBuilder()
def deactivate_all_snodes(): """Recursively deactivate all SNodes.""" for root_fb in FieldsBuilder._finalized_roots(): root_fb.deactivate_all()