예제 #1
0
 def _init_quantization_setting(self, quantization_setting):
     self.quantize_bits = 8
     self.mlp_extra_grouping = False
     self.quantize_groups = 1
     if type(quantization_setting) is tuple:
         self.mlp_extra_grouping, \
         self.quantize_groups = quantization_setting
     elif quantization_setting is not None:
         self.quantize_groups = quantization_setting
     log_dist(
         f"quantize_bits = {self.quantize_bits} "
         f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
         f"quantize_groups = {self.quantize_groups}",
         [0])
예제 #2
0
    def log(self,
            names,
            normalizer=1.0,
            reset=True,
            memory_breakdown=False,
            ranks=None):
        """Log a group of timers."""
        assert normalizer > 0.0
        string = f'rank={torch.distributed.get_rank()} time (ms)'
        for name in names:
            if name in self.timers:
                elapsed_time = self.timers[name].elapsed(
                    reset=reset) * 1000.0 / normalizer
                string += ' | {}: {:.2f}'.format(name, elapsed_time)

        log_dist(string, ranks=ranks or [0])
예제 #3
0
    def log(self,
            names,
            normalizer=1.0,
            reset=True,
            memory_breakdown=False,
            ranks=None):
        """Log a group of timers."""
        assert normalizer > 0.0
        string = f"rank={dist.get_rank()} time (ms)"
        for name in names:
            if name in self.timers:
                elapsed_time = (self.timers[name].elapsed(reset=reset) /
                                normalizer)
                string += " | {}: {:.2f}".format(name, elapsed_time)

        log_dist(string, ranks=ranks or [0])
예제 #4
0
    def __init__(self,
                 config,
                 mp_group=None,
                 quantize_scales=None,
                 quantize_groups=1,
                 merge_count=1,
                 mlp_extra_grouping=False,
                 qkv_merging=False):
        super(DeepSpeedTransformerInference, self).__init__()

        self.config = config
        self.config.layer_id = DeepSpeedTransformerInference.layer_id
        DeepSpeedTransformerInference.layer_id += 1

        data_type = torch.half if config.fp16 else torch.float
        global inference_cuda_module
        if inference_cuda_module is None:
            builder = op_builder.InferenceBuilder()
            inference_cuda_module = builder.load()

        if DeepSpeedTransformerInference.layer_id == 1:
            log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}",
                     [0])

        self.attention = DeepSpeedSelfAttention(self.config, mp_group,
                                                quantize_scales,
                                                quantize_groups, merge_count,
                                                qkv_merging)
        self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales,
                                quantize_groups, merge_count,
                                mlp_extra_grouping)

        device = torch.cuda.current_device(
        ) if config.bigscience_bloom else 'cpu'
        self.norm_w = nn.Parameter(
            torch.empty(self.config.hidden_size,
                        dtype=data_type,
                        device=device))
        self.norm_b = nn.Parameter(
            torch.empty(self.config.hidden_size,
                        dtype=data_type,
                        device=device))
        self.layer_past = None
예제 #5
0
 def log(self,
         names,
         normalizer=1.0,
         reset=True,
         memory_breakdown=False,
         ranks=None,
         return_values=True):
     """Log a group of timers."""
     assert normalizer > 0.0
     string = f'rank={torch.distributed.get_rank()} time (ms)'
     if return_values:
         return_dict = {}
     for name in names:
         if name in self.timers:
             elapsed_time = self.timers[name].elapsed(
                 reset=reset) * 1000.0 / normalizer
             string += ' | {}: {:.2f}'.format(name, elapsed_time)
             if return_values:
                 return_dict[name] = elapsed_time
         else:
             log_dist(f'logging failed for timer {name}', ranks=[0])
     log_dist(string, ranks=ranks or [0])
     if return_values:
         return return_dict