def testTransposeWithCustomObject(self): outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2})) inner_treedef = tree_util.tree_structure([1, 2]) expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})] actual = tree_util.tree_transpose(outer_treedef, inner_treedef, FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual)
def jacfun(*args, **kwargs): f = linear_util.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if has_aux: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) else: y, pullback = _vjp(f_partial, *dyn_args, has_aux=False) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_unravel_array_into_pytree, y, 0, is_leaf=_isleaf), jac, is_leaf=_isleaf) jac = tree_transpose(tree_structure(example_args), tree_flatten(y, is_leaf=_isleaf)[1], jac_tree) if return_value: return (jac, y, aux) if has_aux else (jac, y) else: return (jac, aux) if has_aux else jac
def testTranspose(self, tree): outer_treedef = tree_util.tree_structure(tree) if not outer_treedef.num_leaves: self.skipTest("Skipping empty tree") inner_treedef = tree_util.tree_structure([1, 1, 1]) nested = tree_util.tree_map(lambda x: [x, x, x], tree) actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested) self.assertEqual(actual, [tree, tree, tree])
def value_and_jacrev_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if not has_aux: y, pullback = _vjp(f_partial, *dyn_args) else: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac) if not has_aux: return y, tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) else: return (y, aux), tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) return
def testTransposeMismatchInner(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2}) inner_treedef = tree_util.tree_structure([1, 2, 3]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree)