Пример #1
0
    def _clip_gradients(self, optimizer, clip_val=None):
        """ Override of PTL Gradient Clipping.
            Enables model parallel gradient clipping from Megatron-LM.

        Args:
            optimizer ([type]): [description]
            clip_val ([type], optional): [description]. Defaults to None.
        """
        app_state = AppState()

        # get clip_val from trainer if None is provided
        if clip_val is None:
            clip_val = float(self._trainer.gradient_clip_val)

        if app_state.model_parallel_size is not None:
            model = self._trainer.get_model()
            parameters = model.parameters()
            if mpu.model_parallel_is_initialized():
                mpu.grads.clip_grad_norm(parameters=parameters,
                                         max_norm=clip_val)
            else:
                raise ValueError(
                    'Model parallel groups must be intialized to use model parallel gradient clipping.'
                )

        else:
            return Accelerator._clip_gradients(self, optimizer, clip_val)
Пример #2
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path

        if os.path.isfile(restore_path):
            self._load_checkpoint(restore_path)
        elif os.path.isdir(restore_path):
            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(
                    group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                self._load_checkpoint(mp_restore_path)
            else:
                logging.info(
                    f'torch.distributed not initialized yet. Will not restore model parallel checkpoint'
                )
        else:
            logging.error(
                f'restore_path: {restore_path} must be a file or directory.')
Пример #3
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path
        if os.path.isfile(restore_path):
            logging.info(f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism')
            state_dict = torch.load(restore_path)
            # to load from Megatron pretrained checkpoint
            if 'model' in state_dict:
                self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
            else:
                self.load_state_dict(state_dict)
            logging.info(f"weights restored from {restore_path}")
        elif os.path.isdir(restore_path):
            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                logging.info(f'Restoring model parallel checkpoint from: {mp_restore_path}')
                state_dict = torch.load(mp_restore_path)
                # to load from Megatron pretrained checkpoint
                if 'model' in state_dict:
                    self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
                else:
                    self.load_state_dict(state_dict)
            else:
                logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint')
        else:
            logging.error(f'restore_path: {restore_path} must be a file or directory.')
Пример #4
0
def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print(
                'torch distributed is already initialized, '
                'skipping initialization ...',
                flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
        if device_count > 0:
            device = args.rank % device_count
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
        # Call the init process
        init_method = 'tcp://'
        master_ip = os.getenv('MASTER_ADDR', 'localhost')
        master_port = os.getenv('MASTER_PORT', '6000')
        init_method += master_ip + ':' + master_port
        torch.distributed.init_process_group(backend=args.distributed_backend,
                                             world_size=args.world_size,
                                             rank=args.rank,
                                             init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    if device_count > 0:
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
            mpu.initialize_model_parallel(args.model_parallel_size)

    # Optional DeepSpeed Activation Checkpointing Features
    #
    if args.deepspeed and args.deepspeed_activation_checkpointing:
        setup_deepspeed_random_and_activation_checkpointing(args)
Пример #5
0
def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print(
                'torch distributed is already initialized, '
                'skipping initialization ...',
                flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
        if device_count > 0:
            device = args.rank % device_count
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
    # Call the init process
    torch.distributed.init_process_group(backend=args.distributed_backend,
                                         world_size=args.world_size,
                                         rank=args.rank,
                                         timeout=timedelta(minutes=10))

    # Set the tensor model-parallel, pipeline model-parallel, and
    # data-parallel communicators.
    if device_count > 0:
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
            mpu.initialize_model_parallel(
                args.tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                args.virtual_pipeline_model_parallel_size,
                args.pipeline_model_parallel_split_rank)
Пример #6
0
def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
        if device_count > 0:
            device = args.rank % device_count
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
            
        distributed.init_distributed(
            dist_backend=args.distributed_backend,
            auto_mpi_discovery=True,
            distributed_port=os.getenv('MASTER_PORT', '6000'),
            verbose=True,
        )

    # Setup 3D topology.
    if args.pipe_parallel_size > 0:
        pp = args.pipe_parallel_size
        mp = args.model_parallel_size
        assert args.world_size % (pp * mp) == 0
        dp = args.world_size // (pp * mp)

        from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
        # this does pipe on the most outside, then data, then model. 
        # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
        topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)

        # Offset base seeds for the interior pipeline stages.
        # TODO: adjust last stage too once IO is improved.
        stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe
        if 0 < stage_id < topo.get_dim('pipe') - 1:
            offset = args.seed + 1138
            args.seed = offset + (stage_id * mp)
    else:
        topo = None

    # Set the model-parallel / data-parallel communicators.
    if device_count > 0:
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
            mpu.initialize_model_parallel(args.model_parallel_size, topology=topo)

    # Optional DeepSpeed Activation Checkpointing Features
    #
    if args.deepspeed and args.deepspeed_activation_checkpointing:
        setup_deepspeed_random_and_activation_checkpointing(args)
Пример #7
0
def _initialize_distributed(neox_args):
    """Initialize torch.distributed and mpu."""

    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if neox_args.rank == 0:
            print(
                "torch distributed is already initialized, "
                "skipping initialization ...",
                flush=True,
            )
        neox_args.rank = torch.distributed.get_rank()
        neox_args.world_size = torch.distributed.get_world_size()

    else:

        if neox_args.rank == 0:
            print("> initializing torch distributed ...", flush=True)
        # Manually set the device ids.
        if device_count > 0:
            device = neox_args.rank % device_count
            if neox_args.local_rank is not None:
                assert (
                    neox_args.local_rank == device
                ), "expected local-rank to be the same as rank % device-count."
            else:
                neox_args.local_rank = device
            torch.cuda.set_device(device)

        distributed.init_distributed(
            dist_backend=neox_args.distributed_backend,
            auto_mpi_discovery=True,
            distributed_port=os.getenv("MASTER_PORT", "6000"),
            verbose=True,
        )

    # Setup 3D topology.
    pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
    mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
    assert (
        neox_args.world_size %
        (pp * mp) == 0), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
    dp = neox_args.world_size // (pp * mp)

    from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology

    # this does pipe on the most outside, then data, then model.
    # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
    topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)

    # Offset base seeds for the interior pipeline stages.
    # TODO: adjust last stage too once IO is improved.
    stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe
    if 0 < stage_id < topo.get_dim("pipe") - 1:
        offset = neox_args.seed + 1138
        neox_args.seed = offset + (stage_id * mp)

    # Set the model-parallel / data-parallel communicators.
    if device_count > 0:
        if mpu.model_parallel_is_initialized():
            print(
                "_initialize_distributed() model parallel is already initialized",
                flush=True,
            )
        else:
            mpu.initialize_model_parallel(
                neox_args.model_parallel_size,
                topology=topo,
                fp32_allreduce=neox_args.fp32_allreduce,
            )

    # Init DeepSpeed Activation Checkpointing Features
    setup_deepspeed_random_and_activation_checkpointing(neox_args=neox_args)
Пример #8
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path
        if os.path.isfile(restore_path):
            logging.info(
                f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism'
            )
            state_dict = torch.load(restore_path, map_location='cpu')
            if 'checkpoint_version' in state_dict:
                if state_dict['checkpoint_version'] is not None:
                    set_checkpoint_version(state_dict['checkpoint_version'])
            else:
                logging.warning(
                    'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                )
                set_checkpoint_version(0)
            # to load from Megatron pretrained checkpoint
            if 'model' in state_dict:
                self.language_model.load_state_dict(
                    state_dict['model'][self._language_model_key])
            else:
                self.load_state_dict(state_dict)
            logging.info(f"weights restored from {restore_path}")
        elif os.path.isdir(restore_path):
            # TODO: need to refactor this so we're not repeating code

            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(
                    group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                logging.info(
                    f'Restoring model parallel checkpoint from: {mp_restore_path}'
                )
                state_dict = torch.load(mp_restore_path, map_location='cpu')
                if 'checkpoint_version' in state_dict:
                    if state_dict['checkpoint_version'] is not None:
                        set_checkpoint_version(
                            state_dict['checkpoint_version'])
                else:
                    logging.warning(
                        'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                    )
                    set_checkpoint_version(0)
                # to load from Megatron pretrained checkpoint
                if 'model' in state_dict:
                    self.language_model.load_state_dict(
                        state_dict['model'][self._language_model_key])
                else:
                    self.load_state_dict(state_dict)
            else:
                logging.info(
                    f'torch.distributed not initialized yet. Will not restore model parallel checkpoint'
                )
        else:
            logging.error(
                f'restore_path: {restore_path} must be a file or directory.')