def testRingPermutations(self): # 0 devices pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) self.assertEqual(pred_by_c_d, []) self.assertEqual(rank_by_c_d, []) # 1 worker, 1 subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0]) self.assertEqual(pred_by_c_d, [[0]]) self.assertEqual(rank_by_c_d, [[0]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0, 1, 2]) self.assertEqual(pred_by_c_d, [[2, 0, 1]]) self.assertEqual(rank_by_c_d, [[0, 1, 2]]) # multiple workers, 1 subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [0, 1, 2]) self.assertEqual(pred_by_c_d, [[5, 0, 1, 2, 3, 4]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(3, 1, [0, 1, 2]) self.assertEqual(pred_by_c_d, [[8, 0, 1, 2, 3, 4, 5, 6, 7]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7, 8]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [2, 1, 0]) self.assertEqual(pred_by_c_d, [[1, 2, 3, 4, 5, 0]]) self.assertEqual(rank_by_c_d, [[2, 1, 0, 5, 4, 3]]) # 1 worker, multiple subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3]) self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [2, 3, 0, 1]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 4, [0, 1, 2, 3]) self.assertEqual( pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2], [3, 0, 1, 2], [3, 0, 1, 2]]) self.assertEqual( rank_by_c_d, [[0, 1, 2, 3], [3, 0, 1, 2], [2, 3, 0, 1], [1, 2, 3, 0]]) # multiple worker, multiple subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 1, 2, 3]) self.assertEqual(pred_by_c_d, [[7, 0, 1, 2, 3, 4, 5, 6], [3, 0, 5, 2, 7, 4, 1, 6]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7], [2, 3, 0, 1, 6, 7, 4, 5]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 3, 2, 1]) self.assertEqual(pred_by_c_d, [[5, 2, 3, 0, 1, 6, 7, 4], [1, 2, 7, 0, 5, 6, 3, 4]]) self.assertEqual(rank_by_c_d, [[0, 3, 2, 1, 4, 7, 6, 5], [2, 1, 0, 3, 6, 5, 4, 7]])
def testRingPermutations(self): # 0 devices pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) self.assertEqual(pred_by_c_d, []) self.assertEqual(rank_by_c_d, []) # 1 worker, 1 subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0]) self.assertEqual(pred_by_c_d, [[0]]) self.assertEqual(rank_by_c_d, [[0]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0, 1, 2]) self.assertEqual(pred_by_c_d, [[2, 0, 1]]) self.assertEqual(rank_by_c_d, [[0, 1, 2]]) # multiple workers, 1 subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [0, 1, 2]) self.assertEqual(pred_by_c_d, [[5, 0, 1, 2, 3, 4]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(3, 1, [0, 1, 2]) self.assertEqual(pred_by_c_d, [[8, 0, 1, 2, 3, 4, 5, 6, 7]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7, 8]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [2, 1, 0]) self.assertEqual(pred_by_c_d, [[1, 2, 3, 4, 5, 0]]) self.assertEqual(rank_by_c_d, [[2, 1, 0, 5, 4, 3]]) # 1 worker, multiple subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3]) self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [2, 3, 0, 1]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 4, [0, 1, 2, 3]) self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2], [3, 0, 1, 2], [3, 0, 1, 2]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [3, 0, 1, 2], [2, 3, 0, 1], [1, 2, 3, 0]]) # multiple worker, multiple subchunk cases pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 1, 2, 3]) self.assertEqual(pred_by_c_d, [[7, 0, 1, 2, 3, 4, 5, 6], [3, 0, 5, 2, 7, 4, 1, 6]]) self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7], [2, 3, 0, 1, 6, 7, 4, 5]]) pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 3, 2, 1]) self.assertEqual(pred_by_c_d, [[5, 2, 3, 0, 1, 6, 7, 4], [1, 2, 7, 0, 5, 6, 3, 4]]) self.assertEqual(rank_by_c_d, [[0, 3, 2, 1, 4, 7, 6, 5], [2, 1, 0, 3, 6, 5, 4, 7]])
def testBuildRingGatherPassStructure(self): # 1 worker, 1 device input_tensors, device_names = self._buildInput(1, 1) pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0]) output_tensors = ar._build_ring_gather(input_tensors, device_names, 1, pred_by_c_d, rank_by_c_d, math_ops.add) self.assertEqual(output_tensors, input_tensors) # 1 worker, 4 devices, 2 subchunks input_tensors, device_names = self._buildInput(1, 4) pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3]) output_tensors, pad_len = ar._build_ring_gather( input_tensors, device_names, 2, pred_by_c_d, rank_by_c_d, math_ops.add) self.assertEqual(0, pad_len) # same number outputs as inputs self.assertEqual(len(output_tensors), len(input_tensors)) num_chunks = 2 * len(input_tensors) tlen = tensor_shape.dimension_value(input_tensors[0].shape[0]) for otl in output_tensors: self.assertEqual(len(otl), num_chunks) for ot in otl: self.assertEqual(ot.shape, [tlen/num_chunks])