def test_cat_errors(self): with self.assertRaisesRegex( RuntimeError, 'All inputs need to be an instance of _PartialTensor' ): torch.cat([_PartialTensor(torch.rand(10)), torch.rand(10)]) with self.assertRaisesRegex( RuntimeError, 'reduce_ops need to be the same' ): torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10), reduce_op=dist.ReduceOp.MAX)]) with self.assertRaisesRegex( RuntimeError, '"out" kwarg is not supported' ): torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10))], out=torch.rand(10))
def _handle_row_wise_sharding_sharded_tensor(input, world_size, weight, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear when the input is a sharded tensor. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`_PartialTensor` object which stores the partial local result. """ results = [] local_shard = input.local_shards()[0].tensor if input.sharding_spec().dim not in (-1, len(input.size()) - 1): raise NotImplementedError( "The case when the input does not come from col-wise sharded " "linear is not supported for row-wise sharded linear.") for tensor in torch.tensor_split(local_shard, world_size): results.append( tensor.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)) # Return the partial local result. return _PartialTensor(torch.cat(results), pg)
def _run_partial_tensor_n_reshard( self, reshard_spec, input_size, world_size, reduce_op, dtype=torch.float ): results_compare = [] local_result = [] pg = dist.distributed_c10d._get_default_group() for rank in range(pg.size()): torch.manual_seed(rank) results = [] for _ in range(world_size): tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank) results.append(tensor) if self.rank == rank: local_result.append(tensor.clone().detach()) results_compare.append(torch.cat(results)) parital_tensor = _PartialTensor( torch.cat(local_result), pg, reduce_op=reduce_op ) local_sharded_result = parital_tensor.reshard(reshard_spec) local_shards = local_sharded_result.local_shards() results_compare = torch.stack(results_compare) if reduce_op == dist.ReduceOp.SUM: results_compare = torch.sum(results_compare, dim=0) else: results_compare = torch.max(results_compare, dim=0).values rank_idx = None for idx, placement in enumerate(reshard_spec.placements): if placement.rank() == self.rank: rank_idx = idx local_result_compare = results_compare.chunk(pg.size())[rank_idx] self.assertEqual(1, len(local_shards)) self.assertEqual(local_shards[0].tensor, local_result_compare)
def test_cat(self): t1 = torch.rand(5, 10) t2 = torch.rand(3, 10) t3 = torch.rand(4, 10) partial_tensors = [ _PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3) ] partial_concat = torch.cat(partial_tensors) local_concat = torch.cat([t1, t2, t3]) self.assertEqual(local_concat.size(), partial_concat.size()) # Test dim kwarg t1 = torch.rand(5, 10) t2 = torch.rand(5, 12) t3 = torch.rand(5, 11) partial_tensors = [ _PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3) ] partial_concat = torch.cat(partial_tensors, dim=1) local_concat = torch.cat([t1, t2, t3], dim=1) self.assertEqual(local_concat.size(), partial_concat.size())
def _handle_row_wise_sharding_tensor(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`_PartialTensor` object which stores the partial local result. """ # alltoall to gather all the appropriate inputs. input_t = input.transpose(0, -1).contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input_size = [input_split_sizes[rank] * world_size] + list( input_t_size[1:]) gathered_input = torch.empty(gathered_input_size, device=input_t.device) # Perform autograd enabled alltoall all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.transpose(0, -1) # Perform local matmuls for all shards results = [] shard_size = local_shard_t.size()[0] for r in range(world_size): inp = torch.narrow(gathered_input, -1, r * shard_size, shard_size) results.append( inp.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)) # Return the partial local result. return _PartialTensor(torch.cat(results), pg)
def test_transpose(self): partial_tensor = _PartialTensor(torch.rand(5, 10)) partial_tensor = partial_tensor.transpose(0, 1) self.assertEqual(partial_tensor.size(), torch.Size((10, 5)))