示例#1
0
    def testAxisGroups(self):
        axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
        groups = xla.axis_groups(axis_env, 'i')
        self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))

        groups = xla.axis_groups(axis_env, 'j')
        self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7)))

        groups = xla.axis_groups(axis_env, ('i', 'j'))
        self.assertEqual(groups, ((
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
        ), ))

        groups = xla.axis_groups(axis_env, ('j', 'i'))
        self.assertEqual(len(groups), 1)
        self.assertEqual((tuple(sorted(groups[0])), ), ((
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
        ), ))  # order doesn't matter
示例#2
0
def _replica_groups(axis_env, axis_name, axis_index_groups):
  replica_groups = xla.axis_groups(axis_env, axis_name)
  if axis_index_groups is not None:
    replica_groups = [[axis_group[i] for i in axis_index_group]
                      for axis_group in replica_groups
                      for axis_index_group in axis_index_groups]
  return replica_groups