Exemple #1
0
def clear_all_gradients():
    impl.get_runtime().materialize()

    def visit(node):
        places = []
        for i in range(node.ptr.get_num_ch()):
            ch = node.ptr.get_ch(i)
            if not ch.is_place():
                visit(SNode(ch))
            else:
                if not ch.is_primal():
                    places.append(ch.get_expr())

        places = tuple(places)
        if places:
            from taichi.lang.meta import clear_gradients
            clear_gradients(places)

    for root_fb in FieldsBuilder.finalized_roots():
        visit(root_fb)
Exemple #2
0
def clear_all_gradients():
    """Set all fields' gradients to 0."""
    impl.get_runtime().materialize()

    def visit(node):
        places = []
        for _i in range(node.ptr.get_num_ch()):
            ch = node.ptr.get_ch(_i)
            if not ch.is_place():
                visit(SNode(ch))
            else:
                if not ch.is_primal():
                    places.append(ch.get_expr())

        places = tuple(places)
        if places:
            from taichi._kernels import \
                clear_gradients  # pylint: disable=C0415
            clear_gradients(places)

    for root_fb in FieldsBuilder.finalized_roots():
        visit(root_fb)
Exemple #3
0
def deactivate_all_snodes():
    """Recursively deactivate all SNodes."""
    for root_fb in FieldsBuilder.finalized_roots():
        root_fb.deactivate_all()