def split_state_dict(self,
                         mp_world_size,
                         mp_rank,
                         quantize=False,
                         quantize_bits=8,
                         groups=64,
                         mlp_extra_grouping=True):
        #self.sanity_check(self.ckpt_list[0])

        sd, num_to_split, ckpt_offset = self.get_split_state_dict(
            mp_world_size, mp_rank)
        ds_sd = copy.deepcopy(sd)
        new_client_sd = collections.OrderedDict()

        client_sd = self.get_module(sd)

        ckpt_ver = self.get_checkpoint_version(ds_sd)
        logger.info(f"checkpoint version: {ckpt_ver}")

        if quantize:
            quantizer = WeightQuantization(
                mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)

        for key in client_sd.keys():
            value = client_sd[key]

            if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
                assert value.shape[1] % num_to_split == 0
                split_size = value.shape[1] // num_to_split
                if quantize:
                    q_vals = quantizer.Quantize([value], quantize_bits, groups,
                                                key)
                    value = q_vals[0]
                new_client_sd[key] = torch.split(value, split_size,
                                                 dim=1)[ckpt_offset]
            elif "attention.query_key_value" in key:
                if quantize and "attention.query_key_value.weight" in key:
                    q_vals = quantizer.Quantize([value], quantize_bits, groups,
                                                key)
                    value = q_vals[0]
                new_client_sd[key] = self.split_query_key_value(
                    value, num_to_split, ckpt_offset, ckpt_ver)
            elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key:
                assert value.shape[0] % num_to_split == 0
                split_size = value.shape[0] // num_to_split
                if quantize and "mlp.dense_h_to_4h.weight" in key:
                    q_vals = quantizer.Quantize([value], quantize_bits, groups,
                                                key)
                    value = q_vals[0]
                new_client_sd[key] = torch.split(value, split_size,
                                                 dim=0)[ckpt_offset]
            else:
                new_client_sd[key] = value

        if quantize:
            all_scales = quantizer.merge_scales_split(num_to_split)

        ds_sd = self.set_module(ds_sd, new_client_sd)

        return ds_sd, (all_scales if quantize else None)
Exemple #2
0
    def load_state_dir(self, load_dir, strict=True):
        for idx, layer in enumerate(self.forward_funcs):
            # Functions, etc. will not have state_dicts
            if not hasattr(layer, 'load_state_dict'):
                continue

            # get all checkpoint files for the layer.
            model_ckpt_list = self.ckpt_layer_path_list(load_dir, idx)
            mp_rank = self._grid.get_slice_parallel_rank()
            mp_world_size = self._grid.get_slice_parallel_world_size()

            sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list,
                                                      version=2.0)
            load_path, checkpoint, _ = sd_loader.load(mp_world_size,
                                                      mp_rank,
                                                      module_key=None,
                                                      is_pipe_parallel=True)

            layer.load_state_dict(checkpoint)

            if self._grid.data_parallel_id == 0:
                logger.info(
                    f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}'
                )

        self._synchronize_tied_weights()
 def save(self, state_dict, path: str):
     tag = _get_tag_from_path(path)
     partititon_name = os.path.basename(path)
     logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...")
     self.checkpoint.save(partititon_name, state_dict)
     logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.")
     return None
    def __init__(self,
                 gate: Module,
                 experts: Module,
                 num_local_experts: int,
                 group: Optional[Any] = None,
                 use_tutel: bool = False) -> None:
        super().__init__()
        self.gate = gate
        self.experts = experts
        self.group = group
        self.world_size = dist.get_world_size(group)
        self.num_local_experts = num_local_experts
        self.time_falltoall = 0.0
        self.time_salltoall = 0.0
        self.time_moe = 0.0
        self.timers = SynchronizedWallClockTimer()
        self.wall_clock_breakdown = False

        self.use_tutel = use_tutel and TUTEL_INSTALLED

        if self.use_tutel:
            logger.info('Using Tutel optimizations.')
        elif use_tutel and not TUTEL_INSTALLED:
            logger.warning("Tutel optimization requested but not installed. "
                           "Proceeding without Tutel.")
 def override_loss_scale(self, loss_scale):
     if loss_scale != self.external_loss_scale:
         logger.info(
             f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
         )
     self.custom_loss_scaler = True
     self.external_loss_scale = loss_scale
    def merge_state_dict(self,
                         mp_world_size,
                         mp_rank,
                         quantize=False,
                         quantize_bits=8,
                         groups=64,
                         mlp_extra_grouping=True):
        self.sanity_check(self.ckpt_list[0])

        sd_list = self.get_merge_state_dicts(mp_world_size, mp_rank)
        ds_sd = copy.deepcopy(sd_list[0])
        new_client_sd = collections.OrderedDict()

        client_sd_list = [self.get_module(sd) for sd in sd_list]
        keys = client_sd_list[0].keys()

        ckpt_ver = self.get_checkpoint_version(ds_sd)
        logger.info(f"checkpoint version: {ckpt_ver}")
        if quantize:
            quantizer = WeightQuantization(
                mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)

        for key in keys:
            value_list = [sd[key] for sd in client_sd_list]

            if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
                if quantize:
                    value_list = quantizer.Quantize(value_list,
                                                    quantize_bits,
                                                    groups,
                                                    key=key,
                                                    merge_dim=1)
                new_client_sd[key] = torch.cat(value_list, axis=1)
            elif "attention.query_key_value" in key:
                if quantize and "attention.query_key_value.weight" in key:
                    value_list = quantizer.Quantize(value_list,
                                                    quantize_bits,
                                                    groups,
                                                    key=key)
                    new_client_sd[key] = torch.cat(value_list, axis=0)
                else:
                    if quantize:
                        new_client_sd[key] = torch.cat(value_list, axis=0)
                    else:
                        new_client_sd[key] = self.merge_query_key_value(
                            value_list, ckpt_ver)
            elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key:
                if quantize and "mlp.dense_h_to_4h.weight" in key:
                    value_list = quantizer.Quantize(value_list,
                                                    quantize_bits,
                                                    groups,
                                                    key=key)
                new_client_sd[key] = torch.cat(value_list, axis=0)
            else:
                new_client_sd[key] = value_list[0]
        if quantize:
            all_scales = quantizer.merge_scales()
        ds_sd = self.set_module(ds_sd, new_client_sd)

        return ds_sd, (all_scales if quantize else None), len(client_sd_list)
Exemple #7
0
    def __init__(self,
                 gate: Module,
                 experts: Module,
                 ep_group_name,
                 ep_size,
                 num_local_experts: int,
                 use_tutel: bool = False) -> None:
        super().__init__()
        self.gate = gate
        self.experts = experts
        self.ep_group = None
        self.ep_size = ep_size
        self.ep_group_name = ep_group_name
        self.num_local_experts = num_local_experts
        self.time_falltoall = 0.0
        self.time_salltoall = 0.0
        self.time_moe = 0.0
        self.timers = SynchronizedWallClockTimer()
        self.wall_clock_breakdown = False

        self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1

        if self.use_tutel:
            logger.info('Using Tutel optimizations.')
        elif use_tutel and not TUTEL_INSTALLED:
            logger.warning("Tutel optimization requested but not installed. "
                           "Proceeding without Tutel.")
        elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
            logger.warning(
                "To enable Tutel optimization, use top-1 instead of top-2 gate. "
                "Proceeding without Tutel.")
    def load(self,
             mp_world_size,
             mp_rank,
             module_key=AUTO_MODULE_KEY,
             is_pipe_parallel=False,
             quantize=False,
             quantize_bits=8,
             quantize_groups=64,
             mlp_extra_grouping=True):
        self.module_key = module_key
        num_ckpt = len(self.ckpt_list)
        idx = mp_rank * num_ckpt // mp_world_size

        logger.info(
            f'mp_world_size: {mp_world_size}, mp_rank: {mp_rank}, module_key: {module_key}'
        )
        """ We have multiple cases to handle here for both training and inference:
            1. PipeModule loading mp_rank_*.pt files, is_pipe_parallel=True, module_key is not None
                a. if no mp_size/pp_size resizing occurs, for both training & inference, loading
                   the mp_rank related checkpoint directly.
                b. if has mp_size/pp_size resizing, only Megatron model inference is supported,
                   in this case each mp_rank_*.pt have same content, we will load the first checkpoint
                   file (idx=0), to avoid idx exceeding file list boundary.

            2. PipeModule loading layer_*.pt files, is_pipe_parallel=True, module_key is None
                a. if no mp_size resizing occurs, for both training & inference, loading
                   the mp_rank related checkpoint directly.
                b. if has mp_size resizing, only Megatron model inference is supported,
                   checkpoint file(s) will be merged/splitted according to mp_rank, mp_world_size and
                   checkpoint file list.

            3. Non-PipeModule loading mp_rank_*.pt files, is_pipe_parallel=False
                Same with case (2).
        """
        if is_pipe_parallel and module_key is not None and mp_world_size != num_ckpt:
            mp_world_size = num_ckpt
            idx = 0

        load_path = self.ckpt_list[idx]

        merge_count = 1
        if num_ckpt == mp_world_size:
            assert os.path.exists(load_path)
            logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}')
            sd = torch.load(load_path, map_location=lambda storage, loc: storage)

            if quantize:
                quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
                                               mp_size=mp_world_size)
                sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits, quantize_groups)
                self.set_module(sd, sd_module)
            else:
                all_scales = None
        elif num_ckpt > mp_world_size:
            sd, all_scales, merge_count = self.merge_state_dict(mp_world_size, mp_rank, quantize, \
                quantize_bits, quantize_groups, mlp_extra_grouping)
        else:
            sd, all_scales = self.split_state_dict(mp_world_size, mp_rank, quantize, quantize_bits, \
                quantize_groups, mlp_extra_grouping)
        return load_path, sd, (all_scales, merge_count)
Exemple #9
0
    def tune(self, sample_size=1, n_trials=1000, early_stopping=None):
        i = 0
        try:
            while i < n_trials and self.has_next():
                # Select the next batch of configuratiosn for evaluation
                sampled_exps = self.next_batch(sample_size)
                # Generate experiments for measurement of performance
                exp_paths = write_experiments(sampled_exps, self.rm.exps_dir)
                self.rm.schedule_experiments(exp_paths)
                self.rm.run()
                exp, metric_val = self.rm.parse_results(self.metric)
                if self.best_exp == None or self.best_metric_val == None or (
                        metric_val and metric_val > self.best_metric_val):
                    # logger.info(f"tuner finds better = {exp}")
                    self.best_exp = exp
                    self.best_metric_val = metric_val
                    self.best_iter = i

                i += len(sampled_exps)

                # Update the tuner with evaluated performance results
                self.update()

                self.rm.clear()

                # Early stop if no more promising configurations are likely to be found
                if early_stopping and i >= self.best_iter + early_stopping:
                    logger.info(
                        f"Tuner early stopped at iteration {i}. Best iteration is {self.best_iter}. Early stopping threshold is {early_stopping}"
                    )
                    break
            return i
        except:
            logger.info("Tunner Error:", sys.exc_info()[0])
            return i
Exemple #10
0
    def step_fused_lamb(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        grads_groups = []
        norm_groups = []
        expert_norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads = [
                torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                if p.grad is None else p.grad for p in group
            ]
            grads_groups.append(grads)
            grads_groups_flat.append(_flatten_dense_tensors(grads))
            grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(
                group)
            norm_group_value = 0.0
            if len(grads_for_norm) > 0:
                norm_group_value = get_weight_norm(
                    _flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
            norm_groups.append(norm_group_value)
            expert_norm_group_value = 0.0
            if len(expert_grads_for_norm) > 0:
                expert_norm_group_value = get_weight_norm(
                    _flatten_dense_tensors(expert_grads_for_norm),
                    mpu=self.mpu)
            expert_norm_groups.append(expert_norm_group_value)

        self.overflow = self.overflow_checker.check_using_norm(
            norm_groups + expert_norm_groups)
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        combined_scale = self.unscale_and_clip_grads(norm_groups,
                                                     apply_scale=False)
        self.optimizer.step(grads=grads_groups,
                            output_params=self.fp16_groups,
                            scale=combined_scale)

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for idx, (fp32_param,
                      fp16_param) in enumerate(zip(fp32_group, fp16_group)):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
Exemple #11
0
def partition_activations_in_checkpoint(partition_activation):
    global PARTITION_ACTIVATIONS
    PARTITION_ACTIVATIONS = partition_activation
    if dist.get_rank() == 0:
        logger.info(
            f"**************Partition Activations {PARTITION_ACTIVATIONS}************"
        )
Exemple #12
0
    def _partition_layers(self, method='uniform'):
        num_stages = self._topo.get_dim('pipe')
        stage_id = self._topo.get_coord(self.global_rank).pipe

        if self.global_rank == 0:
            logger.info(f'Partitioning pipeline stages with method {method}')

        method = method.lower()

        # Each stage gets a simple uniform number of layers.
        if method == 'uniform':
            num_layers = len(self._layer_specs)
            self.parts = ds_utils.partition_uniform(num_items=num_layers,
                                                    num_parts=num_stages)
        elif method == 'parameters':
            param_counts = self._count_layer_params()
            self.parts = ds_utils.partition_balanced(weights=param_counts,
                                                     num_parts=num_stages)
        elif method.startswith('type:'):
            layertype = method.split(':')[1]
            binary_weights = [0] * len(self._layer_specs)
            for idx in self._find_layer_type(layertype):
                binary_weights[idx] = 1
            else:
                self.parts = ds_utils.partition_balanced(
                    weights=binary_weights, num_parts=num_stages)
        elif method == 'profile':
            raise NotImplementedError(
                f'Partitioning method {method} not implemented.')
        else:
            raise NotImplementedError(
                f'Partitioning method {method} not implemented.')

        # Print some information on the partitioning.
        if self.global_rank == 0:
            for stage in range(num_stages):
                start = self.parts[stage]
                stop = self.parts[stage + 1]
                print(f'stage={stage} layers={stop - start}')
                for idx, layer in enumerate(self._layer_specs[start:stop]):
                    name = str(layer)
                    if isinstance(layer, LayerSpec):
                        name = layer.typename.__name__
                    if isinstance(layer, nn.Module):
                        name = layer.__class__.__name__
                    else:
                        try:
                            name = layer.__name__
                        except AttributeError:
                            pass
                    print(f'    {idx + start:2d}: {name}')
            if self.loss_fn:
                try:
                    print(f'  loss: {self.loss_fn.__name__}')
                except AttributeError:
                    print(f'  loss: {self.loss_fn.__class__.__name__}')

        self._set_bounds(start=self.parts[stage_id],
                         stop=self.parts[stage_id + 1])
Exemple #13
0
 def __init__(self, exps, resource_manager, metric):
     self.all_exps = exps
     self.rm = resource_manager
     self.best_iter = 0
     self.best_exp = None
     self.best_metric_val = None
     self.metric = metric if metric else AUTOTUNING_METRIC_DEFAULT
     logger.info(f"total number of exps =  {len(self.all_exps)}")
    def check_ckpt_list(self):
        logger.info(f'checkpoint file list: {self.ckpt_list}')
        assert len(self.ckpt_list) > 0

        sd = torch.load(self.ckpt_list[0], map_location=lambda storage, loc: storage)

        # check checkpoint count is same with saved mp_world_size
        if 'mp_world_size' in sd.keys():
            assert len(self.ckpt_list) == sd['mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"
    def step(self, closure=None):
        """
        Not supporting closure.
        """

        if self.fused_lamb_legacy:
            return self.step_fused_lamb()

        self.overflow = self.overflow_checker.check()
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(
                group)
            norm_group_value = 0.0
            if len(grads_for_norm) > 0:
                norm_group_value = get_weight_norm(grads_for_norm,
                                                   mpu=self.mpu)
            norm_groups.append(norm_group_value)

            # copying gradients to fp32 to wor  k with fp32 parameters
            for fp32_param, fp16_param in zip(self.fp32_groups[i],
                                              self.fp16_groups[i]):
                if fp16_param.grad is None:
                    fp32_param.grad = torch.zeros(fp16_param.size(),
                                                  dtype=fp32_param.dtype,
                                                  device=fp32_param.device)
                else:
                    fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

        self._global_grad_norm = get_global_norm(norm_list=norm_groups)
        self.unscale_and_clip_grads(self._global_grad_norm)

        self.optimizer.step()

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for idx, (fp32_param,
                      fp16_param) in enumerate(zip(fp32_group, fp16_group)):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
Exemple #16
0
def _configure_using_config_file(deepspeed_config, mpu=None):
    global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
            PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME

    config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config
    logger.info(config.repr())
    PARTITION_ACTIVATIONS = config.partition_activations
    CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
    num_layers = config.number_checkpoints
    PA_TO_CPU = config.cpu_checkpointing
    SYNCHRONIZE = config.synchronize_checkpoint_boundary
    PROFILE_TIME = config.profile
Exemple #17
0
def _handle_overflow(cpu_sum, x, i):
    import math
    rank = torch.distributed.get_rank()
    if rank == 0:
        t_i = -1
        for v_i, v in enumerate(x.data.contiguous().view(-1)):
            if not math.isfinite(float(v)):
                t_i = v_i
                break
        logger.info(
            f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
        )
 def commit(self, tag):
     # nebula commit will be call when all files under give tag are ready to be persisted in the async way.
     logger.info(
         f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting"
     )
     commit_rls = self.checkpoint.commit()
     if not commit_rls:
         logger.error(
             f"[Nebula] failed to commit the checkpoint, please check the log."
         )
         return False
     return commit_rls
Exemple #19
0
    def _set_batch_related_parameters(self):

        train_batch = self.train_batch_size
        micro_batch = self.train_micro_batch_size_per_gpu
        grad_acc = self.gradient_accumulation_steps

        #all values are provided nothing needs to be set
        if train_batch is not None and \
            micro_batch is not None and \
            grad_acc is not None:
            return

        #global_accumulation_steps needs to be set
        elif train_batch is not None and \
            micro_batch is not None:
            grad_acc = train_batch // micro_batch
            grad_acc //= self.world_size
            self.gradient_accumulation_steps = grad_acc

        #micro_batch_per_gpu needs to be set
        elif train_batch is not None and \
            grad_acc is not None:
            micro_batch = train_batch // self.world_size
            micro_batch //= grad_acc
            self.train_micro_batch_size_per_gpu = micro_batch

        #train_batch_size needs to be set
        elif micro_batch is not None and \
            grad_acc is not None:
            train_batch_size = micro_batch * grad_acc
            train_batch_size *= self.world_size
            self.train_batch_size = train_batch_size

        #gradient_accumulation_steps and micro_batch_per_gpus is set
        elif train_batch is not None:
            self.gradient_accumulation_steps = 1
            self.train_micro_batch_size_per_gpu = train_batch // self.world_size

        #train_batch_size and gradient_accumulation_step is set
        elif micro_batch is not None:
            self.train_batch_size = micro_batch * self.world_size
            self.gradient_accumulation_steps = 1

        #either none of the three parameters are provided or just gradient_accumulation_step is provided
        else:
            assert False, \
                'Either train_batch_size or micro_batch_per_gpu needs to be provided'

        logger.info(
            f' After Train batch {self.train_batch_size} micro_batch {self.train_micro_batch_size_per_gpu} and grad_acc {self.gradient_accumulation_steps}'
        )
Exemple #20
0
def see_memory_usage(message):
    return
    if torch.distributed.is_initialized(
    ) and not torch.distributed.get_rank() == 0:
        return

    # Print message except when distributed but not rank 0
    logger.info(message)
    logger.info(
        f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
        Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
        CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
        Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB "
    )
    def step_fused_adam(self, closure=None):
        """
        Not supporting closure.
        """

        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            grads_groups_flat.append(
                _flatten_dense_tensors([
                    torch.zeros(p.size(), dtype=p.dtype, device=p.device)
                    if p.grad is None else p.grad for p in group
                ]))
            norm_groups.append(
                get_weight_norm(grads_groups_flat[i], mpu=self.mpu))

        self.overflow = self.overflow_checker.check_using_norm(norm_groups)
        prev_scale = self.cur_scale
        self._update_scale(self.overflow)

        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        scaled_grad_norm = get_global_norm(norm_list=norm_groups)

        combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
                                                     scaled_grad_norm,
                                                     apply_scale=False)

        # Stash unscaled gradient norm
        self._global_grad_norm = scaled_grad_norm / self.cur_scale

        # norm is in fact norm*cur_scale
        self.optimizer.step(grads=[[g] for g in grads_groups_flat],
                            output_params=[[p] for p in self.fp16_groups_flat],
                            scale=combined_scale,
                            grad_norms=norm_groups)
        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data
        return self.overflow
Exemple #22
0
    def compute_quantization(self, input, index=0, factor=1):
        # fixing the quantization bits based on the training steps
        # when reducing 1 bit at each period, we increase the period
        # to go slowly toward the target quantization bits
        # the period and starting bit can be configured

        if input.start_bits != input.target_bits:
            if self.qsteps >= input.q_period:
                self.quantize_real_ratio = 1.0
                input.q_period <<= 1
                input.q_period *= factor
                input.start_bits -= 1
                if self.q_verbose:
                    logger.info(
                        f'Quantization settings: current bit-precision = {input.start_bits}, step = {self.qsteps}, quantization period = {input.q_period}, index = {index}'
                    )
        assert (input.start_bits >= input.target_bits), \
            'Quantization bit is lower than target precision bits!'

        if self.use_quantizer_kernel:
            if input.start_bits <= 2:
                raise ValueError(
                    'Quantization bit is too low, please do it without quantization kernel!'
                )
            input_q = ds_quantizer(
                input.data.clone(),
                self.q_groups,
                input.start_bits,
                asym=False if self.q_type == 'symmetric' else True,
                sr=False if self.q_rounding == 'nearest_neighbor' else True)
        else:
            if input.start_bits >= 3:
                input_flat = self.quantize_highbit(input.data,
                                                   input.start_bits)
            elif input.start_bits == 2:
                assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
                assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
                input_flat = self.quantize_tenary(input.data)
            elif input.start_bits == 1:
                assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
                assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
                input_flat = self.quantize_binary(input.data)
        if self.use_quantizer_kernel:
            return self.mixed_fp16_quantize(input.data, input_q, index)
        else:
            if self.q_mixed_fp16 and input.start_bits >= input.target_bits - 1:
                input_flat = self.quantize_real_ratio * input.data + \
                              (1 - self.quantize_real_ratio) * input_flat
            return input_flat
Exemple #23
0
 def check_row_pruning(self):
     # check row pruning
     rp = self.different_compression_methods[ROW_PRUNING]
     if not rp[TECHNIQUE_ENABLED]:
         return
     else:
         shared_parameters = rp[SHARED_PARAMETERS]
         if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
             for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]:
                 for module_name in module_name_list:
                     module = recursive_getattr(self.model, module_name)
                     module.row_pruning_enabled = True
             if not self.verbose[ROW_PRUNING]:
                 logger.info(f'Row pruning is enabled at step {self.training_steps}')
                 self.verbose[ROW_PRUNING] = True
Exemple #24
0
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
    group_paddings = []
    flattened_size = sum([tensor.numel() for tensor in tensor_list])
    for i in range(sub_partition_count):
        padding = get_alignment_padding(flattened_size, i, sub_partition_size)
        group_paddings.append(padding)

    logger.info("****Padding information*****")
    logger.info(f"tensor_size = {flattened_size}")
    logger.info(f"sub_partition_size = {sub_partition_size}")
    logger.info(f"sub_partition_count = {sub_partition_count}")
    for i, padding in enumerate(group_paddings):
        logger.info(f"padding[{i}] = {padding}")

    return group_paddings
Exemple #25
0
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
    data_parallel_size = int(dist.get_world_size())
    parameter_parallel_size = parameter_parallel_size or data_parallel_size
    logger.info("data_parallel_size: %s, parameter_parallel_size: %s",
                data_parallel_size, parameter_parallel_size)
    assert data_parallel_size % parameter_parallel_size == 0, \
        'world size should be divisible by parameter parallel size'
    rank = dist.get_rank()
    my_group = None
    for i in range(data_parallel_size // parameter_parallel_size):
        ranks = range(i * parameter_parallel_size,
                      (i + 1) * parameter_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            my_group = group
    return my_group
    def get_split_state_dict(self, mp_world_size, mp_rank):
        num_ckpt = len(self.ckpt_list)
        assert mp_world_size % num_ckpt == 0, 'Invalid checkpoints and world size for sd split'

        num_to_split = mp_world_size // num_ckpt
        ckpt_index = mp_rank // num_to_split
        ckpt_offset = mp_rank % num_to_split

        logger.info(
            f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}"
        )

        sd = torch.load(self.ckpt_list[ckpt_index],
                        map_location=lambda storage, loc: storage)

        return sd, num_to_split, ckpt_offset
Exemple #27
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """
        if self.fused_lamb_legacy:
            return self.step_fused_lamb()

        self.overflow = self.overflow_checker.check()
        prev_scale = self.cur_scale

        self._update_scale(self.overflow)
        if self.overflow:
            if self.verbose:
                logger.info(
                    "[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                    "scale: {}, reducing to {}".format(prev_scale,
                                                       self.cur_scale))
            return self.overflow

        norm_groups = []
        for i, group in enumerate(self.fp16_groups):
            norm_groups.append(get_grad_norm(group, mpu=self.mpu))

            # copying gradients to fp32 to work with fp32 parameters
            for fp32_param, fp16_param in zip(self.fp32_groups[i],
                                              self.fp16_groups[i]):
                if fp16_param.grad is None:
                    fp32_param.grad = torch.zeros(fp16_param.size(),
                                                  dtype=fp32_param.dtype,
                                                  device=fp32_param.device)
                else:
                    fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

        self.unscale_and_clip_grads(norm_groups)

        self.optimizer.step()

        for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
            for fp32_param, fp16_param in zip(fp32_group, fp16_group):

                #remove the fp32 grad
                fp32_param.grad = None

                #copy data from fp32 to fp16
                fp16_param.data.copy_(fp32_param.data)

        return self.overflow
    def get_merge_state_dicts(self, mp_world_size, mp_rank):
        num_ckpt = len(self.ckpt_list)
        assert num_ckpt % mp_world_size == 0, 'Invalid checkpoints and world size for sd merge'

        num_to_merge = num_ckpt // mp_world_size
        ckpt_list = [
            self.ckpt_list[i] for i in range(num_to_merge * mp_rank,
                                             num_to_merge * (mp_rank + 1))
        ]

        logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}")
        sd_list = [
            torch.load(ckpt,
                       map_location=lambda storage,
                       loc: storage) for ckpt in ckpt_list
        ]
        return sd_list
Exemple #29
0
 def check_activation_quantization(self):
     # check activation quantization
     aq = self.different_compression_methods[ACTIVATION_QUANTIZATION]
     if not aq[TECHNIQUE_ENABLED]:
         return
     else:
         shared_parameters = aq[SHARED_PARAMETERS]
         if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
             for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]:
                 for module_name in module_name_list:
                     module = recursive_getattr(self.model, module_name)
                     module.activation_quantization_enabled = True
             if not self.verbose[ACTIVATION_QUANTIZATION]:
                 logger.info(
                     f'Activation quantization is enabled at step {self.training_steps}'
                 )
                 self.verbose[ACTIVATION_QUANTIZATION] = True
Exemple #30
0
    def load_state_dir(self, load_dir, strict=True):
        rank = dist.get_rank()

        layer_offset = self._local_start
        for idx, layer in enumerate(self.forward_funcs):
            # Functions, etc. will not have state_dicts
            if not hasattr(layer, 'load_state_dict'):
                continue

            model_ckpt_path = self.ckpt_layer_path(load_dir, idx)
            layer.load_state_dict(torch.load(
                model_ckpt_path, map_location=lambda storage, loc: storage),
                                  strict=strict)
            if self._grid.data_parallel_id == 0:
                logger.info(
                    f'RANK={self.global_rank} Loaded layer={idx + layer_offset} file={model_ckpt_path}'
                )

        self._synchronize_tied_weights()