def test_pretrain_zero(self):
        assert self.world_size >0, "ZeRO test requires a distributed run."
        setup_torch_distributed(self.world_rank, self.world_size)
        per_gpu_batch_size = 32
        optimization_batch_size = per_gpu_batch_size*self.world_size # set to disable grad accumulation
        
        self.train_batch_size = optimization_batch_size
        self.gradient_accumulation_steps = 1
        self.deepspeed_zero_stage = 1
        self.force_num_hidden_layers = 2
        self.max_seq_length = 32
        self.output_dir = './bert_pretrain_ckpt'
        if self.world_rank == 0:            
            if os.path.isdir(self.output_dir):
                shutil.rmtree(self.output_dir)
            os.makedirs(self.output_dir, exist_ok = True)
        
        torch.distributed.barrier()

        assert os.path.exists(self.output_dir)        
        
        # run a few optimization steps
        self.max_steps = 200
        args = PretrainArguments(
            output_dir=self.output_dir,
            bert_model=self.bert_model,
            local_rank=self.local_rank,
            world_rank=self.world_rank,
            world_size=self.world_size,
            max_steps=self.max_steps,
            learning_rate=self.learning_rate,
            max_seq_length=self.max_seq_length,
            max_predictions_per_seq=self.max_predictions_per_seq,
            train_batch_size=self.train_batch_size,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            input_dir=self.input_dir,
            fp16=self.fp16,
            allreduce_post_accumulation=self.allreduce_post_accumulation,
            force_num_hidden_layers=self.force_num_hidden_layers,
            deepspeed_zero_stage=self.deepspeed_zero_stage,
            save_checkpoint=True)
        train_loss = do_pretrain(args)

        # ensure all workers reach this point before loading the checkpointed state
        torch.distributed.barrier()

        # on rank 0, load the trained state
        if args.world_rank == 0:
            checkpoint_files = glob.glob(os.path.join(self.output_dir, 'checkpoint*.ortcp'))
            args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True)

        torch.distributed.barrier()

        # run a single step to get the loss, on rank 0 should be lesser than starting loss
        args.save_checkpoint = False
        args.max_steps = 1
        args.deepspeed_zero_stage = 0
        final_loss = do_pretrain(args)
        return final_loss
def test_megatron_aggregation(checkpoint_dir, loaded_state_dict,
                              expected_state_dict, is_mixedprecision):
    # get aggregated state dict independently
    aggregate_state_dict_from_checkpoint = \
        checkpoint.aggregate_checkpoints(glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False)

    # verify loaded state and aggregated states match:
    assert_all_states_close_ort(loaded_state_dict,
                                aggregate_state_dict_from_checkpoint)

    #compare with expected state dict
    assert_all_states_close_ort(expected_state_dict, loaded_state_dict)
def test_zero_aggregation(checkpoint_dir, loaded_state_dict,
                          is_mixedprecision):
    # get aggregated state dict independently
    aggregate_state_dict_from_checkpoint = \
        checkpoint.aggregate_checkpoints(glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False)

    # verify loaded state and aggregated states match:
    assert_all_states_close_ort(loaded_state_dict,
                                aggregate_state_dict_from_checkpoint)

    # manually aggregate the states from the previously saved pickle file
    aggregate_state_dict_from_test = aggregate_states(checkpoint_dir)

    # compare state dictionaries between the manually aggregated state dictionary with the aggregated state dictionary from the ORTTrainer
    assert_all_states_close_ort(aggregate_state_dict_from_test,
                                aggregate_state_dict_from_checkpoint,
                                reshape_states=True)
Пример #4
0
def test_checkpoint_aggregation_mixed_precision(load_mock):
    trainer_options1 = {
        'mixed_precision': np.bool_(True),
        'world_rank': np.int64(0),
        'world_size': np.int64(2),
        'horizontal_parallel_size': np.int64(1),
        'data_parallel_size': np.int64(2),
        'zero_stage': np.int64(1),
        'optimizer_name': b'Adam'
    }
    trainer_options2 = {
        'mixed_precision': np.bool_(True),
        'world_rank': np.int64(1),
        'world_size': np.int64(2),
        'horizontal_parallel_size': np.int64(1),
        'data_parallel_size': np.int64(2),
        'zero_stage': np.int64(1),
        'optimizer_name': b'Adam'
    }

    state_dict1 = {
        'model': {
            'full_precision': {
                'sharded': np.array([1, 2, 3]),
                'non_sharded': np.array([11, 22, 33])
            }
        },
        'optimizer': {
            'sharded': {
                'Moment_1': np.array([9, 8, 7]),
                'Moment_2': np.array([99, 88, 77]),
                'Step': np.array([5])
            },
            'non_sharded': {
                'Moment_1': np.array([666, 555, 444]),
                'Moment_2': np.array([6666, 5555, 4444]),
                'Step': np.array([55])
            }
        },
        'trainer_options': {
            'mixed_precision': np.bool_(True),
            'world_rank': np.int64(0),
            'world_size': np.int64(1),
            'horizontal_parallel_size': np.int64(1),
            'data_parallel_size': np.int64(1),
            'zero_stage': np.int64(0),
            'optimizer_name': b'Adam'
        },
        'partition_info': {
            'sharded': {
                'original_dim': np.array([2, 3])
            }
        }
    }

    state_dict2 = {
        'model': {
            'full_precision': {
                'sharded': np.array([4, 5, 6]),
                'non_sharded': np.array([11, 22, 33])
            }
        },
        'optimizer': {
            'sharded': {
                'Moment_1': np.array([6, 5, 4]),
                'Moment_2': np.array([66, 55, 44]),
                'Step': np.array([5])
            },
            'non_sharded': {
                'Moment_1': np.array([666, 555, 444]),
                'Moment_2': np.array([6666, 5555, 4444]),
                'Step': np.array([55])
            }
        },
        'trainer_options': {
            'mixed_precision': np.bool_(True),
            'world_rank': np.int64(1),
            'world_size': np.int64(1),
            'horizontal_parallel_size': np.int64(1),
            'data_parallel_size': np.int64(1),
            'zero_stage': np.int64(0),
            'optimizer_name': b'Adam'
        },
        'partition_info': {
            'sharded': {
                'original_dim': np.array([2, 3])
            }
        }
    }

    load_mock.side_effect = [
        trainer_options1, trainer_options2, trainer_options1, state_dict1,
        state_dict2
    ]
    state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'],
                                                  pytorch_format=False)

    assert (state_dict['model']['full_precision']['sharded'] == np.array(
        [[1, 2, 3], [4, 5, 6]])).all()
    assert (state_dict['model']['full_precision']['non_sharded'] == np.array(
        [11, 22, 33])).all()
    assert (state_dict['optimizer']['sharded']['Moment_1'] == np.array(
        [[9, 8, 7], [6, 5, 4]])).all()
    assert (state_dict['optimizer']['sharded']['Moment_2'] == np.array(
        [[99, 88, 77], [66, 55, 44]])).all()
    assert (state_dict['optimizer']['sharded']['Step'] == np.array([5])).all()
    assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array(
        [666, 555, 444])).all()
    assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array(
        [6666, 5555, 4444])).all()
    assert (state_dict['optimizer']['non_sharded']['Step'] == np.array(
        [55])).all()

    assert state_dict['trainer_options']['mixed_precision'] == True
    assert state_dict['trainer_options']['world_rank'] == 0
    assert state_dict['trainer_options']['world_size'] == 1
    assert state_dict['trainer_options']['horizontal_parallel_size'] == 1
    assert state_dict['trainer_options']['data_parallel_size'] == 1
    assert state_dict['trainer_options']['zero_stage'] == 0
    assert state_dict['trainer_options']['optimizer_name'] == b'Adam'
def test_checkpoint_aggregation_mixed_precision(load_mock):
    trainer_options1 = {
        "mixed_precision": np.bool_(True),
        "world_rank": np.int64(0),
        "world_size": np.int64(2),
        "horizontal_parallel_size": np.int64(1),
        "data_parallel_size": np.int64(2),
        "zero_stage": np.int64(1),
        "optimizer_name": b"Adam",
    }
    trainer_options2 = {
        "mixed_precision": np.bool_(True),
        "world_rank": np.int64(1),
        "world_size": np.int64(2),
        "horizontal_parallel_size": np.int64(1),
        "data_parallel_size": np.int64(2),
        "zero_stage": np.int64(1),
        "optimizer_name": b"Adam",
    }

    state_dict1 = {
        "model": {
            "full_precision": {
                "sharded": np.array([1, 2, 3]),
                "non_sharded": np.array([11, 22, 33])
            }
        },
        "optimizer": {
            "sharded": {
                "Moment_1": np.array([9, 8, 7]),
                "Moment_2": np.array([99, 88, 77]),
                "Step": np.array([5])
            },
            "non_sharded": {
                "Moment_1": np.array([666, 555, 444]),
                "Moment_2": np.array([6666, 5555, 4444]),
                "Step": np.array([55]),
            },
        },
        "trainer_options": {
            "mixed_precision": np.bool_(True),
            "world_rank": np.int64(0),
            "world_size": np.int64(1),
            "horizontal_parallel_size": np.int64(1),
            "data_parallel_size": np.int64(1),
            "zero_stage": np.int64(0),
            "optimizer_name": b"Adam",
        },
        "partition_info": {
            "sharded": {
                "original_dim": np.array([2, 3])
            }
        },
    }

    state_dict2 = {
        "model": {
            "full_precision": {
                "sharded": np.array([4, 5, 6]),
                "non_sharded": np.array([11, 22, 33])
            }
        },
        "optimizer": {
            "sharded": {
                "Moment_1": np.array([6, 5, 4]),
                "Moment_2": np.array([66, 55, 44]),
                "Step": np.array([5])
            },
            "non_sharded": {
                "Moment_1": np.array([666, 555, 444]),
                "Moment_2": np.array([6666, 5555, 4444]),
                "Step": np.array([55]),
            },
        },
        "trainer_options": {
            "mixed_precision": np.bool_(True),
            "world_rank": np.int64(1),
            "world_size": np.int64(1),
            "horizontal_parallel_size": np.int64(1),
            "data_parallel_size": np.int64(1),
            "zero_stage": np.int64(0),
            "optimizer_name": b"Adam",
        },
        "partition_info": {
            "sharded": {
                "original_dim": np.array([2, 3])
            }
        },
    }

    load_mock.side_effect = [
        trainer_options1, trainer_options2, trainer_options1, state_dict1,
        state_dict2
    ]
    state_dict = checkpoint.aggregate_checkpoints(["abc", "def"],
                                                  pytorch_format=False)

    assert (state_dict["model"]["full_precision"]["sharded"] == np.array(
        [[1, 2, 3], [4, 5, 6]])).all()
    assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array(
        [11, 22, 33])).all()
    assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array(
        [[9, 8, 7], [6, 5, 4]])).all()
    assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array(
        [[99, 88, 77], [66, 55, 44]])).all()
    assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all()
    assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array(
        [666, 555, 444])).all()
    assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array(
        [6666, 5555, 4444])).all()
    assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array(
        [55])).all()

    assert state_dict["trainer_options"]["mixed_precision"] == True
    assert state_dict["trainer_options"]["world_rank"] == 0
    assert state_dict["trainer_options"]["world_size"] == 1
    assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1
    assert state_dict["trainer_options"]["data_parallel_size"] == 1
    assert state_dict["trainer_options"]["zero_stage"] == 0
    assert state_dict["trainer_options"]["optimizer_name"] == b"Adam"