def test_tree_map_zipped_wrong_structure(self): nests = [dict(a=jnp.zeros((1, 3)), b=jnp.zeros( (1, 5)))] * (NUM_NESTS - 1) nests.append(dict(c=jnp.zeros((1, 3)))) # add a non-matching nest with self.assertRaisesRegex(ValueError, 'must share the same tree'): tree_util.tree_map_zipped(lambda *args: jnp.concatenate(args), nests)
def test_tree_map_zipped_empty(self): outputs = tree_util.tree_map_zipped( lambda *args: jnp.concatenate(args), []) self.assertEmpty(outputs)
def test_tree_map_zipped(self): nests = [dict(a=jnp.zeros((1, 3)), b=jnp.zeros((1, 5)))] * NUM_NESTS nest_output = tree_util.tree_map_zipped( lambda *args: jnp.concatenate(args), nests) self.assertEqual(nest_output['a'].shape, (NUM_NESTS, 3)) self.assertEqual(nest_output['b'].shape, (NUM_NESTS, 5))