def test_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): is_mixedprecision = True is_zero_run = True opts_dict = { "device": { "id": device }, "mixed_precision": { "enabled": is_mixedprecision }, "distributed": { "world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True, "deepspeed_zero_optimization": { "stage": 1 }, }, "debug": { "deterministic_compute": True }, } trainer = create_initialized_orttrainer(device, opts_dict, True) expected_state_dict = trainer._training_session.get_state() expected_state_dict = split_state_dict(expected_state_dict) verify_model_state(trainer, expected_state_dict, is_mixedprecision) verify_opt_state(trainer, expected_state_dict) verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run)
def test_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): is_mixedprecision = True is_zero_run = True opts_dict = { 'device': { 'id': device }, 'mixed_precision': { 'enabled': is_mixedprecision }, 'distributed': { 'world_rank': world_rank, 'world_size': world_size, 'allreduce_post_accumulation': True, 'deepspeed_zero_optimization': { 'stage': 1 } }, 'debug': { 'deterministic_compute': True } } trainer = create_initialized_orttrainer(device, opts_dict, True) expected_state_dict = trainer._training_session.get_state() expected_state_dict = split_state_dict(expected_state_dict) verify_model_state(trainer, expected_state_dict, is_mixedprecision) verify_opt_state(trainer, expected_state_dict) verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run)
def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False): expected_optim_state, trainer_state = load_model_optim_state_and_eval( device, opts, use_lamb) trainer_state = split_state_dict(trainer_state) # round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them. with open( os.path.join(checkpoint_dir, 'distributed_state_' + str(world_rank) + '.pkl'), "wb") as f: pickle.dump(trainer_state, f) dist.barrier() if world_rank == 0: num_states = len(glob.glob1(checkpoint_dir, "distributed_state*")) optimizer_states = dict() for rank in range(num_states): rank_state_dict = None with open( os.path.join(checkpoint_dir, 'distributed_state_' + str(rank) + '.pkl'), 'rb') as f: rank_state_dict = pickle.load(f) # collect optimizer states for later comparison since they are sharded aggregate_states(optimizer_states, rank_state_dict['optimizer']) # compare optimizer states optimizer_config = optim.LambConfig( ) if use_lamb else optim.AdamConfig() actual_optim_state = get_optim_state_from_state_dict( optimizer_states, optimizer_config) assert actual_optim_state.keys() == expected_optim_state.keys() for param_name, a_state in actual_optim_state.items(): for k, v in a_state.items(): assert_allclose( v.reshape(expected_optim_state[param_name][k].shape), expected_optim_state[param_name][k], err_msg= f"Optimizer state mismatch for param {param_name}, key {k}" ) dist.barrier() os.remove( os.path.join(checkpoint_dir, 'distributed_state_' + str(world_rank) + '.pkl'))
def test_single_node_full_precision_lamb(device="cuda", checkpoint_dir=""): opts_dict = { "device": { "id": device }, "debug": { "deterministic_compute": True } } is_mixedprecision = False is_zero_run = False trainer = create_initialized_orttrainer(device, opts_dict, True) expected_state_dict = trainer._training_session.get_state() expected_state_dict = split_state_dict(expected_state_dict) verify_model_state(trainer, expected_state_dict, is_mixedprecision) verify_opt_state(trainer, expected_state_dict) verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run)
def test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision): # get aggregated state dict independently checkpoint_files = checkpoint._list_checkpoint_files(checkpoint_dir, "ORT_checkpoint") agg_checkpoint = checkpoint._CombineZeroCheckpoint(checkpoint_files) aggregate_state_dict = agg_checkpoint.aggregate_checkpoints() # verify loaded state and aggregated states match: assert aggregate_state_dict.keys() == loaded_state_dict.keys() for k, v in loaded_state_dict.items(): assert_allclose(v, aggregate_state_dict[k]) # split state for next few checks loaded_state_dict = split_state_dict(loaded_state_dict) # verify that aggregation is done correctly num_states = len(glob.glob1(checkpoint_dir, "state_dict*")) sharded_state_rank_offset = dict() for rank in range(num_states): state = None with open(os.path.join(checkpoint_dir, 'state_dict_'+str(rank)+'.pkl'), 'rb') as f: state = pickle.load(f) rank_state_dict = split_state_dict(state['state_dict_'+str(rank)]) if is_mixedprecision: for k, v in rank_state_dict['fp16_param'].items(): # verify fp16 weights match assert_allclose(v, loaded_state_dict['fp16_param'][k]) # verify rank fp16 weights match loaded fp32 correctly fp32_name = k.split('_fp16')[0] assert_allclose(v, loaded_state_dict['fp32_param'][fp32_name], atol=global_fp16_fp32_atol) for k, v in rank_state_dict['fp32_param'].items(): if k in loaded_state_dict['fp32_param']: assert_allclose(v, loaded_state_dict['fp32_param'][k]) else: assert '_view_' in k weight_key = k.split('_view_')[0] rank_offset = 0 if weight_key in sharded_state_rank_offset: rank_offset = sharded_state_rank_offset[weight_key] rank_len = v.numel() loaded_tensor = loaded_state_dict['fp32_param'][weight_key].view(-1) assert rank_offset + rank_len <= loaded_tensor.numel() assert_allclose(v, loaded_tensor[rank_offset: rank_offset + rank_len]) # update offset sharded_state_rank_offset[weight_key] = rank_offset + rank_len for k, v in rank_state_dict['optimizer'].items(): if k in loaded_state_dict['optimizer']: assert_allclose(v, loaded_state_dict['optimizer'][k]) else: assert '_view_' in k if k.startswith('Moment_'): # verify moment tensors optim_key = k.split('_view_')[0] rank_offset = 0 if optim_key in sharded_state_rank_offset: rank_offset = sharded_state_rank_offset[optim_key] rank_len = v.numel() loaded_tensor = loaded_state_dict['optimizer'][optim_key].view(-1) assert rank_offset + rank_len <= loaded_tensor.numel() assert_allclose(v, loaded_tensor[rank_offset: rank_offset + rank_len]) sharded_state_rank_offset[optim_key] = rank_offset + rank_len