Exemplo n.º 1
0
 def test_tree_map_zipped_wrong_structure(self):
     nests = [dict(a=jnp.zeros((1, 3)), b=jnp.zeros(
         (1, 5)))] * (NUM_NESTS - 1)
     nests.append(dict(c=jnp.zeros((1, 3))))  # add a non-matching nest
     with self.assertRaisesRegex(ValueError, 'must share the same tree'):
         tree_util.tree_map_zipped(lambda *args: jnp.concatenate(args),
                                   nests)
Exemplo n.º 2
0
 def test_tree_map_zipped_empty(self):
     outputs = tree_util.tree_map_zipped(
         lambda *args: jnp.concatenate(args), [])
     self.assertEmpty(outputs)
Exemplo n.º 3
0
 def test_tree_map_zipped(self):
     nests = [dict(a=jnp.zeros((1, 3)), b=jnp.zeros((1, 5)))] * NUM_NESTS
     nest_output = tree_util.tree_map_zipped(
         lambda *args: jnp.concatenate(args), nests)
     self.assertEqual(nest_output['a'].shape, (NUM_NESTS, 3))
     self.assertEqual(nest_output['b'].shape, (NUM_NESTS, 5))