示例#1
0
    def all_reduce(self, overflow_buf, accum=1):
        scaler = amp.scaler.LossScaler(1.0)

        # 1. allocate an uninitialized buffer for flattened gradient
        master_grads = [
            p.grad for p in amp.master_params(self.optimizer)
            if p.grad is not None
        ]
        flat_grad_size = sum(p.numel() for p in master_grads)
        allreduce_dtype = torch.float32
        flat_raw = torch.empty(flat_grad_size,
                               device='cuda',
                               dtype=allreduce_dtype)
        # 2. combine unflattening and predivision of unscaled 'raw' gradient
        allreduced_views = apex_C.unflatten(flat_raw, master_grads)
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(
            65536, overflow_buf, [master_grads, allreduced_views],
            scaler.loss_scale() / (self.team_size * accum))
        # 3. sum gradient across ranks. Because of the predivision,
        #    this averages the gradient
        torch.distributed.all_reduce(flat_raw, group=self.local_group)
        # 4. combine unscaling and unflattening of allreduced gradient
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536, overflow_buf,
                                 [allreduced_views, master_grads],
                                 1. / scaler.loss_scale())
示例#2
0
def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):

    global skipped_steps
    if args.allreduce_post_accumulation:
        # manually allreduce gradients after all accumulation steps
        # check for Inf/NaN
        # 1. allocate an uninitialized buffer for flattened gradient
        loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1
        master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None]
        flat_grad_size = sum(p.numel() for p in master_grads)
        allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
        flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype)
        # 2. combine unflattening and predivision of unscaled 'raw' gradient
        allreduced_views = apex_C.unflatten(flat_raw, master_grads)
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536,
            overflow_buf,
            [master_grads, allreduced_views],
            loss_scale / (get_world_size() * args.gradient_accumulation_steps))
        # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
        torch.distributed.all_reduce(flat_raw)
        # 4. combine unscaling and unflattening of allreduced gradient
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536,
            overflow_buf,
            [allreduced_views, master_grads],
            1./loss_scale)
        # 5. update loss scale
        if args.fp16:
            scaler = _amp_state.loss_scalers[0]
            old_overflow_buf = scaler._overflow_buf
            scaler._overflow_buf = overflow_buf
            had_overflow = scaler.update_scale()
            scaler._overfloat_buf = old_overflow_buf
        else:
            had_overflow = 0
        # 6. call optimizer step function
        if had_overflow == 0:
            optimizer.step()
            global_step += 1
        else:
            # Overflow detected, print message and clear gradients
            skipped_steps += 1
            if is_main_process():
                scaler = _amp_state.loss_scalers[0]
                dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()})
            if _amp_state.opt_properties.master_weights:
                for param in optimizer._amp_stash.all_fp32_from_fp16_params:
                    param.grad = None
        for param in model.parameters():
            param.grad = None
    else:
        if args.apply_optimizer > 0:
            optimizer.step()
        # optimizer.zero_grad()
        for param in model.parameters():
            param.grad = None
        global_step += 1

    return global_step
示例#3
0
 def _step_distributed_fp16(self) -> None:
     # manually allreduce gradients after all accumulation steps
     # check for Inf/NaN
     # 1. allocate an uninitialized buffer for flattened gradient
     scaler = _amp_state.loss_scalers[0]
     master_grads = [
         p.grad for p in amp.master_params(self.optimizer)
         if p.grad is not None
     ]
     flat_grad_size = sum(p.numel() for p in master_grads)
     # allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else \
     # torch.float32
     allreduce_dtype = torch.float16
     flat_raw = torch.empty(flat_grad_size,
                            device='cuda',
                            dtype=allreduce_dtype)
     # 2. combine unflattening and predivision of unscaled 'raw' gradient
     allreduced_views = apex_C.unflatten(flat_raw, master_grads)
     self._overflow_buf.zero_()
     amp_C.multi_tensor_scale(
         65536, self._overflow_buf, [master_grads, allreduced_views],
         scaler.loss_scale() / (torch.distributed.get_world_size() *
                                self.gradient_accumulation_steps))
     # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
     torch.distributed.all_reduce(flat_raw)
     # 4. combine unscaling and unflattening of allreduced gradient
     self._overflow_buf.zero_()
     amp_C.multi_tensor_scale(65536, self._overflow_buf,
                              [allreduced_views, master_grads],
                              1. / scaler.loss_scale())
     # 5. update loss scale
     scaler = _amp_state.loss_scalers[0]
     old_overflow_buf = scaler._overflow_buf
     scaler._overflow_buf = self._overflow_buf
     had_overflow = scaler.update_scale()
     scaler._overfloat_buf = old_overflow_buf
     # 6. call optimizer step function
     if had_overflow == 0:
         self._step()
     else:
         # Overflow detected, print message and clear gradients
         logger.info(
             f"Gradient overflow.  Skipping step, reducing loss scale to "
             f"{scaler.loss_scale()}")
         if _amp_state.opt_properties.master_weights:
             for param in self.optimizer._amp_stash.all_fp32_from_fp16_params:
                 param.grad = None
     for param in self.model.parameters():
         param.grad = None
示例#4
0
def take_optimizer_step(args, optimizer, grad_scaler, model, overflow_buf,
                        global_step):

    global skipped_steps
    if args.allreduce_post_accumulation:
        # manually allreduce gradients after all accumulation steps
        # check for Inf/NaN
        # 1. allocate an uninitialized buffer for flattened gradient
        loss_scale = grad_scaler._get_scale_async() if args.fp16 else 1.
        master_grads = [
            p.grad for p in model.parameters() if p.grad is not None
        ]
        flat_grad_size = sum(p.numel() for p in master_grads)
        allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
        flat_raw = torch.empty(flat_grad_size,
                               device='cuda',
                               dtype=allreduce_dtype)
        # 2. combine unflattening and predivision of unscaled 'raw' gradient
        allreduced_views = apex_C.unflatten(flat_raw, master_grads)
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(
            65536, overflow_buf, [master_grads, allreduced_views],
            loss_scale / (get_world_size() * args.gradient_accumulation_steps))
        # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
        torch.distributed.all_reduce(flat_raw)
        # 4. combine unscaling and unflattening of allreduced gradient
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536, overflow_buf,
                                 [allreduced_views, master_grads], 1.)
        # 5. update loss scale
        if args.fp16:
            had_overflow = overflow_buf.item()
        else:
            had_overflow = 0
        # 6. call optimizer step function
        if had_overflow == 0:
            global_step += 1
        else:
            # Overflow detected, print message and clear gradients
            skipped_steps += 1
    else:
        global_step += 1
    grad_scaler.step(optimizer)
    grad_scaler.update()
    optimizer.zero_grad(set_to_none=True)

    return global_step