Esempio n. 1
0
 def call(*args, **kwargs):
     variable_data = vars.dict()
     try:
         out, changes = jitted_func(variable_data, *args, **kwargs)
     except UnexpectedTracerError as e:
         vars.assign(variable_data)
         raise errors.JaxTracerError(variables=vars) from e
     vars.assign(changes)
     return out
Esempio n. 2
0
 def call(x=None):
     dyn_init = [v.value for v in dyn_vars]
     try:
         dyn_values, _ = lax.while_loop(cond_fun=_cond_fun,
                                        body_fun=_body_fun,
                                        init_val=(dyn_init, x))
     except UnexpectedTracerError as e:
         for v, d in zip(dyn_vars, dyn_init):
             v.value = d
         raise errors.JaxTracerError(variables=dyn_vars) from e
     for v, d in zip(dyn_vars, dyn_values):
         v.value = d
Esempio n. 3
0
 def call(xs):
     init_values = [v.value for v in dyn_vars]
     try:
         dyn_values, out_values = lax.scan(f=fun2scan,
                                           init=init_values,
                                           xs=xs)
     except UnexpectedTracerError as e:
         for v, d in zip(dyn_vars, init_values):
             v.value = d
         raise errors.JaxTracerError(variables=dyn_vars) from e
     for v, d in zip(dyn_vars, dyn_values):
         v.value = d
     return tree_unflatten(tree, out_values)
Esempio n. 4
0
 def call(pred, x=None):
     old_values = [v.value for v in dyn_vars]
     try:
         dyn_values, res = lax.cond(pred=pred,
                                    true_fun=_true_fun,
                                    false_fun=_false_fun,
                                    operand=(old_values, x))
     except UnexpectedTracerError as e:
         for v, d in zip(dyn_vars, old_values):
             v.value = d
         raise errors.JaxTracerError(variables=dyn_vars) from e
     for v, d in zip(dyn_vars, dyn_values):
         v.value = d
     return res
Esempio n. 5
0
 def call(*args, **kwargs):
     dyn_data = dyn_vars.dict()
     n = args[batch_idx[0]].shape[batch_idx[1]]
     rand_data = {key: val.split_keys(n) for key, val in rand_vars.items()}
     try:
         out, dyn_changes, rand_changes = vmapped_func(
             dyn_data, rand_data, *args, **kwargs)
     except UnexpectedTracerError as e:
         dyn_vars.assign(dyn_data)
         rand_vars.assign(rand_data)
         raise errors.JaxTracerError(variables=dyn_vars) from e
     for key, v in dyn_changes.items():
         dyn_vars[key] = reduce_func(v)
     for key, v in rand_changes.items():
         rand_vars[key] = reduce_func(v)
     return out
Esempio n. 6
0
    def call_func(*args, **kwargs):
        old_grad_vs = [v.value for v in grad_vars]
        old_dyn_vs = [v.value for v in dyn_vars]
        try:
            grads, (outputs, new_grad_vs,
                    new_dyn_vs) = grad_func(old_grad_vs, old_dyn_vs, *args,
                                            **kwargs)
        except UnexpectedTracerError as e:
            for v, d in zip(grad_vars, old_grad_vs):
                v.value = d
            for v, d in zip(dyn_vars, old_dyn_vs):
                v.value = d
            raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e
        for v, d in zip(grad_vars, new_grad_vs):
            v.value = d
        for v, d in zip(dyn_vars, new_dyn_vs):
            v.value = d

        # check returned grads
        if len(grad_vars) == 0:
            grads = grads[1] if isinstance(argnums, int) else grads[1:]
        else:
            var_grads = tree_unflatten(grad_tree, grads[0])
            if argnums is None:
                grads = var_grads
            else:
                arg_grads = grads[1] if isinstance(argnums, int) else grads[1:]
                grads = (var_grads, arg_grads)

        # check returned value
        if return_value:
            # check aux
            if has_aux:
                return grads, outputs[0], outputs[1]
            else:
                return grads, outputs
        else:
            # check aux
            if has_aux:
                return grads, outputs[1]
            else:
                return grads