Exemple #1
0
 def testTransposeWithCustomObject(self):
   outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
   inner_treedef = tree_util.tree_structure([1, 2])
   expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
   actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
                                     FlatCache({"a": [3, 4], "b": [5, 6]}))
   self.assertEqual(expected, actual)
Exemple #2
0
 def jacfun(*args, **kwargs):
     f = linear_util.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if has_aux:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     else:
         y, pullback = _vjp(f_partial, *dyn_args, has_aux=False)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_unravel_array_into_pytree,
                                 y,
                                 0,
                                 is_leaf=_isleaf),
                         jac,
                         is_leaf=_isleaf)
     jac = tree_transpose(tree_structure(example_args),
                          tree_flatten(y, is_leaf=_isleaf)[1], jac_tree)
     if return_value:
         return (jac, y, aux) if has_aux else (jac, y)
     else:
         return (jac, aux) if has_aux else jac
Exemple #3
0
 def testTranspose(self, tree):
   outer_treedef = tree_util.tree_structure(tree)
   if not outer_treedef.num_leaves:
     self.skipTest("Skipping empty tree")
   inner_treedef = tree_util.tree_structure([1, 1, 1])
   nested = tree_util.tree_map(lambda x: [x, x, x], tree)
   actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
   self.assertEqual(actual, [tree, tree, tree])
Exemple #4
0
 def value_and_jacrev_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if not has_aux:
         y, pullback = _vjp(f_partial, *dyn_args)
     else:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
     if not has_aux:
         return y, tree_transpose(tree_structure(example_args),
                                  tree_structure(y), jac_tree)
     else:
         return (y, aux), tree_transpose(tree_structure(example_args),
                                         tree_structure(y), jac_tree)
     return
Exemple #5
0
 def testTransposeMismatchInner(self):
   tree = {"a": [1, 2], "b": [3, 4]}
   outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
   inner_treedef = tree_util.tree_structure([1, 2, 3])
   with self.assertRaisesRegex(TypeError, "Mismatch"):
     tree_util.tree_transpose(outer_treedef, inner_treedef, tree)