Beispiel #1
0
def count_zeros_fp32(parameters):

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    total_num_zeros = 0.0
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grad = param.grad.detach()
            num_zeros = grad.numel() - torch.count_nonzero(grad)
            total_num_zeros = num_zeros + total_num_zeros

    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(
        total_num_zeros,
        op=torch.distributed.ReduceOp.SUM,
        group=parallel_state.get_model_parallel_group())
    total_num_zeros = total_num_zeros.item()

    return total_num_zeros
Beispiel #2
0
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool):
    """Calculate l2 norm of parameters """
    # args = get_args()
    if not isinstance(model, list):
        model = [model]
    # Remove duplicate params.
    params_data = []
    for model_ in model:
        for param in model_.parameters():
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = parallel_state.param_is_not_tensor_parallel_duplicate(
                param)
            if is_not_shared and is_not_tp_duplicate:
                if bf16:
                    params_data.append(param.data.float())
                else:
                    params_data.append(param.data)
    # Calculate norm
    dummy_overflow_buf = torch.cuda.IntTensor([0])
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        dummy_overflow_buf,
        [params_data],
        False  # no per-parameter norm
    )
    norm_2 = norm * norm
    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(
        norm_2,
        op=torch.distributed.ReduceOp.SUM,
        group=parallel_state.get_model_parallel_group())
    return norm_2.item()**0.5
Beispiel #3
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path

        if os.path.isfile(restore_path):
            self._load_checkpoint(restore_path)
        elif os.path.isdir(restore_path):
            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(
                    group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                self._load_checkpoint(mp_restore_path)
            else:
                logging.info(
                    f'torch.distributed not initialized yet. Will not restore model parallel checkpoint'
                )
        else:
            logging.error(
                f'restore_path: {restore_path} must be a file or directory.')
Beispiel #4
0
    def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
        retval = None
        found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())])

        # Update across all model parallel instances.
        torch.distributed.all_reduce(
            found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
        )

        if found_inf.item() == 0:
            retval = optimizer.step(*args, **kwargs)
        return retval
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None):
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != self.tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
Beispiel #6
0
    def update(self, new_scale=None):
        """
        Updates the scale factor.

        If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
        to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
        the scale is multiplied by ``growth_factor`` to increase it.

        Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
        used directly, it's used to fill GradScaler's internal scale tensor. So if
        ``new_scale`` was a tensor, later in-place changes to that tensor will not further
        affect the scale GradScaler uses internally.)

        Args:
            new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None):  New scale factor.

        .. warning::
            :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
            been invoked for all optimizers used this iteration.
        """
        if not self._enabled:
            return

        _scale, _growth_tracker = self._check_scale_growth_tracker("update")

        if new_scale is not None:
            # Accept a new user-defined scale.
            if isinstance(new_scale, float):
                self._scale.fill_(new_scale)  # type: ignore[union-attr]
            else:
                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
                assert isinstance(new_scale, torch.cuda.FloatTensor
                                  ), reason  # type: ignore[attr-defined]
                assert new_scale.numel() == 1, reason
                assert new_scale.requires_grad is False, reason
                self._scale.copy_(new_scale)  # type: ignore[union-attr]
        else:
            # Consume shared inf/nan data collected from optimizers to update the scale.
            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
            found_infs = [
                found_inf.to(device=_scale.device, non_blocking=True)
                for state in self._per_optimizer_states.values()
                for found_inf in state["found_inf_per_device"].values()
            ]

            assert len(
                found_infs) > 0, "No inf checks were recorded prior to update."

            found_inf_combined = found_infs[0]

            # Update across all model parallel instances.
            torch.distributed.all_reduce(
                found_inf_combined,
                op=torch.distributed.ReduceOp.MAX,
                group=parallel_state.get_model_parallel_group())

            if len(found_infs) > 1:
                for i in range(1, len(found_infs)):
                    found_inf = found_infs[i]
                    # Update across all model parallel instances.
                    torch.distributed.all_reduce(
                        found_inf,
                        op=torch.distributed.ReduceOp.MAX,
                        group=parallel_state.get_model_parallel_group())
                    found_inf_combined += found_inf

            torch._amp_update_scale_(
                _scale,
                _growth_tracker,
                found_inf_combined,
                self._growth_factor,
                self._backoff_factor,
                self._growth_interval,
            )

        # To prepare for next iteration, clear the data collected from optimizers this iteration.
        self._per_optimizer_states = defaultdict(
            torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)
Beispiel #7
0
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
    """Clips gradient norm of an iterable of parameters whose gradients
       are in fp32.
    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.
    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    grads = []
    grads_for_norm = []
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        grad = param.grad.detach()
        if grad_not_none:
            # Make sure the grads are in fp32
            assert isinstance(param.grad, torch.cuda.FloatTensor)
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grads_for_norm.append(grad)

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = 0.0

    # Calculate norm.
    if norm_type == inf:
        total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(
            total_norm_cuda,
            op=torch.distributed.ReduceOp.MAX,
            group=parallel_state.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0])
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
            grad_norm, _ = multi_tensor_applier(
                amp_C.multi_tensor_l2norm,
                dummy_overflow_buf,
                [grads_for_norm],
                False  # no per-parameter norm
            )
            # Since we will be summing across data parallel groups,
            # we need the pow(norm-type).
            total_norm = grad_norm**norm_type

        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm**norm_type

        # Sum across all model-parallel GPUs.
        torch.distributed.all_reduce(
            total_norm,
            op=torch.distributed.ReduceOp.SUM,
            group=parallel_state.get_model_parallel_group())
        total_norm = total_norm.item()**(1.0 / norm_type)

    # Scale.
    clip_coeff = max_norm / (total_norm + 1.0e-6)
    if clip_coeff < 1.0:
        dummy_overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf,
                             [grads, grads], clip_coeff)

    return total_norm
Beispiel #8
0
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None):
        # Check whether the DDP is initialized. This is needed when running inference outside of training loop.
        if parallel_state.is_unitialized():

            def dummy():
                return

            if self.trainer.strategy.launcher is not None:
                self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer)
            self.trainer.strategy.setup_environment()

            # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training.
            _reconfigure_microbatch_calculator(
                rank=0,  # This doesn't matter since it is only used for logging
                rampup_batch_size=None,
                global_batch_size=1,
                micro_batch_size=1,  # Make sure that there is no "grad acc" while decoding.
                data_parallel_size=1,  # We check above to make sure that dataparallel size is always 1 at inference.
            )

        # If classes that inherit from this class are using a different tokenizer,
        tokenizer = self.tokenizer if tokenizer is None else tokenizer
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
Beispiel #9
0
    def update(self, new_scale=None):
        """
        Updates to native grad scaler update function.
        1. Check inf across model-parallel ranks.
        2. Update hysteresis tracker.
        3. Apply hysteresis to grad scale update.
        """
        if not self._enabled:
            return

        _scale, _growth_tracker = self._check_scale_growth_tracker("update")

        if new_scale is not None:
            # Accept a new user-defined scale.
            if isinstance(new_scale, float):
                self._scale.fill_(new_scale)  # type: ignore[union-attr]
            else:
                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
                assert isinstance(new_scale, torch.cuda.FloatTensor
                                  ), reason  # type: ignore[attr-defined]
                assert new_scale.numel() == 1, reason
                assert new_scale.requires_grad is False, reason
                self._scale.copy_(new_scale)  # type: ignore[union-attr]
        else:
            # Consume shared inf/nan data collected from optimizers to update the scale.
            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
            found_infs = [
                found_inf.to(device=_scale.device, non_blocking=True)
                for state in self._per_optimizer_states.values()
                for found_inf in state["found_inf_per_device"].values()
            ]

            assert len(
                found_infs) > 0, "No inf checks were recorded prior to update."

            found_inf_combined = found_infs[0]

            # Update across all model parallel instances.
            torch.distributed.all_reduce(
                found_inf_combined,
                op=torch.distributed.ReduceOp.MAX,
                group=parallel_state.get_model_parallel_group())

            if len(found_infs) > 1:
                for i in range(1, len(found_infs)):
                    found_inf = found_infs[i]
                    # Update across all model parallel instances.
                    torch.distributed.all_reduce(
                        found_inf,
                        op=torch.distributed.ReduceOp.MAX,
                        group=parallel_state.get_model_parallel_group())
                    found_inf_combined += found_inf

            if found_inf_combined > 0:
                self._hysteresis_tracker -= 1
                if self._hysteresis_tracker <= 0:
                    # When hysteresis becomes zero, follow the native grad scale update rule.
                    # Increase scale and reset growth tracker
                    torch._amp_update_scale_(
                        _scale,
                        _growth_tracker,
                        found_inf_combined,
                        self._growth_factor,
                        self._backoff_factor,
                        self._growth_interval,
                    )
                else:
                    # Only reset the growth tracker when hysteresis is larger than zero
                    _growth_tracker.fill_(0.0)
            else:
                # When no inf found, follow the native grad scale update rule.
                # Increment growth_tracker, update scale when growth tracker reaches the interval, and
                # reset the hysteresis tracker.
                torch._amp_update_scale_(
                    _scale,
                    _growth_tracker,
                    found_inf_combined,
                    self._growth_factor,
                    self._backoff_factor,
                    self._growth_interval,
                )
                self._hysteresis_tracker = self.hysteresis

        # To prepare for next iteration, clear the data collected from optimizers this iteration.
        self._per_optimizer_states = defaultdict(
            torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)