Esempio n. 1
0
    def test_cross_entropy(self):
        batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
        logits_scale = 1000.0
        seed = 1234
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size, )
                vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size
                loss_torch, grad_torch = torch_cross_entropy(
                    batch_size, sequence_length, vocab_size, logits_scale,
                    seed)
                (
                    loss_tensor_parallel,
                    grad_tensor_parallel,
                ) = tensor_sharded_cross_entropy(batch_size, sequence_length,
                                                 vocab_size, logits_scale,
                                                 seed)

                torch.testing.assert_close(loss_torch, loss_tensor_parallel)
                torch.testing.assert_close(grad_torch, grad_tensor_parallel)

                parallel_state.destroy_model_parallel()
Esempio n. 2
0
def parallel_self_attention(tensor_model_parallel_size,
                            num_att_heads_per_partition,
                            hidden_size_per_att_head, dropout_prob, batch_size,
                            sequence_length):
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
        torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    attention_layer = parallel_state.BertParallelSelfAttention(
        hidden_size, num_att_heads, dropout_prob).cuda()
    loss_weight = torch.randn([batch_size, sequence_length,
                               hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = attention_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = parallel_state.get_tensor_model_parallel_rank()
    parallel_state.destroy_model_parallel()
    return rank, hidden_size, tensor_model_parallel_size, loss, \
        attention_layer, identity_layer
Esempio n. 3
0
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):

    if torch.distributed.get_rank() == 0:
        print('> testing get_tensor_model_parallel_src_rank with size {} ...'.
              format(tensor_model_parallel_size_))
    tensor_model_parallel_size = min(
        tensor_model_parallel_size_,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    assert parallel_state.model_parallel_is_initialized()

    # Checks
    src_rank = torch.distributed.get_rank(
    ) - parallel_state.get_tensor_model_parallel_rank()
    assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank is None

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 4
0
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
                         hidden_size_per_att_head, batch_size, sequence_length):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
        torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads
    intermediate_size = 4 * hidden_size

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    transformer_layer = parallel_state.BertParallelTransformerLayer(
        hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
        torch.nn.functional.relu, 1.0e-5).cuda()

    loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = transformer_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = parallel_state.get_tensor_model_parallel_rank()
    parallel_state.destroy_model_parallel()
    return rank, hidden_size, tensor_model_parallel_size, loss, \
        transformer_layer, identity_layer
Esempio n. 5
0
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing model parallel cuda manual seed with size {} ...'.
              format(tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
    assert torch.cuda.initial_seed() == 12345
    with tensor_parallel.random.get_cuda_rng_tracker().fork():
        assert (torch.cuda.initial_seed() == 12345 + 2718 +
                parallel_state.get_tensor_model_parallel_rank())

    # Reset the tracker
    tensor_parallel.random.get_cuda_rng_tracker().reset()

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 6
0
    def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
        if self.world_size < 4:
            self.skipTest("requires >= 4 GPUs")
        self.assertFalse(parallel_state.model_parallel_is_initialized())

        tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
        pipeline_model_parallel_world_size = (self.world_size //
                                              tensor_model_parallel_world_size)
        virtual_pipeline_model_parallel_world_size = 2
        pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=tensor_model_parallel_world_size,
            pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
            virtual_pipeline_model_parallel_size_=
            virtual_pipeline_model_parallel_world_size,
            pipeline_model_parallel_split_rank_=
            pipeline_model_parallel_split_rank,
        )
        self.assertEqual(
            calc_expected_tensor_model_paralell_rank(
                self.rank, tensor_model_parallel_world_size),
            parallel_state.get_tensor_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_world_size,
            parallel_state.get_pipeline_model_parallel_world_size(),
        )
        self.assertEqual(
            virtual_pipeline_model_parallel_world_size,
            parallel_state.get_virtual_pipeline_model_parallel_world_size(),
        )

        expected_pipeline_rank = (self.rank -
                                  (self.rank % tensor_model_parallel_world_size
                                   )) % pipeline_model_parallel_world_size
        self.assertEqual(
            expected_pipeline_rank,
            parallel_state.get_pipeline_model_parallel_rank(),
        )
        # virtual pipeline model parallel rank is lazily set, i.e., right after the call of
        # `initialize_model_parallel`, it's set to 0.
        self.assertEqual(
            0,
            parallel_state.get_virtual_pipeline_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank(),
        )

        fake_split_rank = 77
        parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
        self.assertEqual(
            fake_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank())

        parallel_state.destroy_model_parallel()
Esempio n. 7
0
def test_broadcast_data(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print(
            '> testing broadcast_data with model parallel size {} ...'.format(
                tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    torch.manual_seed(1234 + parallel_state.get_data_parallel_rank())
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    key_size_t = {
        'key1': [7, 11],
        'key2': [8, 2, 1],
        'key3': [13],
        'key4': [5, 1, 2],
        'key5': [5, 12],
    }
    keys = list(key_size_t.keys())

    data = {}
    data_t = {}
    for key in key_size_t:
        data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
        data_t[key] = data[key].clone()
    data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
    data_t['keyX'] = data['keyX'].clone()
    if parallel_state.get_tensor_model_parallel_rank() != 0:
        data = None

    data_utils._check_data_types(keys, data_t, torch.int64)
    key_size, key_numel, \
        total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
    for key in keys:
        assert key_size[key] == key_size_t[key]
    total_numel_t = 0
    for key in keys:
        target_size = functools.reduce(operator.mul, key_size_t[key], 1)
        assert key_numel[key] == target_size
        total_numel_t += target_size
    assert total_numel == total_numel_t

    data_b = data_utils.broadcast_data(keys, data, torch.int64)
    for key in keys:
        tensor = data_t[key].cuda()
        assert data_b[key].sub(tensor).abs().max() == 0

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 8
0
def test__reduce(args, tensor_model_parallel_size):
    print("Testing reduction size =", tensor_model_parallel_size)
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )
    assert torch.equal(
        mappings._reduce(torch.full((10, 10, 10, 10), (50))),
        torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
    )
    parallel_state.destroy_model_parallel()
    print("Passed!")
Esempio n. 9
0
    def test_cuda_rng_tracker(self):
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )

                seed_1, seed_2, size = 1234, 4321, [12, 21]
                tensor = torch.cuda.FloatTensor(size)

                torch.cuda.manual_seed(seed_1)
                torch.randn(size, out=tensor)
                target_11 = tensor.clone()
                torch.randn(size, out=tensor)
                target_12 = tensor.clone()

                torch.cuda.manual_seed(seed_2)
                torch.randn(size, out=tensor)
                targt_21 = tensor.clone()
                torch.randn(size, out=tensor)
                target_22 = tensor.clone()

                torch.cuda.manual_seed(seed_1)
                tensor_parallel.random.get_cuda_rng_tracker().add(
                    "test", seed_2)

                torch.randn(size, out=tensor)
                result_11 = tensor.clone()

                with tensor_parallel.random.get_cuda_rng_tracker().fork(
                        "test"):
                    torch.randn(size, out=tensor)
                    result_21 = tensor.clone()

                torch.randn(size, out=tensor)
                result_12 = tensor.clone()

                with tensor_parallel.random.get_cuda_rng_tracker().fork(
                        "test"):
                    torch.randn(size, out=tensor)
                    result_22 = tensor.clone()

                self.assertEqual(target_11, result_11)
                self.assertEqual(target_12, result_12)
                self.assertEqual(targt_21, result_21)
                self.assertEqual(target_22, result_22)
                self.assertNotEqual(result_11, result_21)
                self.assertNotEqual(result_21, result_22)

                tensor_parallel.random.get_cuda_rng_tracker().reset()
                parallel_state.destroy_model_parallel()
Esempio n. 10
0
    def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
        for data_parallel_size in range(1, self.world_size + 1):

            expected_global_batch_size = self.GLOBAL_BATCH_SIZE
            expected_micro_batch_size = self.MICRO_BATCH_SIZE
            if rampup_batch_size:
                expected_global_batch_size = rampup_batch_size[0]
                num_consumed_samples = 0
                step_of_global_batch_size = rampup_batch_size[1]
                threshold = rampup_batch_size[2]

            if data_parallel_size > 1 and data_parallel_size % 2 != 0:
                continue
            if self.world_size % data_parallel_size != 0:
                continue
            with self.subTest(data_parallel_size=data_parallel_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=self.world_size //
                    data_parallel_size,
                    pipeline_model_parallel_size_=1,
                )
                self.assertEqual(data_parallel_size,
                                 parallel_state.get_data_parallel_world_size())

                _reconfigure_microbatch_calculator(
                    self.rank,
                    rampup_batch_size,
                    self.GLOBAL_BATCH_SIZE,
                    self.MICRO_BATCH_SIZE,
                    data_parallel_size,
                )

                self.assertEqual(get_micro_batch_size(),
                                 expected_micro_batch_size)
                self.assertEqual(
                    get_num_microbatches(), expected_global_batch_size /
                    expected_micro_batch_size / data_parallel_size)
                current_global_batch_size = get_current_global_batch_size()
                self.assertEqual(current_global_batch_size,
                                 expected_global_batch_size)

                # Make sure `global_batch_size` equals to the final global batch size after
                # certain number of updates.
                if rampup_batch_size:
                    update_num_microbatches(current_global_batch_size)
                    for i in range(100):
                        current_global_batch_size = get_current_global_batch_size(
                        )
                        update_num_microbatches(current_global_batch_size)
                    current_global_batch_size = get_current_global_batch_size()
                    self.assertEqual(get_current_global_batch_size(),
                                     self.GLOBAL_BATCH_SIZE)
                parallel_state.destroy_model_parallel()
Esempio n. 11
0
    def test_row_parallel_linear(self) -> None:
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )

                input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
                output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size

                set_random_seed(self.SEED)
                linear_layer = layers.RowParallelLinear(
                    input_size,
                    output_size,
                    keep_master_weight_for_test=True,
                    params_dtype=torch.float32,
                    use_cpu_initialization=True,
                ).cuda()
                loss_weight = torch.randn(
                    (self.BATCH_SIZE, output_size)).cuda()

                # Forward and backward
                input_tensor = torch.randn(self.BATCH_SIZE,
                                           input_size,
                                           requires_grad=True).cuda()
                input_tensor.retain_grad()
                output, _ = linear_layer(input_tensor)
                loss = torch.mul(output, loss_weight).sum()
                loss.backward()
                self.assertIsNotNone(input_tensor.grad)

                with torch.no_grad():
                    dldy = loss_weight.clone()
                    x = input_tensor.clone()
                    a = linear_layer.master_weight.cuda()
                dlda = torch.matmul(dldy.t(), x)
                dldb = torch.matmul(
                    torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy).view(-1)
                dldx = torch.matmul(dldy, a)

                with torch.no_grad():
                    curr_dlda = torch.split(
                        dlda, self.INPUT_SIZE_COEFF, dim=1
                    )[parallel_state.get_tensor_model_parallel_rank()].clone()
                self.assertEqual(linear_layer.weight.grad, curr_dlda)
                self.assertEqual(input_tensor.grad, dldx)
                self.assertEqual(linear_layer.bias.grad, dldb)

                parallel_state.destroy_model_parallel()
Esempio n. 12
0
def test__gather(args, tensor_model_parallel_size):

    print("Testing gathering size =", tensor_model_parallel_size)
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )
    assert torch.equal(
        mappings._gather(
            torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
        torch.tensor(list(range(tensor_model_parallel_size))),
    )
    parallel_state.destroy_model_parallel()
    print("Passed!")
Esempio n. 13
0
    def _affine_weight_init_test_impl(self, init_device: str,
                                      is_column_parallel: bool) -> None:
        dim = int(not is_column_parallel)
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )
                input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
                output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size

                weight_shape = ((self.OUTPUT_SIZE_COEFF,
                                 input_size) if is_column_parallel else
                                (output_size, self.INPUT_SIZE_COEFF))
                weight = torch.empty(weight_shape)
                set_random_seed(self.SEED)

                sharding_dim_size = (self.OUTPUT_SIZE_COEFF
                                     if is_column_parallel else
                                     self.INPUT_SIZE_COEFF)

                if init_device == "cpu":
                    layers._initialize_affine_weight_cpu(
                        weight,
                        output_size,
                        input_size,
                        sharding_dim_size,
                        dim,
                        nn.init.normal_,
                        params_dtype=torch.float32,
                    )
                else:
                    layers._initialize_affine_weight_gpu(
                        weight, torch.nn.init.normal_, dim)
                # Target
                set_random_seed(self.SEED)
                if init_device == "cpu":
                    main_weight = torch.empty(output_size, input_size)
                    nn.init.normal_(main_weight)
                    curr_weight = torch.split(
                        main_weight, sharding_dim_size, dim=dim)[
                            parallel_state.get_tensor_model_parallel_rank()]
                else:
                    curr_weight = torch.empty(*weight_shape)
                    nn.init.normal_(curr_weight)
                self.assertEqual(curr_weight, weight)
                parallel_state.destroy_model_parallel()
Esempio n. 14
0
def test__split(args, tensor_model_parallel_size):
    print("Testing splitting size =", tensor_model_parallel_size)
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )
    listy = []
    for i in range(tensor_model_parallel_size):
        listy.append(torch.randn(10, 1))
    x = torch.cat(tuple(listy), 1)
    out = mappings._split(x)
    assert torch.equal(out,
                       listy[parallel_state.get_tensor_model_parallel_rank()])
    parallel_state.destroy_model_parallel()
    print("Passed!")
Esempio n. 15
0
def test_column_parallel_linear_with_async_allreduce_custom_amp(
        tensor_model_parallel_size):
    dtypes = (torch.half,
              torch.bfloat16) if torch.cuda.is_bf16_supported() else (
                  torch.half, )

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    for dtype in dtypes:
        # Network
        identity_layer = IdentityLayer3D(batch_size, batch_size,
                                         input_size).to(device="cuda",
                                                        dtype=dtype)
        linear_layer = layers.ColumnParallelLinear(
            input_size,
            output_size,
            keep_master_weight_for_test=True,
            params_dtype=global_vars.get_args().params_dtype,
            use_cpu_initialization=global_vars.get_args().
            use_cpu_initialization,
        ).to(device="cuda", dtype=dtype)
        # Forward
        loss_weight = torch.randn([batch_size, output_size]).cuda()
        output, _ = linear_layer(identity_layer())
        loss = torch.mul(output, loss_weight).sum()
        loss.backward()
        torch.distributed.barrier()

        assert output.dtype == dtype

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Esempio n. 16
0
 def test_reduce(self):
     for tensor_model_paralell_world_size in range(1, self.world_size + 1):
         if self.world_size % tensor_model_paralell_world_size > 0:
             continue
         with self.subTest(tensor_model_paralell_world_size=
                           tensor_model_paralell_world_size):
             parallel_state.initialize_model_parallel(
                 tensor_model_parallel_size_=tensor_model_paralell_world_size
             )
             t = torch.full((10, 10, 10, 10),
                            50,
                            device=f"cuda:{self.rank}")
             expected = torch.full(
                 (10, 10, 10, 10),
                 50 * tensor_model_paralell_world_size,
                 device=f"cuda:{self.rank}",
             )
             self.assertTrue(torch.equal(mappings._reduce(t), expected))
             parallel_state.destroy_model_parallel()
Esempio n. 17
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_devices
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            # destroy groups in case they have already been created
            # this happens with multiple calls to trainer.test for example
            parallel_state.destroy_model_parallel()
            if torch.distributed.is_initialized():
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=app_state.
                    tensor_model_parallel_size,
                    pipeline_model_parallel_size_=app_state.
                    pipeline_model_parallel_size,
                    pipeline_model_parallel_split_rank_=app_state.
                    pipeline_model_parallel_split_rank,
                )

                # assert that fake tp and pp rank match after model parallel init
                assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank(
                )
                assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank(
                )

                app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group(
                )
                app_state.data_parallel_group = parallel_state.get_data_parallel_group(
                )
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank(
                )
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size(
                )
                app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group(
                )
Esempio n. 18
0
 def test_gather(self):
     for tensor_model_paralell_world_size in range(1, self.world_size + 1):
         if self.world_size % tensor_model_paralell_world_size > 0:
             continue
         with self.subTest(tensor_model_paralell_world_size=
                           tensor_model_paralell_world_size):
             parallel_state.initialize_model_parallel(
                 tensor_model_parallel_size_=tensor_model_paralell_world_size
             )
             device = f"cuda:{self.rank}"
             gathered = mappings._gather(
                 torch.tensor(
                     [parallel_state.get_tensor_model_parallel_rank()],
                     device=device))
             expected = torch.tensor(
                 [rank for rank in range(tensor_model_paralell_world_size)],
                 device=device,
             )
             self.assertTrue(torch.equal(gathered, expected))
             parallel_state.destroy_model_parallel()
Esempio n. 19
0
    def test_set_cuda_rng_state(self):
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )

                size, seed = 123, 1234
                torch.cuda.manual_seed(seed)
                tensor = torch.cuda.FloatTensor(size)

                rng_state = torch.cuda.get_rng_state()
                rng_state_clone = rng_state.clone()

                for _ in range(5):
                    torch.randn(size, out=tensor)
                result_1 = tensor.clone()

                self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)
                self.assertGreater(
                    torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0)

                new_rng_state = torch.cuda.get_rng_state()
                self.assertGreater(new_rng_state.sub(rng_state).max(), 0)

                tensor_parallel.random._set_cuda_rng_state(rng_state)
                for _ in range(5):
                    torch.randn(size, out=tensor)
                tensor_parallel.random._set_cuda_rng_state(rng_state)
                for _ in range(5):
                    torch.randn(size, out=tensor)
                result_2 = tensor.clone()

                torch.testing.assert_close(result_2, result_1)

                self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)

                parallel_state.destroy_model_parallel()
Esempio n. 20
0
    def test_initialize_model_parallel(self) -> None:

        self.assertFalse(parallel_state.model_parallel_is_initialized())

        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                if self.world_size % tensor_model_parallel_world_size:
                    continue

                pipeline_model_parallel_world_size = (
                    self.world_size // tensor_model_parallel_world_size)

                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size,
                    pipeline_model_parallel_size_=
                    pipeline_model_parallel_world_size,
                )
                self.assertEqual(
                    tensor_model_parallel_world_size,
                    parallel_state.get_tensor_model_parallel_world_size(),
                )
                expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
                    self.rank, tensor_model_parallel_world_size)
                self.assertEqual(
                    expected_tensor_model_parallel_rank,
                    parallel_state.get_tensor_model_parallel_rank(),
                )

                expected_tensor_model_parallel_src_rank = (
                    self.rank // tensor_model_parallel_world_size
                ) * tensor_model_parallel_world_size
                self.assertEqual(
                    expected_tensor_model_parallel_src_rank,
                    parallel_state.get_tensor_model_parallel_src_rank(),
                )

                parallel_state.destroy_model_parallel()
                self.assertFalse(
                    parallel_state.model_parallel_is_initialized())
Esempio n. 21
0
    def test_split(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            with self.subTest(tensor_model_paralell_world_size=
                              tensor_model_paralell_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_paralell_world_size
                )

                tensors = [
                    torch.randn(10, 1)
                    for rank in range(tensor_model_paralell_world_size)
                ]
                x = torch.cat(tensors, 1)
                out = mappings._split(x)
                self.assertTrue(
                    torch.equal(
                        out, tensors[
                            parallel_state.get_tensor_model_parallel_rank()]))
                parallel_state.destroy_model_parallel()
Esempio n. 22
0
def test_cross_entropy(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cross entropy with model parallel size {} ...'.format(
            tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    batch_size = 13
    seq_length = 17
    vocab_size_per_partition = 11
    logits_scale = 1000.0
    vocab_size = vocab_size_per_partition * tensor_model_parallel_size
    seed = 1234

    loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
                                                 vocab_size, logits_scale,
                                                 seed)
    loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length,
                                                      vocab_size, logits_scale,
                                                      seed)

    error = loss_torch.sub_(loss_mpu).abs().max()
    print('   max error in loss on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = grad_torch.sub_(grad_mpu).abs().max()
    print('   max error in grad on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 23
0
    def test_broadcast_data(self):
        tensor_model_parallel_world_size: int = self.world_size // (
            1 + self.world_size > 1)
        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=tensor_model_parallel_world_size)

        target_key_size = {
            "key1": [7, 11],
            "key2": [8, 2, 1],
            "key3": [13],
            "key4": [5, 1, 2],
            "key5": [5, 12],
        }
        keys = [k for k in target_key_size]

        data = {}
        data_t = {}
        with torch.no_grad():
            for key in target_key_size:
                data[key] = torch.randint(0, 1000, size=target_key_size[key])
                data_t[key] = data[key].clone()
            # "key_x" is supposed to be ignored.
            data["key_x"] = torch.rand(5)
            data_t["key_x"] = data["key_x"].clone()
        if parallel_state.get_tensor_model_parallel_rank() != 0:
            data = None

        data_utils._check_data_types(keys, data_t, torch.int64)
        key_size, _, _ = data_utils._build_key_size_numel_dictionaries(
            keys, data)

        for key in keys:
            self.assertEqual(target_key_size[key], key_size[key])

        broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64)
        for key in keys:
            torch.testing.assert_close(broadcasted_data[key],
                                       data_t[key].cuda())

        parallel_state.destroy_model_parallel()
Esempio n. 24
0
def test_initialize_model_parallel(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            tensor_model_parallel_size))
    tensor_model_parallel_size_ = min(
        tensor_model_parallel_size,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
    assert parallel_state.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = tensor_model_parallel_size_
    rank = torch.distributed.get_rank() % tensor_model_parallel_size_
    assert world_size == parallel_state.get_tensor_model_parallel_world_size()
    assert rank == parallel_state.get_tensor_model_parallel_rank()
    check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)

    # Data parallel.
    world_size = torch.distributed.get_world_size(
    ) // tensor_model_parallel_size_
    rank = torch.distributed.get_rank() // tensor_model_parallel_size
    assert world_size == parallel_state.get_data_parallel_world_size()
    assert rank == parallel_state.get_data_parallel_rank()
    check(parallel_state.get_data_parallel_group(), world_size, rank)

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 25
0
def test_pipeline_model_parallel_split_rank():
    pipeline_model_parallel_split_rank_ = 1
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(
        pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_
    )
    assert parallel_state.model_parallel_is_initialized()

    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank is pipeline_model_parallel_split_rank_

    fake_split_rank = 7
    parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank == fake_split_rank

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 26
0
    def test_split_tensor_along_last_dim(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            with self.subTest(tensor_model_paralell_world_size=
                              tensor_model_paralell_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_paralell_world_size
                )

                device = "cpu"
                input_tensor = torch.randn((100, 100, 100), device=device)
                splits = utils.split_tensor_along_last_dim(input_tensor, 10)
                last_dim_shapes = torch.tensor(
                    [int(split.size()[-1]) for split in splits])

                self.assertTrue(
                    torch.equal(
                        last_dim_shapes,
                        torch.full((10, ), 10),
                    ))

                parallel_state.destroy_model_parallel()
Esempio n. 27
0
        print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=args.tensor_model_parallel_size,
            pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
            default_backend="nccl",
            p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
        )

        pipeline_model_parallel_size = (
            parallel_state.get_pipeline_model_parallel_world_size()
        )
        model_parallel_cuda_manual_seed(0)
        model = build_model(
            gpt_model_provider,
            wrap_with_ddp=True,
            virtual_pipeline_model_parallel_size=None,
            cpu_offload=args.cpu_offload,
        )
        assert isinstance(model, list), model
        _param_groups = _get_params_for_weight_decay_optimization(model)
        optim = torch.optim.Adam(_param_groups)
        runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm)

        parallel_state.destroy_model_parallel()
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
        print("Average Iteration Time:", runtime)
Esempio n. 28
0
def test_cuda_rng_tracker(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cuda rng tracker with size {} ...'.format(
            tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed_1 = 1234
    seed_2 = 4321
    size = [12, 21]
    tensor = torch.cuda.FloatTensor(size)

    # Set to seed_1 and generate two tensors.
    torch.cuda.manual_seed(seed_1)
    torch.randn(size, out=tensor)
    target_11 = tensor.clone()
    torch.randn(size, out=tensor)
    target_12 = tensor.clone()

    # Set to seed_2 and generate two tensors.
    torch.cuda.manual_seed(seed_2)
    torch.randn(size, out=tensor)
    target_21 = tensor.clone()
    torch.randn(size, out=tensor)
    target_22 = tensor.clone()

    # Now if we interleave seed_1 and seed_2,
    # we should still get the same tensors
    torch.cuda.manual_seed(seed_1)
    tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)

    torch.randn(size, out=tensor)
    result_11 = tensor.clone()

    with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_21 = tensor.clone()

    torch.randn(size, out=tensor)
    result_12 = tensor.clone()

    with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_22 = tensor.clone()

    diff = result_11.sub(result_21).abs().max()
    diff = min(diff, result_12.sub(result_22).abs().max())
    print('   max diff in generated tensors (should be non-zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
    assert diff > 1.0e-6
    error = max(
        result_11.sub(target_11).abs().max(),
        result_12.sub(target_12).abs().max())
    error = max(error, result_21.sub(target_21).abs().max())
    error = max(error, result_22.sub(target_22).abs().max())
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset the tracker
    tensor_parallel.random.get_cuda_rng_tracker().reset()

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 29
0
def test_set_cuda_rng_state(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing set_rng_state with size {} ...'.format(
            tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    size = 123
    seed = 1234
    torch.cuda.manual_seed(seed)
    tensor = torch.cuda.FloatTensor(size)

    # Get the state
    rng_state = torch.cuda.get_rng_state()
    rng_state_copy = rng_state.clone()

    # Do some stuff.
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_1 = tensor.clone()

    assert rng_state.sub(rng_state_copy).max() == 0
    assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0

    # State should be different.
    new_rng_state = torch.cuda.get_rng_state()
    max_diff = new_rng_state.sub(rng_state).max()
    print(
        '   max diff in rng state (should be non-zero) on global rank {}: {}'.
        format(torch.distributed.get_rank(), max_diff))
    assert max_diff > 0

    # Reset the rng state and do the same stuff.
    tensor_parallel.random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    tensor_parallel.random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_2 = tensor.clone()

    # Results should be the same
    error = result_2.sub(result_1).abs().max()
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Input state should have remained intact.
    error = rng_state.sub(rng_state_copy).max()
    print('   max error in rng state (should be zero) on global rank {}: {}'.
          format(torch.distributed.get_rank(), error))
    assert error == 0

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 30
0
    def _forward_backward_test_impl(
        self,
        forward_only: bool,
        fwd_bwd_func: FwdStepFunc,
        pipeline_model_parallel_world_size: Optional[int],
        virtual_pipeline_model_parallel_size: Optional[int],
        async_comm: bool = False,
        *,
        default_backend: Optional[str] = None,
        p2p_backend: Optional[str] = None,
    ) -> None:
        if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
            self.assertIsNotNone(virtual_pipeline_model_parallel_size)
            self.assertGreater(virtual_pipeline_model_parallel_size, 1)
        dtype_options = self.dtypes or [torch.float32, torch.double
                                        ] + _get_autocast_dtypes()

        for dtype, deallocate_pipeline_outputs in itertools.product(
                dtype_options,
                self.deallocate_options,
        ):
            grad_scaler = (torch.cuda.amp.GradScaler(
                init_scale=4.0) if dtype == torch.half else None)

            (tensor_model_parallel_world_size, data_parallel_size,
             pipeline_model_parallel_world_size
             ) = _get_default_world_sizes_model_parallel_world_size(
                 pipeline_model_parallel_world_size)

            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_parallel_world_size,
                pipeline_model_parallel_size_=
                pipeline_model_parallel_world_size,
                virtual_pipeline_model_parallel_size_=
                virtual_pipeline_model_parallel_size,
                default_backend=default_backend,
                p2p_backend=p2p_backend,
            )
            pp_utils._reconfigure_microbatch_calculator(
                rank=parallel_state.get_tensor_model_parallel_rank(),
                rampup_batch_size=None,
                global_batch_size=self.GLOBAL_BATCH_SIZE,
                micro_batch_size=self.MICRO_BATCH_SIZE,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

            global_batch_shape = (
                self.GLOBAL_BATCH_SIZE //
                parallel_state.get_data_parallel_world_size(),
                self.HIDDEN_SIZE,
                self.HIDDEN_SIZE,
            )

            batch = None
            if parallel_state.is_pipeline_first_stage():
                batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )

            model = build_model(
                testing_utils.model_provider_func,
                # Use DDP only when it's better to have
                wrap_with_ddp=data_parallel_size > 1,
                virtual_pipeline_model_parallel_size=
                virtual_pipeline_model_parallel_size,
                hidden_size=self.HIDDEN_SIZE,
            )

            offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
            for idx, model_module in enumerate(model):
                model_module = model_module.to(dtype)
                model_module.apply(get_init_weights_func(idx * offset))

            _param_groups = _get_params_for_weight_decay_optimization(model)
            optimizer = torch.optim.Adam(_param_groups, lr=1e-3)

            pp_utils.update_num_microbatches(0)

            loss = fwd_bwd_func(
                testing_utils.fwd_step_func,
                batch,
                model,
                forward_only=forward_only,
                # `tensor_shape` is the shape of micro batch.
                tensor_shape=(
                    self.MICRO_BATCH_SIZE,
                    self.HIDDEN_SIZE,
                    self.HIDDEN_SIZE,
                ),
                dtype=dtype,
                async_comm=async_comm,
                grad_scaler=grad_scaler,
                deallocate_pipeline_output=deallocate_pipeline_outputs,
            )

            if dtype == torch.double:
                hidden_size = self.HIDDEN_SIZE
                microbatch_size = self.MICRO_BATCH_SIZE
                total_layers = pipeline_model_parallel_world_size
                if virtual_pipeline_model_parallel_size is not None:
                    total_layers *= virtual_pipeline_model_parallel_size
                target_loss, target_model = get_target_loss_and_model(
                    global_batch_shape, hidden_size, total_layers)

                for loss_item in loss:
                    x = loss_item['avg']
                    torch.testing.assert_close(x.item() / microbatch_size,
                                               target_loss.item())

                if not forward_only:
                    for vm_id, model_module in enumerate(model):
                        params = list(model_module.parameters())
                        rank = params[0].get_device()
                        offset = pipeline_model_parallel_world_size
                        param_id = rank // data_parallel_size + vm_id * offset
                        target_params = target_model[param_id]

                        torch.testing.assert_close(params[0].cpu(),
                                                   target_params[0])
                        torch.testing.assert_close(params[1].cpu(),
                                                   target_params[1])
                        torch.testing.assert_close(
                            params[0].grad.cpu() / microbatch_size,
                            target_params[0].grad)
                        torch.testing.assert_close(
                            params[1].grad.cpu() / microbatch_size,
                            target_params[1].grad)

            if not forward_only:
                for m in model:
                    for p in m.parameters():
                        self.assertIsNotNone(p.grad)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            parallel_state.destroy_model_parallel()