コード例 #1
0
                                           logits_scale, seed)

    error = loss_torch.sub_(loss_mpu).abs().max()
    print('   max error in loss on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = grad_torch.sub_(grad_mpu).abs().max()
    print('   max error in grad on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


if __name__ == '__main__':

    initialize_distributed()
    world_size = torch.distributed.get_world_size()

    model_parallel_size = 1
    while model_parallel_size <= world_size:
        print_separator('test cross entropy')
        test_cross_entropy(model_parallel_size)
        model_parallel_size *= 2
コード例 #2
0
ファイル: test_data.py プロジェクト: yrchen92/CoDIR
    total_numel_t = 0
    for key in keys:
        target_size = functools.reduce(operator.mul, key_size_t[key], 1)
        assert key_numel[key] == target_size
        total_numel_t += target_size
    assert total_numel == total_numel_t

    data_b = data_utils.broadcast_data(keys, data, torch.int64)
    for key in keys:
        tensor = data_t[key].cuda()
        assert data_b[key].sub(tensor).abs().max() == 0

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


if __name__ == '__main__':

    initialize_distributed()
    world_size = torch.distributed.get_world_size()

    model_parallel_size = 1
    while model_parallel_size <= world_size:
        print_separator('test test boradcast data')
        test_boradcast_data(model_parallel_size)
        model_parallel_size *= 2
コード例 #3
0
ファイル: test_layers.py プロジェクト: zavodptica/ru-gpts
    assert error < 5.0e-5, 'error: {}'.format(error)

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')


if __name__ == '__main__':

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    initialize_distributed()
    world_size = torch.distributed.get_world_size()

    print_separator('test initialize affine weight')
    model_parallel_size = 1
    while model_parallel_size <= world_size:
        test_initialize_affine_weight(model_parallel_size)
        model_parallel_size *= 2

    model_parallel_size = 1
    while model_parallel_size <= world_size:
        print_separator('test parallel embedding')
        test_parallel_embedding(model_parallel_size)
        model_parallel_size *= 2

    print_separator('test column-parallel linear')
    model_parallel_size = 1
    while model_parallel_size <= world_size:
        test_column_parallel_linear(model_parallel_size)
コード例 #4
0
    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


if __name__ == '__main__':

    initialize_distributed()
    world_size = torch.distributed.get_world_size()

    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        print_separator('test set rng state')
        test_set_cuda_rng_state(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2

    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        print_separator('test cuda rng tracker')
        test_cuda_rng_tracker(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2

    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        print_separator('test model parallel cuda manual seed')
        test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2
コード例 #5
0
            model_parallel_size_))
    model_parallel_size = min(model_parallel_size_,
                              torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size)
    assert mpu.model_parallel_is_initialized()

    # Checks
    src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
    assert mpu.get_model_parallel_src_rank() == src_rank

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


if __name__ == '__main__':

    initialize_distributed()
    world_size = torch.distributed.get_world_size()
    model_parallel_size = 1
    while model_parallel_size <= world_size:
        print_separator('test initialize model parallel')
        test_initialize_model_parallel(model_parallel_size)
        print_separator('test model parallel source rank')
        test_get_model_parallel_src_rank(model_parallel_size)
        model_parallel_size *= 2