def test_distributed_zero_mixed_precision_lamb(world_rank, world_size, device,
                                               checkpoint_dir):
    is_mixedprecision = True
    is_zero_run = True
    opts_dict = {
        "device": {
            "id": device
        },
        "mixed_precision": {
            "enabled": is_mixedprecision
        },
        "distributed": {
            "world_rank": world_rank,
            "world_size": world_size,
            "allreduce_post_accumulation": True,
            "deepspeed_zero_optimization": {
                "stage": 1
            },
        },
        "debug": {
            "deterministic_compute": True
        },
    }

    trainer = create_initialized_orttrainer(device, opts_dict, True)

    expected_state_dict = trainer._training_session.get_state()
    expected_state_dict = split_state_dict(expected_state_dict)

    verify_model_state(trainer, expected_state_dict, is_mixedprecision)

    verify_opt_state(trainer, expected_state_dict)

    verify_part_info(trainer, expected_state_dict, is_mixedprecision,
                     is_zero_run)
Exemple #2
0
def test_distributed_zero_mixed_precision_lamb(world_rank, world_size, device,
                                               checkpoint_dir):
    is_mixedprecision = True
    is_zero_run = True
    opts_dict = {
        'device': {
            'id': device
        },
        'mixed_precision': {
            'enabled': is_mixedprecision
        },
        'distributed': {
            'world_rank': world_rank,
            'world_size': world_size,
            'allreduce_post_accumulation': True,
            'deepspeed_zero_optimization': {
                'stage': 1
            }
        },
        'debug': {
            'deterministic_compute': True
        }
    }

    trainer = create_initialized_orttrainer(device, opts_dict, True)

    expected_state_dict = trainer._training_session.get_state()
    expected_state_dict = split_state_dict(expected_state_dict)

    verify_model_state(trainer, expected_state_dict, is_mixedprecision)

    verify_opt_state(trainer, expected_state_dict)

    verify_part_info(trainer, expected_state_dict, is_mixedprecision,
                     is_zero_run)
Exemple #3
0
def verify_optimizer_state_match(device,
                                 opts,
                                 checkpoint_dir,
                                 world_rank,
                                 use_lamb=False):
    expected_optim_state, trainer_state = load_model_optim_state_and_eval(
        device, opts, use_lamb)
    trainer_state = split_state_dict(trainer_state)
    # round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them.
    with open(
            os.path.join(checkpoint_dir,
                         'distributed_state_' + str(world_rank) + '.pkl'),
            "wb") as f:
        pickle.dump(trainer_state, f)
    dist.barrier()

    if world_rank == 0:
        num_states = len(glob.glob1(checkpoint_dir, "distributed_state*"))
        optimizer_states = dict()
        for rank in range(num_states):
            rank_state_dict = None
            with open(
                    os.path.join(checkpoint_dir,
                                 'distributed_state_' + str(rank) + '.pkl'),
                    'rb') as f:
                rank_state_dict = pickle.load(f)

            # collect optimizer states for later comparison since they are sharded
            aggregate_states(optimizer_states, rank_state_dict['optimizer'])

        # compare optimizer states
        optimizer_config = optim.LambConfig(
        ) if use_lamb else optim.AdamConfig()
        actual_optim_state = get_optim_state_from_state_dict(
            optimizer_states, optimizer_config)
        assert actual_optim_state.keys() == expected_optim_state.keys()
        for param_name, a_state in actual_optim_state.items():
            for k, v in a_state.items():
                assert_allclose(
                    v.reshape(expected_optim_state[param_name][k].shape),
                    expected_optim_state[param_name][k],
                    err_msg=
                    f"Optimizer state mismatch for param {param_name}, key {k}"
                )

    dist.barrier()
    os.remove(
        os.path.join(checkpoint_dir,
                     'distributed_state_' + str(world_rank) + '.pkl'))
def test_single_node_full_precision_lamb(device="cuda", checkpoint_dir=""):
    opts_dict = {
        "device": {
            "id": device
        },
        "debug": {
            "deterministic_compute": True
        }
    }
    is_mixedprecision = False
    is_zero_run = False

    trainer = create_initialized_orttrainer(device, opts_dict, True)

    expected_state_dict = trainer._training_session.get_state()
    expected_state_dict = split_state_dict(expected_state_dict)

    verify_model_state(trainer, expected_state_dict, is_mixedprecision)

    verify_opt_state(trainer, expected_state_dict)

    verify_part_info(trainer, expected_state_dict, is_mixedprecision,
                     is_zero_run)
Exemple #5
0
def test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision):
    # get aggregated state dict independently
    checkpoint_files = checkpoint._list_checkpoint_files(checkpoint_dir, "ORT_checkpoint")
    agg_checkpoint = checkpoint._CombineZeroCheckpoint(checkpoint_files)
    aggregate_state_dict = agg_checkpoint.aggregate_checkpoints()

    # verify loaded state and aggregated states match:
    assert aggregate_state_dict.keys() == loaded_state_dict.keys()
    for k, v in loaded_state_dict.items():
        assert_allclose(v, aggregate_state_dict[k])

    # split state for next few checks
    loaded_state_dict = split_state_dict(loaded_state_dict)

    # verify that aggregation is done correctly
    num_states = len(glob.glob1(checkpoint_dir, "state_dict*"))

    sharded_state_rank_offset = dict()
    for rank in range(num_states):
        state = None
        with open(os.path.join(checkpoint_dir, 'state_dict_'+str(rank)+'.pkl'), 'rb') as f:
            state = pickle.load(f)
        rank_state_dict = split_state_dict(state['state_dict_'+str(rank)])

        if is_mixedprecision:
            for k, v in rank_state_dict['fp16_param'].items():
                # verify fp16 weights match
                assert_allclose(v, loaded_state_dict['fp16_param'][k])
                # verify rank fp16 weights match loaded fp32 correctly
                fp32_name = k.split('_fp16')[0]
                assert_allclose(v, loaded_state_dict['fp32_param'][fp32_name], atol=global_fp16_fp32_atol)

        for k, v in rank_state_dict['fp32_param'].items():
            if k in loaded_state_dict['fp32_param']:
                assert_allclose(v, loaded_state_dict['fp32_param'][k])
            else:
                assert '_view_' in k
                weight_key = k.split('_view_')[0]
                rank_offset = 0
                if weight_key in sharded_state_rank_offset:
                    rank_offset = sharded_state_rank_offset[weight_key]
                rank_len = v.numel()
                loaded_tensor = loaded_state_dict['fp32_param'][weight_key].view(-1)
                assert rank_offset + rank_len <= loaded_tensor.numel()
                assert_allclose(v, loaded_tensor[rank_offset: rank_offset + rank_len])
                # update offset
                sharded_state_rank_offset[weight_key] = rank_offset + rank_len

        for k, v in rank_state_dict['optimizer'].items():
            if k in loaded_state_dict['optimizer']:
                assert_allclose(v, loaded_state_dict['optimizer'][k])
            else:
                assert '_view_' in k
                if k.startswith('Moment_'):  # verify moment tensors
                    optim_key = k.split('_view_')[0]
                    rank_offset = 0
                    if optim_key in sharded_state_rank_offset:
                        rank_offset = sharded_state_rank_offset[optim_key]
                    rank_len = v.numel()
                    loaded_tensor = loaded_state_dict['optimizer'][optim_key].view(-1)
                    assert rank_offset + rank_len <= loaded_tensor.numel()
                    assert_allclose(v, loaded_tensor[rank_offset: rank_offset + rank_len])

                    sharded_state_rank_offset[optim_key] = rank_offset + rank_len