def clip_grad_norm_fp32(parameters,
                        param_is_distributed,
                        shard_optimizer_state,
                        max_norm,
                        norm_type=2):
    """Clips gradient norm of an iterable of parameters whose gradients
       are in fp32.
    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.
    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    torch.cuda.set_device(smp.local_rank())
    grads = []
    grads_for_norm = []
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, "shared") or not param.shared
        is_not_tp_duplicate = smp.tp_rank() == 0 or (
            param in param_is_distributed and param_is_distributed[param])
        if grad_not_none:
            grad = param.grad.detach()
            # Make sure the grads are in fp32
            assert param.grad.type() == "torch.cuda.FloatTensor"
            grads.append(grad)
            if is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = torch.tensor(0.0, device=torch.device("cuda"))

    # Calculate norm.
    if norm_type == inf:
        if len(grads_for_norm) > 0:
            total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        # Reducing across all ranks since gradients may be different across data parallel ranks
        # when optimizer state sharding is enabled.
        group = (smp.get_world_process_group()
                 if shard_optimizer_state else smp.get_mp_process_group())
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=group)
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0],
                                                      device=torch.device(
                                                          "cuda",
                                                          smp.local_rank()))
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
            if len(grads_for_norm) > 0:
                grad_norm, _ = multi_tensor_applier(
                    amp_C.multi_tensor_l2norm,
                    dummy_overflow_buf,
                    [grads_for_norm],
                    False,  # no per-parameter norm
                )
                # Since we will be summing across data parallel groups,
                # we need the pow(norm-type).
                total_norm = grad_norm**norm_type

        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm**norm_type

        # Sum across all model-parallel GPUs.
        group = (smp.get_world_process_group()
                 if shard_optimizer_state else smp.get_mp_process_group())
        torch.distributed.all_reduce(total_norm,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=group)
        total_norm = total_norm.item()**(1.0 / norm_type)

    # Scale.
    if len(grads) > 0:
        clip_coeff = max_norm / (total_norm + 1.0e-6)
        if clip_coeff < 1.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0],
                                                      device=torch.device(
                                                          "cuda",
                                                          smp.local_rank()))
            multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf,
                                 [grads, grads], clip_coeff)

    return total_norm
Beispiel #2
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        explicit_master_params = (hasattr(self, "_amp_stash") and hasattr(
            self._amp_stash, "fp32_from_fp16_groups"))

        for gid, group in enumerate(self.param_groups):
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            # For each group, there are 3 possible combinations we need to consider:
            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
            # 1. fp16, fp16, fp16, No
            # 2. fp32, fp32, fp32, No
            # 3. fp16, fp32, fp32, Yes

            first_runs = [True, True]

            # I think a bit of code divergence in exchange for naming clarity is worthwhile
            if explicit_master_params:
                stash = self._amp_stash

                fp32_params = [
                    p for p in stash.fp32_from_fp32_groups[gid]
                    if p.grad is not None
                ]
                fp32_grads = [
                    p.grad for p in stash.fp32_from_fp32_groups[gid]
                    if p.grad is not None
                ]
                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)

                if self.materialize_master_grads:
                    fp16_model_params = [
                        p for i, p in enumerate(stash.fp16_groups[gid])
                        if stash.fp32_from_fp16_groups[gid][i].grad is not None
                    ]
                    fp32_from_fp16_grads = [
                        p.grad for p in stash.fp32_from_fp16_groups[gid]
                        if p.grad is not None
                    ]
                    fp32_from_fp16_params = [
                        p for p in stash.fp32_from_fp16_groups[gid]
                        if p.grad is not None
                    ]
                    fp32_from_fp16_momentums, first_runs[
                        0] = self.get_momentums(fp32_from_fp16_params)

                    fp16_set = [
                        fp32_from_fp16_grads, fp32_from_fp16_params,
                        fp32_from_fp16_momentums, fp16_model_params
                    ]
                else:
                    fp16_model_params = [
                        p for p in stash.fp16_groups[gid] if p.grad is not None
                    ]
                    fp16_model_grads = [
                        p.grad for p in stash.fp16_groups[gid]
                        if p.grad is not None
                    ]
                    fp32_from_fp16_params = [
                        p
                        for i, p in enumerate(stash.fp32_from_fp16_groups[gid])
                        if stash.fp16_groups[gid][i].grad is not None
                    ]
                    fp32_from_fp16_momentums, first_runs[
                        0] = self.get_momentums(fp32_from_fp16_params)

                    fp16_set = [
                        fp16_model_grads, fp32_from_fp16_params,
                        fp32_from_fp16_momentums, fp16_model_params
                    ]

                launch_sets = [
                    fp16_set, [fp32_grads, fp32_params, fp32_momentums]
                ]
            else:
                fp16_params = [
                    p for p in group['params']
                    if (p.dtype == torch.float16 and p.grad is not None)
                ]
                fp16_grads = [
                    p.grad for p in group['params']
                    if (p.dtype == torch.float16 and p.grad is not None)
                ]
                fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)

                fp32_params = [
                    p for p in group['params']
                    if (p.dtype == torch.float32 and p.grad is not None)
                ]
                fp32_grads = [
                    p.grad for p in group['params']
                    if (p.dtype == torch.float32 and p.grad is not None)
                ]
                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)

                launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
                               [fp32_grads, fp32_params, fp32_momentums]]

            for s, (launch_set,
                    first_run) in enumerate(zip(launch_sets, first_runs)):
                assert len(launch_set[0]) == len(launch_set[1])
                assert len(launch_set[0]) == len(launch_set[2])
                if len(launch_set[0]) > 0:
                    multi_tensor_applier(self.multi_tensor_sgd,
                                         self._dummy_overflow_buf, launch_set,
                                         weight_decay, momentum, dampening,
                                         group['lr'], nesterov, first_run,
                                         self.wd_after_momentum,
                                         1.0 / self.most_recent_scale)

        self.most_recent_scale = 1.0
        self.scale_set_by_backward = False

        return loss
    def step(self,
             closure=None,
             grads=None,
             output_params=None,
             scale=1.,
             grad_norms=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        if hasattr(self, "_amp_stash"):
            grads = self._amp_stash.grads
            output_params = self._amp_stash.output_params
            scale = self._amp_stash.scale * self._amp_scale_adjustment
            grad_norms = self._amp_stash.grad_norms

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        if output_params is None:
            output_params_group = [None] * len(self.param_groups)
        elif isinstance(output_params, types.GeneratorType):
            output_params_group = [output_params]
        elif type(output_params[0]) != list:
            output_params_group = [output_params]
        else:
            output_params_group = output_params

        if grad_norms is None:
            grad_norms = [None] * len(self.param_groups)

        for group, grads_this_group, output_params_this_group, grad_norm in zip(
                self.param_groups, grads_group, output_params_group,
                grad_norms):
            if grads_this_group is None:
                grads_this_group = [None] * len(group['params'])
            if output_params_this_group is None:
                output_params_this_group = [None] * len(group['params'])

            # compute combined scale factor for this group
            combined_scale = scale
            if group['max_grad_norm'] > 0:
                # norm is in fact norm*scale
                clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
                if clip > 1:
                    combined_scale = clip * scale

            bias_correction = 1 if group['bias_correction'] else 0

            if self._use_multi_tensor:
                if output_params:
                    tensorlists = [[], [], [], [], []]
                else:
                    tensorlists = [[], [], [], []]

            for p, grad, output_param in zip(group['params'], grads_this_group,
                                             output_params_this_group):
                #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                out_p = torch.tensor(
                    [], dtype=torch.float
                ) if output_param is None else output_param
                if self._use_multi_tensor:
                    pl = [p.data, exp_avg, exp_avg_sq, grad]
                    if output_param is not None:
                        pl.append(out_p)

                    for tl, t in zip(tensorlists, pl):
                        tl.append(t)
                else:
                    fused_adam_cuda.adam(p.data, out_p, exp_avg, exp_avg_sq,
                                         grad, group['lr'], beta1, beta2,
                                         group['eps'], combined_scale,
                                         state['step'], self.eps_mode,
                                         bias_correction,
                                         group['weight_decay'])

            if self._use_multi_tensor:
                multi_tensor_applier(fused_adam_cuda.adam_mt,
                                     self._overflow_buf, tensorlists,
                                     group['lr'], beta1, beta2, group['eps'],
                                     combined_scale, state['step'],
                                     self.eps_mode, bias_correction,
                                     group['weight_decay'])

        return loss
Beispiel #4
0
 def _flatten_grad_mt(self, scale):
     if self._flat_mt and len(self._grads) > 0:
         self._overflow_buf.zero_()
         multi_tensor_applier(amp_C.multi_tensor_scale, self._overflow_buf,
                              list(zip(*self._grads)), scale)
         self._grads = []
Beispiel #5
0
        def step(self,
                 closure=None,
                 grads=None,
                 output_params=None,
                 scale=None,
                 grad_norms=None):
            """Performs a single optimization step."""
            loss = None
            if closure is not None:
                loss = closure()

            for group in self.param_groups:
                bias_correction = 1 if group['bias_correction'] else 0
                beta1, beta2 = group['betas']

                # assume same step across group now to simplify things
                # per parameter step can be easily support by making it tensor, or pass list into kernel
                if 'step' in group:
                    group['step'] += 1
                else:
                    group['step'] = 1

                # create lists for multi-tensor apply
                g_16, p_16, orig_p_16, m_16, v_16 = [], [], [], [], []
                g_32, p_32, m_32, v_32 = [], [], [], []

                for p in group['params']:
                    if p.grad is None:
                        continue
                    if p.grad.data.is_sparse:
                        raise RuntimeError(
                            'FusedAdam does not support sparse gradients, '
                            'please consider SparseAdam instead')

                    state = self.state[p]
                    # State initialization
                    if len(state) == 0:
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p.data,
                                                            dtype=torch.float)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(
                            p.data, dtype=torch.float)
                    else:
                        state['exp_avg'] = state['exp_avg'].to(
                            device=p.data.device, dtype=torch.float)
                        state['exp_avg_sq'] = state['exp_avg_sq'].to(
                            device=p.data.device, dtype=torch.float)

                    if p.dtype == torch.float16:
                        g_16.append(p.grad.data.float())
                        p_16.append(p.data.float())
                        orig_p_16.append(p.data)
                        m_16.append(state['exp_avg'])
                        v_16.append(state['exp_avg_sq'])
                    elif p.dtype == torch.float32:
                        g_32.append(p.grad.data)
                        p_32.append(p.data)
                        m_32.append(state['exp_avg'])
                        v_32.append(state['exp_avg_sq'])
                    else:
                        raise RuntimeError(
                            'FusedAdam only support fp16 and fp32.')

                with torch.cuda.device(p.device):
                    if (len(g_16) > 0):
                        multi_tensor_applier(self.multi_tensor_adam,
                                             self._dummy_overflow_buf,
                                             [g_16, p_16, m_16, v_16],
                                             group['lr'], beta1, beta2,
                                             group['eps'], group['step'],
                                             self.adam_w_mode, bias_correction,
                                             group['weight_decay'])
                        for orig_p, p in zip(orig_p_16, p_16):
                            orig_p.copy_(p.data)
                    if (len(g_32) > 0):
                        multi_tensor_applier(self.multi_tensor_adam,
                                             self._dummy_overflow_buf,
                                             [g_32, p_32, m_32, v_32],
                                             group['lr'], beta1, beta2,
                                             group['eps'], group['step'],
                                             self.adam_w_mode, bias_correction,
                                             group['weight_decay'])

            return loss
Beispiel #6
0
    def step(self, closure=None, scale=1.):
        """Apply Adam optimizer step

        Arguments:
            closure (callable, optional): closure to recompute loss
                (default: None)
            scale (float, optional): scaling factor to divide
                gradients (default: 1.0)

        """
        self.state['step'] += 1
        loss = None
        if closure is not None:
            loss = closure()

        # Make sure that gradients have been reduced
        self.grad_sync()

        # Scale gradient if L2 norm is too large
        if self.max_grad_norm > 0:
            grad_norm = self.grad_norm().item()
            if (math.isfinite(grad_norm)
                and grad_norm / scale > self.max_grad_norm):
                scale = grad_norm / self.max_grad_norm

        # Apply optimizer step to each bucket and synchronize params
        current_stream = torch.cuda.current_stream()
        for stream in self._pipeline_streams:
            stream.wait_stream(current_stream)
        for i, bucket in enumerate(self.state['buckets']):
            stream = self._pipeline_streams[i % self.pipeline_size]
            with torch.cuda.stream(stream):

                # Buffer for param sync
                params_shard_copy = torch.zeros(
                    [self.shard_size],
                    dtype=self.param_sync_dtype,
                    device=self.device,
                )

                # Find param fragments in local shard
                buffers = collections.defaultdict(list) # p, m, v, g, p_copy
                for fragment in bucket['fragments']:
                    if fragment['in_local_shard']:
                        param_group_id = fragment['param_group_id']
                        shard_start, shard_end = fragment['shard_range']
                        buffers[param_group_id].append([
                            bucket['params_shard'][shard_start:shard_end],
                            bucket['exp_avg_shard'][shard_start:shard_end],
                            bucket['exp_avg_sq_shard'][shard_start:shard_end],
                            bucket['grads_shard'][shard_start:shard_end],
                            params_shard_copy[shard_start:shard_end],
                        ])

                # Fuse param fragments if possible
                if len(buffers) == 1:
                    group_id = list(buffers.keys())[0]
                    buffers[group_id] = [(
                        bucket['params_shard'],
                        bucket['exp_avg_shard'],
                        bucket['exp_avg_sq_shard'],
                        bucket['grads_shard'],
                        params_shard_copy,
                    )]

                # Apply optimizer step to each param group
                for group_id, group_buffers in buffers.items():

                    # Get param group configs
                    group = self.param_groups[group_id]
                    beta1, beta2 = group['betas']
                    bias_correction = 1 if group['bias_correction'] else 0
                    eps = group['eps']
                    weight_decay = group['weight_decay']

                    # Copy param group configs to GPU
                    num_fragments = len(group_buffers)
                    beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda')
                    beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda')
                    bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda')
                    eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda')
                    weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda')

                    # Apply Adam step
                    dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
                    multi_tensor_applier(
                        distributed_adam_cuda.multi_tensor_fused_adam,
                        dummy_overflow_buf,
                        list(zip(*group_buffers)),
                        beta1,
                        beta2,
                        bias_correction,
                        eps,
                        weight_decay,
                        group['lr'],
                        scale,
                        self.state['step'],
                        1, # Set to 0 to apply eps inside sqrt
                    )

                # Deallocate buffers
                del buffers
                bucket['grads_shard'] = None

                # Allgather updated parameters
                if self.distributed_size == 1:
                    params_bucket = params_shard_copy
                else:
                    params_bucket = torch.zeros(
                        [self.bucket_size],
                        dtype=self.param_sync_dtype,
                        device=self.device,
                    )
                    params_bucket_shards = [
                        params_bucket[i*self.shard_size:(i+1)*self.shard_size]
                        for i in range(self.distributed_size)
                    ]
                    params_bucket_shards[self.distributed_rank].copy_(params_shard_copy)
                    if self._all_gather_no_copy:
                        no_copy_kwarg = { 'no_copy': True }
                    else:
                        no_copy_kwarg = {}
                    torch.distributed.all_gather(
                        params_bucket_shards,
                        params_bucket_shards[self.distributed_rank],
                        group=self.distributed_process_group,
                        **no_copy_kwarg,
                    )
                del params_shard_copy

                # Copy values to param buffers
                params_in = []
                params_out = []
                for fragment in bucket['fragments']:
                    param_group_id = fragment['param_group_id']
                    param_id = fragment['param_id']
                    param = self.param_groups[param_group_id]['params'][param_id]
                    bucket_start, bucket_end = fragment['bucket_range']
                    param_start, param_end = fragment['param_range']
                    params_in.append(params_bucket[bucket_start:bucket_end])
                    params_out.append(param.view(-1)[param_start:param_end])
                if params_in:
                    dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
                    multi_tensor_applier(
                        fused_adam_cuda.maybe_cast_mt,
                        dummy_overflow_buf,
                        [params_in, params_out],
                    )
                del params_bucket, params_in, params_out

        # Synchronize pipeline streams
        for stream in self._pipeline_streams:
            current_stream.wait_stream(stream)

        return loss
Beispiel #7
0
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        # create separate grad lists for fp32 and fp16 params
        g_all_32, g_all_16 = [], []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                if p.dtype == torch.float32:
                    g_all_32.append(p.grad.data)
                elif p.dtype == torch.float16:
                    g_all_16.append(p.grad.data)
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

        g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(
            1, device='cuda')
        # compute grad norm for two lists
        if len(g_all_32) > 0:
            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_32], False)[0]
        if len(g_all_16) > 0:
            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_16], False)[0]

        # blend two grad norms to get global grad norm
        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
                                                self._dummy_overflow_buf,
                                                [[g_norm_32, g_norm_16]],
                                                False)[0]

        max_grad_norm = 1.0
        clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.grad.data *= clipped_ratio
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'Lamb does not support sparse gradients, consider SparseAdam instad.'
                    )

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['m'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['v'] = torch.zeros_like(p.data)

                m_t, v_t = state['m'], state['v']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # m_t = beta1 * m + (1 - beta1) * g_t
                m_t.mul_(beta1).add_(grad, alpha=1 - beta1)
                # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
                v_t.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Debiasing
                m_t_hat = m_t / (1.0 - beta1**state['step'])
                v_t_hat = v_t / (1.0 - beta2**state['step'])

                update = m_t_hat / v_t_hat.sqrt().add(group['eps'])

                if group['weight_decay'] != 0:
                    update.add_(p.data, alpha=group['weight_decay'])

                trust_ratio = 1.0
                w_norm = p.data.pow(2).sum().sqrt()
                g_norm = update.pow(2).sum().sqrt()
                if w_norm > 0 and g_norm > 0:
                    trust_ratio = w_norm / g_norm

                state['w_norm'] = w_norm
                state['g_norm'] = g_norm
                state['trust_ratio'] = trust_ratio

                step_size = group['lr']

                p.data.add_(update, alpha=-step_size * trust_ratio)

        return loss
Beispiel #8
0
def clip_grad_norm_(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    This is identical to torch.nn.utils.clip_grad_norm_, except it
    uses a fused CUDA kernel when computing the 2-norm of GPU tensors
    in float32 and float16.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameters (viewed as a single vector).

    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)

    # Trivial case
    if len(parameters) == 0:
        return torch.tensor(0.)

    # Fallback implementation
    if not (_kernel_import_succeeded
            and norm_type == 2.0
            and any(p.is_cuda for p in parameters)):
        return torch.nn.utils.clip_grad_norm_(
            parameters,
            max_norm,
            norm_type=norm_type,
            error_if_nonfinite = error_if_nonfinite,
        )

    # Find fp32 and fp16 gradients on GPU
    device = next(p.device for p in parameters if p.is_cuda)
    grads_fp32, grads_fp16, grads_misc = [], [], []
    for p in parameters:
        grad = p.grad.detach()
        if p.dtype == torch.float32 and p.device == device:
            grads_fp32.append(grad)
        elif p.dtype == torch.float16 and p.device == device:
            grads_fp16.append(grad)
        else:
            grads_misc.append(grad)

    # Compute gradient L2 norms
    norms = []
    dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
    if grads_fp32:
        norms.append(
            multi_tensor_applier(
                amp_C.multi_tensor_l2norm,
                dummy_overflow_buf,
                [grads_fp32],
                False,
            )[0]
        )
    if grads_fp16:
        norms.append(
            multi_tensor_applier(
                amp_C.multi_tensor_l2norm,
                dummy_overflow_buf,
                [grads_fp16],
                False,
            )[0],
        )
    for g in grads_misc:
        norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
    total_norm = torch.linalg.norm(torch.cat(norms))

    # Check for non-finite values
    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f'The total norm of order {norm_type} for gradients from '
            '`parameters` is non-finite, so it cannot be clipped. To disable '
            'this error and scale the gradients by the non-finite norm anyway, '
            'set `error_if_nonfinite=False`')

    # Scale gradients
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    if grads_fp32:
        multi_tensor_applier(
            amp_C.multi_tensor_scale,
            dummy_overflow_buf,
            [grads_fp32, grads_fp32],
            clip_coef_clamped,
        )
    if grads_fp16:
        multi_tensor_applier(
            amp_C.multi_tensor_scale,
            dummy_overflow_buf,
            [grads_fp16, grads_fp16],
            clip_coef_clamped,
        )
    for g in grads_misc:
        g.mul_(clip_coef_clamped.to(g.device))

    return total_norm
Beispiel #9
0
    def step(self, closure=None, grad_scaler=None):
        loss = None
        if closure is not None:
            loss = closure()

        assert grad_scaler is not None, "FusedLAMBAMP requires `GradScaler` instance"

        print("YESSS!!!")

        # create separate grad lists for fp32 and fp16 params
        g_all_32, g_all_16 = [], []
        for gid, (group, fp32_group) in enumerate(
                zip(self.param_groups, self.param_groups_fp32)):
            for pid, (p, fp32_p) in enumerate(
                    zip(group['params'], fp32_group['params'])):
                if p.grad is None:
                    continue
                assert p.dtype in (torch.float16, torch.float32)
                if p.dtype == torch.float32:
                    g_all_32.append(p.grad)
                else:  # p.dtype == torch.float16:
                    g_all_16.append(p.grad)
        device = self.param_groups[0]["params"][0].device
        g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(
            1, device=device)
        # compute grad norm for two lists
        if len(g_all_32) > 0:
            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_32], False)[0]
        if len(g_all_16) > 0:
            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_16], False)[0]

        # blend two grad norms to get global grad norm
        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
                                                self._dummy_overflow_buf,
                                                [[g_norm_32, g_norm_16]],
                                                False)[0]
        max_grad_norm = self.defaults['max_grad_norm']

        found_inf_per_device = grad_scaler._check_inf_per_device(self)
        assert found_inf_per_device
        found_inf = found_inf_per_device[device]
        inv_scale = grad_scaler._get_scale_async().double().reciprocal().float(
        )
        # Run LAMB optimization math
        for gid, (group, fp32_group) in enumerate(
                zip(self.param_groups, self.param_groups_fp32)):
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']
            grad_averaging = 1 if group['grad_averaging'] else 0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16, v_16, dst_param_fp16 = [], [], [], [], []
            g_32, p_32, m_32, v_32, dst_param_fp32 = [], [], [], [], []

            for p, p_fp32 in zip(group['params'], fp32_group['params']):
                if p.grad is None:
                    continue
                assert not p.grad.is_sparse

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data,
                                                        dtype=torch.float32)
                    # Exponential moving average of gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data,
                                                           dtype=torch.float32)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p_fp32.data)
                    m_16.append(state['exp_avg'])
                    v_16.append(state['exp_avg_sq'])
                    dst_param_fp16.append(p.data)
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p_fp32.data)
                    m_32.append(state['exp_avg'])
                    v_32.append(state['exp_avg_sq'])
                    dst_param_fp32.append(p.data)
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

            print('LearnRate', group['lr'], type(group['lr']), grad_averaging)

            if g_16:
                multi_tensor_applier(fused_lamb_CUDA.multi_tensor_lamb_out,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, m_16, v_16, dst_param_fp16],
                                     group['lr'], beta1, beta2, group['eps'],
                                     group['step'], bias_correction,
                                     group['weight_decay'], grad_averaging,
                                     self.adam_w_mode, global_grad_norm,
                                     max_grad_norm, self.use_nvlamb, found_inf,
                                     inv_scale)
            if g_32:
                print("Float grads")
                multi_tensor_applier(fused_lamb_CUDA.multi_tensor_lamb_out,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, m_32, v_32, dst_param_fp32],
                                     group['lr'], beta1, beta2, group['eps'],
                                     group['step'], bias_correction,
                                     group['weight_decay'], grad_averaging,
                                     self.adam_w_mode, global_grad_norm,
                                     max_grad_norm, self.use_nvlamb, found_inf,
                                     inv_scale)

        return loss
Beispiel #10
0
    def step(self, closure=None, grad_scaler=None):
        loss = None
        if closure is not None:
            loss = closure()

        # create separate grad lists for fp32 and fp16 params
        g_all_32, g_all_16 = [], []
        for gid, (group, fp32_group) in enumerate(
                zip(self.param_groups, self.param_groups_fp32)):
            for pid, (p, fp32_p) in enumerate(
                    zip(group['params'], fp32_group['params'])):
                if p.grad is None:
                    continue
                assert p.dtype in (torch.float16, torch.float32)
                if p.dtype == torch.float32:
                    g_all_32.append(p.grad)
                else:  # p.dtype == torch.float16:
                    g_all_16.append(p.grad)
        device = self.param_groups[0]["params"][0].device
        found_inf = (grad_scaler._check_inf_per_device(self)[device]
                     if grad_scaler is not None else torch.zeros(
                         (1, ), device=device))
        self._dummy_overflow_buf.copy_(found_inf)
        scale, inv_scale = None, None
        if grad_scaler:
            scale = grad_scaler._get_scale_async()
            inv_scale = scale.double().reciprocal().float()
        else:
            scale = torch.ones((1, ), device=device)
            inv_scale = torch.ones((1, ), device=device)
        # g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
        g_norm_32, g_norm_16 = None, None
        # compute grad norm for two lists
        # NOTE(mkozuki): g_all_16, g_all_32, and global_grad_norm are norms of scaled gradients.
        # So, multiply `max_grad_norm` by scale.
        max_grad_norm = self.defaults['max_grad_norm'] * scale
        if len(g_all_32) > 0:
            g_norm_32 = multi_tensor_applier(
                fused_lamb_CUDA.multi_tensor_l2norm,
                self._dummy_overflow_buf,
                [g_all_32],
                False,
            )[0]
        else:
            g_norm_32 = torch.zeros((1, ), device=device)
        if len(g_all_16) > 0:
            g_norm_16 = multi_tensor_applier(
                fused_lamb_CUDA.multi_tensor_l2norm,
                self._dummy_overflow_buf,
                [g_all_16],
                False,
            )[0]
        else:
            g_norm_16 = torch.zeros((1, ), device=device)

        # blend two grad norms to get global grad norm
        global_grad_norm = multi_tensor_applier(
            fused_lamb_CUDA.multi_tensor_l2norm,
            self._dummy_overflow_buf,
            [[g_norm_32, g_norm_16]],
            False,
        )[0]

        # Run LAMB optimization math
        for gid, (group, fp32_group) in enumerate(
                zip(self.param_groups, self.param_groups_fp32)):
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']
            grad_averaging = 1 if group['grad_averaging'] else 0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += (self._dummy_overflow_buf != 1).int()
            else:
                group['step'] = (self._dummy_overflow_buf != 1).int()

            # create lists for multi-tensor apply
            g_16, p_16, m_16, v_16, dst_param_fp16 = [], [], [], [], []
            g_32, p_32, m_32, v_32 = [], [], [], []

            for p, p_fp32 in zip(group['params'], fp32_group['params']):
                if p.grad is None:
                    continue
                assert not p.grad.is_sparse

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    dtype = torch.float if p.dtype == torch.half else p.dtype
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data, dtype=dtype)
                    # Exponential moving average of gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=dtype)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p_fp32.data)
                    m_16.append(state['exp_avg'])
                    v_16.append(state['exp_avg_sq'])
                    dst_param_fp16.append(p.data)
                elif p.dtype == torch.float32:
                    assert p_fp32 is None
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                    v_32.append(state['exp_avg_sq'])
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

            if g_16:
                multi_tensor_applier(fused_lamb_CUDA.multi_tensor_lamb,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, m_16, v_16, dst_param_fp16],
                                     group['lr'], beta1, beta2, group['eps'],
                                     group['step'], bias_correction,
                                     group['weight_decay'], grad_averaging,
                                     self.adam_w_mode, global_grad_norm,
                                     max_grad_norm, self.use_nvlamb, found_inf,
                                     inv_scale)
            if g_32:
                multi_tensor_applier(fused_lamb_CUDA.multi_tensor_lamb,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, m_32, v_32], group['lr'],
                                     beta1, beta2, group['eps'], group['step'],
                                     bias_correction, group['weight_decay'],
                                     grad_averaging, self.adam_w_mode,
                                     global_grad_norm, max_grad_norm,
                                     self.use_nvlamb, found_inf, inv_scale)

        return loss
Beispiel #11
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']
            grad_averaging = 1 if group['grad_averaging'] else 0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16 = [], [], []
            g_32, p_32, m_32 = [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError('FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    m_16.append(state['exp_avg'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                else:
                    raise RuntimeError('FusedNovoGrad only support fp16 and fp32.')

            # we store per weight norm as one tensor for one group/precision combination
            # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
            if 'exp_avg_sq' not in group:
                group['exp_avg_sq'] = [None, None]
                if group['init_zero']:
                    # Creating the following parameters on the same device as the params tensors.
                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
                else: # init with first step norm, so first blend have no effect
                    if group['norm_type'] == 0:
                        v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]
                        v_32 = [torch.max(torch.abs(g)).item() for g in g_32]
                    elif group['norm_type'] == 2:
                        v_16 = [torch.sum(torch.pow(g.to(torch.float32), 2)).sqrt().item() for g in g_16]
                        v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]
                    else:
                        raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')
                    # Creating the following parameters on the same device as the params tensors.
                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0]["params"][0].device)
                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0]["params"][0].device)
            else:
                assert(len(g_16) == group['exp_avg_sq'][0].numel())
                assert(len(g_32) == group['exp_avg_sq'][1].numel())

            if(len(g_16) > 0):
                multi_tensor_applier(self.multi_tensor_novograd,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, m_16],
                                     group['exp_avg_sq'][0],
                                     group['lr'],
                                     beta1,
                                     beta2,
                                     group['eps'],
                                     group['step'],
                                     bias_correction,
                                     group['weight_decay'],
                                     grad_averaging,
                                     self.moment_mode,
                                     group['norm_type'])
            if(len(g_32) > 0):
                multi_tensor_applier(self.multi_tensor_novograd,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, m_32],
                                     group['exp_avg_sq'][1],
                                     group['lr'],
                                     beta1,
                                     beta2,
                                     group['eps'],
                                     group['step'],
                                     bias_correction,
                                     group['weight_decay'],
                                     grad_averaging,
                                     self.moment_mode,
                                     group['norm_type'])


        return loss
Beispiel #12
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']
            grad_averaging = 1 if group['grad_averaging'] else 0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16, v_16 = [], [], [], []
            g_32, p_32, m_32, v_32 = [], [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError(
                        'FusedLAMB does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    m_16.append(state['exp_avg'])
                    v_16.append(state['exp_avg_sq'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                    v_32.append(state['exp_avg_sq'])
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

            if (len(g_16) > 0):
                multi_tensor_applier(self.multi_tensor_lamb,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, m_16, v_16], group['lr'],
                                     beta1, beta2, group['eps'], group['step'],
                                     bias_correction, group['weight_decay'],
                                     grad_averaging, self.adam_w_mode,
                                     group['max_grad_norm'])
            if (len(g_32) > 0):
                multi_tensor_applier(self.multi_tensor_lamb,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, m_32, v_32], group['lr'],
                                     beta1, beta2, group['eps'], group['step'],
                                     bias_correction, group['weight_decay'],
                                     grad_averaging, self.adam_w_mode,
                                     group['max_grad_norm'])

        return loss
    def step(self,
             closure=None,
             grads=None,
             output_params=None,
             scale=1.,
             grad_norms=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output_params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        if hasattr(self, "_amp_stash"):
            raise RuntimeError(
                'apex.contrib.optimizers.FusedSGD should not be used with AMP.'
            )

        loss = None
        if closure is not None:
            loss = closure()

        if grads is None:
            raise RuntimeError(
                'apex.contrib.optimizers.FusedSGD must be wrapped \
	                       with apex.contrib.optimizers.FP16_Optimizer \
			       which provides grads.')
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        if output_params is None:
            raise RuntimeError(
                'apex.contrib.optimizers.FusedSGD must be wrapped \
                               with apex.contrib.optimizers.FP16_Optimizer \
                               which provides output_params.')
        elif isinstance(output_params, types.GeneratorType):
            output_params_group = [output_params]
        elif type(output_params[0]) != list:
            output_params_group = [output_params]
        else:
            output_params_group = output_params

        for group, grads_this_group, output_params_this_group in zip(
                self.param_groups, grads_group, output_params_group):
            if grads_this_group is None or output_params_this_group is None:
                raise RuntimeError(
                    'apex.contrib.optimizers.FusedSGD only works \
                                    when all parameters require grad.')

            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']

            first_runs = [True, True]

            # output_params_this_group: original weights (either fp16 or fp32)
            # group['params']: master weights (fp32)

            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
            # fp32, fp32, fp32, No
            fp32_grads = [
                g for (p, g) in zip(output_params_this_group, grads_this_group)
                if p.dtype == torch.float32
            ]
            fp32_params = [
                p2
                for (p1, p2) in zip(output_params_this_group, group['params'])
                if p1.dtype == torch.float32
            ]
            fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
            fp32_set = [fp32_grads, fp32_params, fp32_momentums]

            # fp16, fp32, fp32, Yes
            fp16_grads = [
                g for (p, g) in zip(output_params_this_group, grads_this_group)
                if p.dtype == torch.float16
            ]
            fp32_from_fp16_params = [
                p2
                for (p1, p2) in zip(output_params_this_group, group['params'])
                if p1.dtype == torch.float16
            ]
            fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(
                fp32_from_fp16_params)
            fp16_params = [
                p1
                for (p1, p2) in zip(output_params_this_group, group['params'])
                if p1.dtype == torch.float16
            ]
            fp16_set = [
                fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums,
                fp16_params
            ]

            launch_sets = [fp16_set, fp32_set]

            for launch_set, first_run in zip(launch_sets, first_runs):
                assert len(launch_set[0]) == len(launch_set[1])
                assert len(launch_set[0]) == len(launch_set[2])
                if len(launch_set[0]) > 0:
                    multi_tensor_applier(self.multi_tensor_sgd,
                                         self._dummy_overflow_buf, launch_set,
                                         weight_decay, momentum, dampening, lr,
                                         nesterov, first_run,
                                         self.wd_after_momentum, 1.0 / scale)

        return loss
    def step(self, loss, optimizer, scheduler, update=True):
        """
        Performs one step of the optimizer.
        Applies loss scaling, computes gradients in fp16, converts gradients to
        fp32, inverts scaling and applies optional gradient norm clipping.
        If gradients are finite, it applies update to fp32 master weights and
        copies updated parameters to fp16 model for the next iteration. If
        gradients are not finite, it skips the batch and adjusts scaling factor
        for the next iteration.

        :param loss: value of loss function
        :param optimizer: optimizer
        :param update: if True executes weight update
        """
        loss *= self.loss_scale
        loss.backward()

        if not update:  return

        # Average the all-reduced gradients by world size if APEX
        # doesn't do that
        scaling_factor = self.loss_scale
        if hasattr(self.fp16_model, 'gradient_average') and \
                not self.fp16_model.gradient_average:
            scaling_factor *= self.world_size

        # APEX DDP reset the gradients to be views into allreduce_buffers
        # So downstream code should simply be able to use the .grad
        # attributes as usual
        if isinstance(optimizer, FusedAdam):
            if self.world_size != 1 and self.fp16_model.retain_allreduce_buffers:
                grads = [p.grad for p in self.fp16_params]
                norm, _ = multi_tensor_applier(
                        multi_tensor_l2norm,
                        self.dummy_overflow_buf,
                        [grads],
                        False)
                norm = norm.item() / scaling_factor
            else:
                self.fp16_to_fp16_flat_grad(self.fp16_params, self.fp16_model)
                grads = [self.fp16_params.grad]
                norm = self.fp16_params.grad.data.norm(p=2,
                    dtype=torch.float).item() / scaling_factor
        else:
            self.fp16_to_fp32_flat_grad(self.fp32_params, self.fp16_model)
            if scaling_factor != 1.0:
                self.fp32_params.grad.data /= scaling_factor

            norm = clip_grad_norm_([self.fp32_params], self.grad_clip)

        if math.isfinite(norm):
            if scheduler is not None:
                scheduler.step()

            if isinstance(optimizer, FusedAdam):
                clip_coef = self.grad_clip / (norm + 1e-6)
                clip_coef = scaling_factor / min(1, clip_coef)
                if self.use_mt:
                    optimizer.step(grads=grads, output_params=self.fp16_params, scale=clip_coef)
                else:
                    optimizer.step(grads=grads, scale=clip_coef)
            else:
                optimizer.step()

            # Unflatten params if not multi-tensor apply
            if not self.use_mt:
                self.fp32_to_fp16_params(self.fp16_model, self.fp32_params)
            self.since_last_invalid += 1
        else:
            self.loss_scale /= self.dls_downscale
            self.since_last_invalid = 0
            logging.info(f'Gradient norm: {norm}')
            logging.info(f'Skipped batch, new scale: {self.loss_scale}')

        if self.since_last_invalid >= self.dls_upscale_interval:
            self.loss_scale *= self.dls_upscale
            self.loss_scale = min(self.loss_scale, 8192.0)
            logging.info(f'Upscaling, new scale: {self.loss_scale}')
            self.since_last_invalid = 0

        for p in self.fp16_model.parameters():
            p.grad = None
Beispiel #15
0
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        grad_list = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad_list.append(p)

        dummy_overflow_buf = torch.cuda.IntTensor([0])
        global_grad_norm = multi_tensor_applier(multi_tensor_l2norm,
                                                dummy_overflow_buf,
                                                [grad_list], False)[0].item()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['next_m'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['next_v'] = torch.zeros_like(p.data)

                next_m, next_v = state['next_m'], state['next_v']
                beta1, beta2 = group['b1'], group['b2']

                # Add grad clipping
                if global_grad_norm > group['max_grad_norm']:
                    p = p * group['max_grad_norm'] / global_grad_norm

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                next_m.mul_(beta1).add_(1 - beta1, grad)
                next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                update = next_m / (next_v.sqrt() + group['e'])

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if group['weight_decay'] > 0.0:
                    update += group['weight_decay'] * p.data

                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(
                        state['step'] / group['t_total'], group['warmup'])
                else:
                    lr_scheduled = group['lr']

                update_with_lr = lr_scheduled * update
                p.data.add_(-update_with_lr)

                state['step'] += 1

        return loss
Beispiel #16
0
    def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.

        The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
        """
        if any(p is not None for p in [grads, output_params, scale, grad_norms]):
            raise RuntimeError('FusedAdam has been updated.  Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16, v_16 = [], [], [], []
            g_32, p_32, m_32, v_32 = [], [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    m_16.append(state['exp_avg'])
                    v_16.append(state['exp_avg_sq'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                    v_32.append(state['exp_avg_sq'])
                else:
                    raise RuntimeError('FusedAdam only support fp16 and fp32.')

            if(len(g_16) > 0):
                multi_tensor_applier(self.multi_tensor_adam,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, m_16, v_16],
                                     group['lr'],
                                     beta1,
                                     beta2,
                                     group['eps'],
                                     group['step'],
                                     self.adam_w_mode,
                                     bias_correction,
                                     group['weight_decay'])
            if(len(g_32) > 0):
                multi_tensor_applier(self.multi_tensor_adam,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, m_32, v_32],
                                     group['lr'],
                                     beta1,
                                     beta2,
                                     group['eps'],
                                     group['step'],
                                     self.adam_w_mode,
                                     bias_correction,
                                     group['weight_decay'])


        return loss