def testAxisGroups(self): axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2)) groups = xla.axis_groups(axis_env, 'i') self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7))) groups = xla.axis_groups(axis_env, 'j') self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7))) groups = xla.axis_groups(axis_env, ('i', 'j')) self.assertEqual(groups, (( 0, 1, 2, 3, 4, 5, 6, 7, ), )) groups = xla.axis_groups(axis_env, ('j', 'i')) self.assertEqual(len(groups), 1) self.assertEqual((tuple(sorted(groups[0])), ), (( 0, 1, 2, 3, 4, 5, 6, 7, ), )) # order doesn't matter
def _replica_groups(axis_env, axis_name, axis_index_groups): replica_groups = xla.axis_groups(axis_env, axis_name) if axis_index_groups is not None: replica_groups = [[axis_group[i] for i in axis_index_group] for axis_group in replica_groups for axis_index_group in axis_index_groups] return replica_groups