예제 #1
0
    def test_shard_module_sub_process_group(self):
        megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank)
        colwise_sharding_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:2",
                "rank:1/cuda:3",
            ],
        )
        rowwise_sharding_spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0/cuda:2",
                "rank:1/cuda:3",
            ],
        )
        sharding_plan = ShardingPlan(
            plan={
                "fc1.weight": colwise_sharding_spec,
                "fc2.weight": rowwise_sharding_spec
            })

        pg = dist.new_group([2, 3])

        if self.rank >= 2:
            shard_module(megatron_lm, sharding_plan, process_group=pg)
예제 #2
0
def _generate_sharding_spec(world_size):
    """
    We first need to create a sharding spec for our sharding work.

    For now, we only support sharding on one dimension. So we use
    ``ChunkShardingSpec`` to chunk the size of the given sharding
    dim to equally split length. The behavior is similar to
    `torch.chunk`.

    We also need to create the output sharding spec for the second nn
    because we need to aggregate(reduce) the partial result after the
    second nn layer. So we have a new sharding spec to represent that
    how we store the aggregation result in a new sharded tensor.
    """
    placements = [f"rank:{idx}/cuda:{idx}" for idx in range(world_size)]
    # Shard the first nn module's weight by dim 0.
    # (nn.Linear transposes the weight internally so dim 0 actually means column)
    colwise_spec = ChunkShardingSpec(
        dim=0,
        placements=placements,
    )
    # Shard the second nn module's weight by dim 1.
    rowwise_spec = ChunkShardingSpec(
        dim=1,
        placements=placements,
    )
    # The result from the second nn.linear layer needs aggregation by dim 0.
    output_spec = ChunkShardingSpec(
        dim=0,
        placements=placements,
    )
    return colwise_spec, rowwise_spec, output_spec
예제 #3
0
 def test_get_chunk_sharding_params(self):
     ranks = [
         "rank:0/cuda:0",
         "rank:1/cuda:1",
         "rank:2/cuda:2",
         "rank:3/cuda:3",
     ]
     spec = ChunkShardingSpec(
         dim=0,
         placements=ranks,
     )
     result = get_chunk_sharding_params(21, 4, spec, 1)
     self.assertEqual(6, result[0])
     self.assertEqual(6, result[1])
     result = get_chunk_sharding_params(21, 4, spec, 3)
     self.assertEqual(18, result[0])
     self.assertEqual(3, result[1])
     ranks[1], ranks[2] = ranks[2], ranks[1]
     ranks[0], ranks[3] = ranks[3], ranks[0]
     spec.placements = ranks
     result = get_chunk_sharding_params(21, 4, spec, 1)
     self.assertEqual(12, result[0])
     self.assertEqual(6, result[1])
     result = get_chunk_sharding_params(21, 4, spec, 3)
     self.assertEqual(0, result[0])
     self.assertEqual(6, result[1])
예제 #4
0
def generate_chunk_sharding_specs_for_test(sharding_dim):
    return [
        ChunkShardingSpec(
            dim=sharding_dim,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        ),
        # Test different ordering. (Case 1)
        ChunkShardingSpec(
            dim=sharding_dim,
            placements=[
                "rank:2/cuda:2",
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        ),
        # Test different ordering. (Case 2)
        ChunkShardingSpec(
            dim=sharding_dim,
            placements=[
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
            ],
        ),
    ]
    def test_load_rowwise_to_colwise(self) -> None:
        path = self.get_file_path()
        self.assertEqual(self.world_size, dist.get_world_size())

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        src_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        dst_spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        if dist.get_rank() == 0:
            shutil.rmtree(path, ignore_errors=True)
            os.makedirs(path)

        model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank())
        model_to_save._register_state_dict_hook(state_dict_hook)
        state_dict_to_save = model_to_save.state_dict()

        fs_writer = FileSystemWriter(path=path)
        save_state_dict(state_dict=state_dict_to_save,
                        storage_writer=fs_writer)

        model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank())
        model_to_load._register_state_dict_hook(state_dict_hook)
        state_dict_to_load_to = model_to_load.state_dict()

        fs_reader = FileSystemReader(path=path)

        load_state_dict(state_dict=state_dict_to_load_to,
                        storage_reader=fs_reader)

        # We can't use torch.allclose since each ST has a different sharding spec
        store_tensor = self.load_tensor(model_to_save.sharded_tensor)
        load_tensor = self.load_tensor(model_to_load.sharded_tensor)

        if dist.get_rank() == 0:
            self.assertTrue(torch.allclose(store_tensor, load_tensor))
예제 #6
0
    def test_custom_sharder_errors(self):
        custom_sharder = CustomSharder(
            devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
            split_sharding_idx=TEST_GPU_NUM // 2)

        sharding_plan = ShardingPlan(plan={
            "": custom_sharder,
        })

        sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)

        with self.assertRaisesRegex(
                KeyError, "path must not be empty for custom sharder!"):
            # shard the module with the provided sharding plan
            shard_module(sharded_model, sharding_plan)

        # test conflicted sharding plan
        spec = ChunkShardingSpec(dim=0,
                                 placements=["rank:0/cuda:0", "rank:1/cuda:1"])
        sharding_plan = ShardingPlan(
            plan={
                "embedding_bags.embedding_bag_0.weight": spec,
                "embedding_bags": custom_sharder,
            })

        with self.assertRaisesRegex(
                RuntimeError, "should not conflict with the submodule tree"):
            # shard the module with the provided sharding plan
            shard_module(sharded_model, sharding_plan)
예제 #7
0
    def test_named_params_with_sharded_tensor(self):
        rowwise_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)
        sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
        param_keys = list(sharded_model_params.keys())
        self.assertEqual(len(param_keys), 2)
        self.assertTrue("param" in param_keys)
        self.assertTrue("sharded_param" in param_keys)

        sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank)
        sharded_linear.shard_parameter()
        sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear))
        param_keys = list(sharded_linear_params.keys())
        self.assertEqual(len(param_keys), 4)
        self.assertTrue("linear1.bias" in param_keys)
        self.assertTrue("linear2.bias" in param_keys)
        self.assertTrue("linear1.weight" in param_keys)
        self.assertTrue("linear2.weight" in param_keys)
        self.assertFalse("bias" in param_keys)
예제 #8
0
 def get_spec(self):
     return ChunkShardingSpec(
         dim=0,
         placements=[
             f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
         ]
     )
예제 #9
0
    def test_init_sharded_tensor_with_normal(self):
        """ Test torch.nn.init.normal_(ShardedTensor, mean, std) """

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        h, w = 8, 2
        expected_h = 2
        expected_device = torch.device(f"cuda:{self.rank}")
        mean, std = 10, 5

        seed = 1234
        dtype = torch.double

        st = sharded_tensor.empty(spec, h, w, dtype=dtype)
        self.assertEqual(1, len(st.local_shards()))

        # Clone local tensor to ensure torch.nn.init starts from the same input
        local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
        torch.manual_seed(seed)
        torch.nn.init.normal_(st, mean=mean, std=std)

        torch.manual_seed(seed)
        torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std)
        self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
예제 #10
0
    def build_plan(self, module: nn.Module) -> ShardingPlan:
        named_params = module.named_parameters()
        plan = {}
        for name, param in named_params:
            plan[name] = ChunkShardingSpec(self.dim, placements=self.devices)

        return ShardingPlan(plan=plan)
예제 #11
0
    def test_basic_math_ops(self):
        ops = [
            "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*",
            "/"
        ]

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_lhs = sharded_tensor.rand(spec, (12, 3))
        sharded_rhs = sharded_tensor.rand(spec, (12, 3))
        current_rank = dist.get_rank()
        global_lhs = torch.empty((12, 3)) if current_rank == 0 else None
        global_rhs = torch.empty((12, 3)) if current_rank == 0 else None
        sharded_lhs.gather(dst=0, out=global_lhs)
        sharded_rhs.gather(dst=0, out=global_rhs)

        for op in ops:
            binary_op = gen_binary_op_func(op)
            binary_op_ = gen_binary_op_func(op, inplace=True)
            # test basic math ops between ShardedTensors
            sharded_output = binary_op(sharded_lhs, sharded_rhs)
            output = torch.empty((12, 3)) if current_rank == 0 else None
            sharded_output.gather(dst=0, out=output)

            if current_rank == 0:
                global_output = binary_op(global_lhs, global_rhs)

                self.assertEqual(output, global_output)

            # test basic math ops between ShardedTensor and scalar
            scalars = [3, 1.8]
            for scalar in scalars:
                sharded_output_lhs = binary_op(sharded_lhs, scalar)

                sharded_output_lhs_ = binary_op_(sharded_lhs, scalar)
                self.assertTrue(
                    torch.allclose(sharded_output_lhs, sharded_output_lhs_))
                output_lhs = torch.empty(
                    (12, 3)) if current_rank == 0 else None
                sharded_output_lhs.gather(dst=0, out=output_lhs)

                sharded_output_rhs = binary_op(scalar, sharded_lhs)
                output_rhs = torch.empty(
                    (12, 3)) if current_rank == 0 else None
                sharded_output_rhs.gather(dst=0, out=output_rhs)

                if current_rank == 0:
                    global_output_lhs = binary_op(global_lhs, scalar)
                    global_output_rhs = binary_op(scalar, global_lhs)

                    self.assertEqual(output_lhs, global_output_lhs)
                    self.assertEqual(output_rhs, global_output_rhs)
예제 #12
0
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    When the local tensor only has one dimension, we increase one more dimension
    for reshard. We need to do squeeze manually to reduce the dimension later-on.

    For example, if we have:
    input: size[15]
    weight: size[15, 16]
    world_size: 4

    In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
    tensor and generate a sharded tenor sharded by dim 1.

    For the rest situations, we just simply concatenate local tensors. No more actions
    are needed afterward.

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    # allgather the inputs first.
    gathered_inputs = all_gather(input, group=pg)
    (start_pos,
     chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                             weight._sharding_spec, rank)
    local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size,
                                         weight, pg, bias)
    results = []
    for i, inp in enumerate(gathered_inputs):
        results.append(inp.matmul(local_shard_t) + local_bias)
    # When the local result only has one dimension, we need to make sure
    # it does not shard by dim 0. So reshard can work properly.
    if results[0].dim() == 1:  # type: ignore[attr-defined]
        result = torch.stack(results)  # type: ignore[arg-type]
    else:
        result = torch.cat(results)  # type: ignore[arg-type]
    st_size = list(result.size())
    st_size[-1] = weight.size(0)
    new_sharding_spec = ChunkShardingSpec(
        dim=-1, placements=weight.sharding_spec().placements)
    return ShardedTensor._init_from_local_tensor(
        result,
        new_sharding_spec,
        *st_size,  # type: ignore[arg-type]
        process_group=pg,
    )
예제 #13
0
 def _reshard_spec_for_subgroup(self, rank):
     if rank in [0, 1]:
         return ChunkShardingSpec(
             dim=0,
             placements=[
                 "rank:0/cuda:0",
                 "rank:1/cuda:1",
             ],
         )
     else:
         return ChunkShardingSpec(
             dim=0,
             placements=[
                 "rank:0/cuda:2",
                 "rank:1/cuda:3",
             ],
         )
예제 #14
0
 def spec(self) -> ChunkShardingSpec:
     # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
     return ChunkShardingSpec(
         dim=0,
         placements=[
             "rank:0/cuda:0",
             "rank:1/cuda:1",
         ],
     )
예제 #15
0
    def test_sharded_optim(self):
        rowwise_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        local_model = MyShardedModel().cuda(self.rank)
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)

        # copy the parameteres from local model
        sharded_model.sharded_param.local_shards()[0].tensor = \
            local_model.sharded_param.detach().clone().requires_grad_()

        local_optim = optim.SGD(local_model.parameters(), lr=0.1)
        sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
        sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)

        local_optim.zero_grad()
        sharded_optim.zero_grad()

        before_update = deepcopy(sharded_optim.named_params)

        inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_()

        # run forward
        local_output = local_model(inp)
        sharded_output = sharded_model(inp)
        # backward
        local_output.sum().backward()
        sharded_output.sum().backward()

        # optimizer update
        local_optim.step()
        sharded_optim.step()

        # make sure the parameters (including sharded param)
        # get updated by the optimizer, and the updated
        # local params are the same as the sharded params
        for key, val in before_update.items():
            new_val = sharded_optim.named_params[key]
            if isinstance(val, sharded_tensor.ShardedTensor):
                self.assertNotEqual(
                    val.local_shards()[0].tensor,
                    new_val.local_shards()[0].tensor
                )
                self.assertEqual(
                    new_val.local_shards()[0].tensor,
                    local_model.sharded_param
                )
            else:
                self.assertNotEqual(val, new_val)
                self.assertEqual(new_val, local_model.param)
예제 #16
0
def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
    spec_list = []
    for i in range(len(sharding_dims)):
        random.Random(seed + i).shuffle(PLACEMENTS)
        spec_list.append(
            ChunkShardingSpec(
                dim=sharding_dims[i],
                placements=copy.deepcopy(PLACEMENTS),
            ))
    return spec_list
예제 #17
0
    def get_gpu_specs(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        alt_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:1",
                "rank:0/cuda:0",
                "rank:3/cuda:3",
                "rank:2/cuda:2",
            ],
        )
        return spec, alt_spec
예제 #18
0
    def _test_common_failures(self, cmp_op):
        spec, alt_spec = self.get_gpu_specs()

        st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
        if self.rank == 0:
            torch.nn.init.uniform_(st1.local_shards()[0].tensor)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.ones(spec, 10, 5)
        self.assertFalse(cmp_op(st1, st2))

        st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.zeros(spec, 10, 10)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True)
        self.assertFalse(cmp_op(st1, st2))

        cpu_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cpu",
                "rank:1/cpu",
                "rank:2/cpu",
                "rank:3/cpu",
            ],
        )
        st1 = sharded_tensor.ones(cpu_spec, 10, 10)
        st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
        self.assertFalse(cmp_op(st1, st2))

        pg = dist.new_group([1, 0, 3, 2])
        st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg)
        with self.assertRaisesRegex(
                RuntimeError,
                "All distributed tensors should use the same ProcessGroup"):
            cmp_op(st1, st2)

        pg = dist.new_group([0, 1, 2, 3])
        st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg)
        with self.assertRaisesRegex(
                RuntimeError,
                "All distributed tensors should use the same ProcessGroup"):
            cmp_op(st1, st2)
예제 #19
0
    def test_storage_key_mapping(self) -> None:
        device = f"cuda:{dist.get_rank()}"
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )

        state_dict = {
            'sharded': sharded_tensor.rand(spec, (
                10,
                10,
            )),
            'replicated': torch.rand(4, device=device),
            'bytes': [1, 2, 3, 4],
        }

        metadata, bytes_reqs, tensor_reqs = _prepare(
            state_dict, write_replicated_data=self.rank == 0)

        if self.rank == 0:
            self.assertEqual(1, len(bytes_reqs))
            self.assertEqual(2, len(tensor_reqs))

            self.assertTrue('bytes' in metadata.state_dict_metadata)
            self.assertEqual(bytes_reqs[0].storage_key,
                             metadata.state_dict_metadata['bytes'].storage_key)

            # tensor ordering is unspecified
            if len(tensor_reqs[0].tensor.size()) == 1:
                replicated = tensor_reqs[0]
                shard = tensor_reqs[1]
            else:
                replicated = tensor_reqs[1]
                shard = tensor_reqs[0]

            self.assertTrue('replicated' in metadata.state_dict_metadata)
            self.assertEqual(
                replicated.storage_key,
                metadata.state_dict_metadata['replicated'].storage_key)
        else:
            self.assertEqual(0, len(bytes_reqs))
            self.assertEqual(1, len(tensor_reqs))
            shard = tensor_reqs[0]

            self.assertTrue('sharded' in metadata.state_dict_metadata)
            shard_keys = [
                sm.storage_key for sm in
                metadata.state_dict_metadata['sharded'].storage_metadata
            ]
            self.assertTrue(shard.storage_key in shard_keys)
예제 #20
0
    def shard_parameter(self):
        rowwise_sharding_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        colwise_sharding_spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_tensor.shard_parameter(self.linear1, "weight", rowwise_sharding_spec)
        sharded_tensor.shard_parameter(self.linear2, "weight", colwise_sharding_spec)
예제 #21
0
 def test_clone(self):
     spec = ChunkShardingSpec(
         dim=0,
         placements=[
             "rank:0/cuda:0",
             "rank:1/cuda:1",
             "rank:2/cuda:2",
             "rank:3/cuda:3",
         ],
     )
     st = sharded_tensor.rand(spec, (12, 5))
     copied_st = st.clone()
     self.assertTrue(type(copied_st) is type(st))
     self.assertEqual(copied_st.local_tensor(), st.local_tensor())
     self.assertFalse(copied_st is st)
예제 #22
0
    def test_tensor_metadata_with_missing_rank_spec(self) -> None:
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:1",
            ],
        )

        st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
        mapping = dict()

        (_, md, storage_md) = _prepare_sharded_tensor_write("fqn", st, "tensor", mapping)

        self.assertEqual(1, len(storage_md))
        self.assertEqual(1, len(mapping))
    def test_read_write_shard_tensor(self) -> None:
        paths = [tempfile.mkdtemp()]
        dist.broadcast_object_list(paths)

        path = paths[0]

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        model_to_save = MyShardedModel1(spec, init_rrefs=False)

        # Test save
        model_to_save._register_state_dict_hook(state_dict_hook)
        state_dict_to_save = model_to_save.state_dict()

        fs_writer = FileSystemWriter(path=path)
        save_state_dict(state_dict=state_dict_to_save,
                        storage_writer=fs_writer)

        dist.barrier()

        # Create a new model
        model_to_load = MyShardedModel1(spec, init_rrefs=False)
        # This is not the correct hook for loading the state dict
        # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
        model_to_load._register_state_dict_hook(state_dict_hook)
        state_dict_to_load_to = model_to_load.state_dict()

        dist.barrier()

        with self.assertRaises(AssertionError):
            assert_state_dict_equal(self, state_dict_to_load_to,
                                    state_dict_to_save)

        # Test load.
        fs_reader = FileSystemReader(path=path)
        load_state_dict(state_dict=state_dict_to_load_to,
                        storage_reader=fs_reader)

        assert_state_dict_equal(self, state_dict_to_load_to,
                                state_dict_to_save)
        dist.barrier()
예제 #24
0
    def _test_sharded_softmax(self, softmax_dim, sharding_dim):
        torch.manual_seed(0)
        local_tensor = torch.rand(10, 10, device=self.rank)
        local_softmax = torch.nn.functional.softmax(local_tensor, softmax_dim)

        spec = ChunkShardingSpec(dim=sharding_dim,
                                 placements=[
                                     f'rank:{idx}/cuda:{idx}'
                                     for idx in range(self.world_size)
                                 ])
        st = _shard_tensor(local_tensor, spec)
        sharded_softmax = torch.nn.functional.softmax(st, softmax_dim)

        self.assertEqual(
            local_softmax.chunk(self.world_size, dim=sharding_dim)[self.rank],
            sharded_softmax.local_tensor())
예제 #25
0
    def test_math_ops_errors(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_lhs = sharded_tensor.rand(spec, (20, 3))
        sharded_rhs = sharded_tensor.rand(spec, (12, 3))

        with self.assertRaisesRegex(
            RuntimeError, "Implicit broadcasting not supported"
        ):
            torch.add(sharded_lhs, sharded_rhs)

        spec = EnumerableShardingSpec(
            [
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                ),
            ]
        )

        st = sharded_tensor.rand(spec, 10, 10)

        with self.assertRaisesRegex(RuntimeError, "not supported"):
            torch.add(st, sharded_rhs)
예제 #26
0
    def test_storage_key_mapping(self) -> None:
        device = f"cuda:{dist.get_rank()}"
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )

        state_dict = {
            'sharded': sharded_tensor.rand(spec, (10, 10, )),
            'replicated': torch.rand(4, device=device),
            'bytes': [1, 2, 3, 4],
        }

        metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0)

        if self.rank == 0:
            self.assertEqual(1, len(bytes_reqs))
            self.assertEqual(2, len(tensor_reqs))

            self.assertTrue('bytes' in metadata.state_dict_metadata)
            self.assertTrue(MetadataIndex('bytes') in metadata.storage_data)

            # tensor ordering is unspecified
            if len(tensor_reqs[0].tensor.size()) == 1:
                replicated = tensor_reqs[0]
                shard = tensor_reqs[1]
            else:
                replicated = tensor_reqs[1]
                shard = tensor_reqs[0]

            self.assertTrue('replicated' in metadata.state_dict_metadata)
            storage_key = MetadataIndex('replicated', torch.Size([0]))
            self.assertTrue(storage_key in metadata.storage_data)
            self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key)
        else:
            self.assertEqual(0, len(bytes_reqs))
            self.assertEqual(1, len(tensor_reqs))
            shard = tensor_reqs[0]
            local_shard = state_dict["sharded"].local_shards()[0]

            self.assertTrue('sharded' in metadata.state_dict_metadata)
            storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets))
            self.assertTrue(storage_key in metadata.storage_data)
            self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)
예제 #27
0
    def test_replicated_tensor_inter_op_sharded_tensor(self):
        torch.manual_seed(self.rank)

        local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4
        local_tensor2 = torch.ones(12, 3, device=f"cuda:{self.rank}") * 4

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st = _shard_tensor(local_tensor1, spec, src_rank=0)
        replica_tensor = ReplicatedTensor(local_tensor2)

        ops = [
            "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*",
            "/"
        ]

        for op in ops:
            binary_op = gen_binary_op_func(op)
            res = binary_op(st, replica_tensor)
            self.assertIsInstance(res, sharded_tensor.ShardedTensor)
            self.assertNotIsInstance(res, ReplicatedTensor)
            output = torch.empty((12, 3)) if self.rank == 0 else None
            res.gather(dst=0, out=output)

            if self.rank == 0:
                local_output = binary_op(local_tensor1, local_tensor2)
                self.assertEqual(output, local_output)

            # reflective
            reflect_res = binary_op(replica_tensor, st)
            self.assertIsInstance(reflect_res, sharded_tensor.ShardedTensor)
            self.assertNotIsInstance(reflect_res, ReplicatedTensor)
            reflect_output = torch.empty((12, 3)) if self.rank == 0 else None
            reflect_res.gather(dst=0, out=reflect_output)

            if self.rank == 0:
                reflect_local_output = binary_op(local_tensor2, local_tensor1)
                self.assertEqual(reflect_output, reflect_local_output)
예제 #28
0
    def test_detach(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        st = sharded_tensor.rand(spec, (12, 5), requires_grad=True)
        local_shards = st.local_shards()
        # before set requires_grad, all local shards should not require grads
        for local_shard in local_shards:
            self.assertTrue(local_shard.tensor.requires_grad)

        detached_st = st.detach()
        self.assertFalse(detached_st.requires_grad)

        for local_shard in detached_st.local_shards():
            self.assertFalse(local_shard.tensor.requires_grad)
예제 #29
0
    def test_replicated_tensor_implicit_broadcasting(self):
        #  use same seed
        torch.manual_seed(self.rank)

        # test implicit broadcasting
        local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4
        # we use size (3) to trigger the implicit broadcasting logic
        # and it will fail if implicit broadcasting not happen.
        local_tensor2 = torch.ones(3, device=f"cuda:{self.rank}")

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st = _shard_tensor(local_tensor1, spec, src_rank=0)
        replica_tensor = ReplicatedTensor(local_tensor2)

        ops = [
            "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*",
            "/"
        ]

        for op in ops:
            binary_op = gen_binary_op_func(op)
            # replicated tensor should automatically broadcasted
            res = binary_op(st, replica_tensor)

            self.assertIsInstance(res, sharded_tensor.ShardedTensor)
            output = torch.empty((12, 3)) if self.rank == 0 else None
            res.gather(dst=0, out=output)

            if self.rank == 0:
                local_output = binary_op(local_tensor1, local_tensor2)
                self.assertEqual(output, local_output)
예제 #30
0
    def test_replicated_tensor_inter_op_sharded_tensor_errors(self):
        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
        replica_tensor = ReplicatedTensor(local_tensor)

        torch.manual_seed(self.rank)
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st1 = sharded_tensor.rand(spec, (20, 3, 3))
        st2 = sharded_tensor.rand(spec, (30, 3, 3))

        with self.assertRaisesRegex(RuntimeError, 'Implicit broadcasting'):
            st1 + st2

        with self.assertRaisesRegex(RuntimeError, 'not supported for ShardedTensor'):
            st1 % replica_tensor