Пример #1
0
    def step(self, closure=None, skip_overflow_check=False):
        loss = None
        if closure is not None:
            loss = closure()

        if self._last_step or not self._overlap_reductions or not self._full_pipeline:
            self._pipeline_step()

        with torch.cuda.stream(self._completion_st):
            # Check for overflow
            # Store state for loss scaler calculation
            has_overflow = False if skip_overflow_check else self.strided_check_finite(
                self._new_params,
                stride=self._shard_size,
                start=0,
                end=self._net_total_param_size)
            if has_overflow:
                self.revert_step()
            else:
                # Copy self._new_params to model params
                for p in self._model_params:
                    self.state[p]['step'] += 1
                multi_tensor_applier(fused_adam_cuda.maybe_cast_mt,
                                     self._overflow_buf,
                                     self._packed_flat_to_model_params)

        torch.cuda.current_stream().wait_stream(self._completion_st)

        self._reductions_works = [None] * self._num_blocks
        self._allgather_works = [None] * self._num_blocks

        return loss
Пример #2
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            # create lists for multi-tensor apply
            g_16, p_16, h_16 = [], [], []
            g_32, p_32, h_32 = [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError('FusedAdagrad does not support sparse gradients')

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['sum'] = torch.zeros_like(p.data)
                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    h_16.append(state['sum'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    h_32.append(state['sum'])
                else:
                    raise RuntimeError('FusedAdagrad only support fp16 and fp32.')

            if(len(g_16) > 0):
                multi_tensor_applier(self.multi_tensor_adagrad,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, h_16],
                                     group['lr'],
                                     group['eps'],
                                     self.adagrad_w_mode,
                                     group['weight_decay'])
            if(len(g_32) > 0):
                multi_tensor_applier(self.multi_tensor_adagrad,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, h_32],
                                     group['lr'],
                                     group['eps'],
                                     self.adagrad_w_mode,
                                     group['weight_decay'])

        return loss
Пример #3
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        explicit_master_params = (hasattr(self, "_amp_stash") and
                                  hasattr(self._amp_stash, "fp32_from_fp16_groups"))

        for gid, group in enumerate(self.param_groups):
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']


            # For each group, there are 3 possible combinations we need to consider:
            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
            # 1. fp16, fp16, fp16, No
            # 2. fp32, fp32, fp32, No
            # 3. fp16, fp32, fp32, Yes

            first_runs = [True, True]

            # I think a bit of code divergence in exchange for naming clarity is worthwhile
            if explicit_master_params:
                stash = self._amp_stash

                fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
                fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)

                if self.materialize_master_grads:
                    fp16_model_params = [p for i, p in enumerate(
                        stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
                    fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
                    fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)

                    fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
                                fp32_from_fp16_momentums, fp16_model_params]
                else:
                    fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
                    fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
                    fp32_from_fp16_params = [p for i, p in enumerate(
                        stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)

                    fp16_set = [fp16_model_grads, fp32_from_fp16_params,
                                fp32_from_fp16_momentums, fp16_model_params]

                launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
            else:
                fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
                fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
                fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)

                fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
                fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)

                launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
                               [fp32_grads, fp32_params, fp32_momentums]]

            for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
                assert len(launch_set[0]) == len(launch_set[1])
                assert len(launch_set[0]) == len(launch_set[2])
                if len(launch_set[0]) > 0:
                    multi_tensor_applier(
                        self.multi_tensor_sgd,
                        self._dummy_overflow_buf,
                        launch_set,
                        weight_decay,
                        momentum,
                        dampening,
                        group['lr'],
                        nesterov,
                        first_run,
                        self.wd_after_momentum,
                        1.0/self.most_recent_scale)

        self.most_recent_scale = 1.0
        self.scale_set_by_backward = False

        return loss
Пример #4
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']
            grad_averaging = 1 if group['grad_averaging'] else 0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16 = [], [], []
            g_32, p_32, m_32 = [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError(
                        'FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    m_16.append(state['exp_avg'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                else:
                    raise RuntimeError(
                        'FusedNovoGrad only support fp16 and fp32.')

            # we store per weight norm as one tensor for one group/precision combination
            # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
            if 'exp_avg_sq' not in group:
                group['exp_avg_sq'] = [None, None]
                if group['init_zero']:
                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(
                        len(g_16)).contiguous().fill_(0)
                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(
                        len(g_32)).contiguous().fill_(0)
                else:  # init with first step norm, so first blend have no effect
                    if group['norm_type'] == 0:
                        v_16 = [
                            torch.max(torch.abs(g.to(torch.float32))).item()
                            for g in g_16
                        ]
                        v_32 = [torch.max(torch.abs(g)).item() for g in g_32]
                    elif group['norm_type'] == 2:
                        v_16 = [
                            torch.sum(torch.pow(g.to(torch.float32),
                                                2)).sqrt().item() for g in g_16
                        ]
                        v_32 = [
                            torch.sum(torch.pow(g, 2)).sqrt().item()
                            for g in g_32
                        ]
                    else:
                        raise RuntimeError(
                            'FusedNovoGrad only support l2/inf norm now.')
                    group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16)
                    group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32)
            else:
                assert (len(g_16) == group['exp_avg_sq'][0].numel())
                assert (len(g_32) == group['exp_avg_sq'][1].numel())

            if (len(g_16) > 0):
                multi_tensor_applier(
                    self.multi_tensor_novograd, self._dummy_overflow_buf,
                    [g_16, p_16, m_16], group['exp_avg_sq'][0], group['lr'],
                    beta1, beta2, group['eps'], group['step'], bias_correction,
                    group['weight_decay'], grad_averaging, self.moment_mode,
                    group['norm_type'])
            if (len(g_32) > 0):
                multi_tensor_applier(
                    self.multi_tensor_novograd, self._dummy_overflow_buf,
                    [g_32, p_32, m_32], group['exp_avg_sq'][1], group['lr'],
                    beta1, beta2, group['eps'], group['step'], bias_correction,
                    group['weight_decay'], grad_averaging, self.moment_mode,
                    group['norm_type'])

        return loss
Пример #5
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        # create separate grad lists for fp32 and fp16 params
        g_all_32, g_all_16 = [], []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                if p.dtype == torch.float32:
                    g_all_32.append(p.grad.data)
                elif p.dtype == torch.float16:
                    g_all_16.append(p.grad.data)
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

        g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(
            1, device='cuda')
        # compute grad norm for two lists
        if len(g_all_32) > 0:
            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_32], False)[0]
        if len(g_all_16) > 0:
            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_16], False)[0]

        # blend two grad norms to get global grad norm
        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
                                                self._dummy_overflow_buf,
                                                [[g_norm_32, g_norm_16]],
                                                False)[0]
        max_grad_norm = self.defaults['max_grad_norm']

        for group in self.param_groups:
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']
            grad_averaging = 1 if group['grad_averaging'] else 0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16, v_16 = [], [], [], []
            g_32, p_32, m_32, v_32 = [], [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError(
                        'FusedLAMB does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    m_16.append(state['exp_avg'])
                    v_16.append(state['exp_avg_sq'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                    v_32.append(state['exp_avg_sq'])
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

            if (len(g_16) > 0):
                multi_tensor_applier(
                    self.multi_tensor_lamb, self._dummy_overflow_buf,
                    [g_16, p_16, m_16, v_16], group['lr'], beta1, beta2,
                    group['eps'], group['step'], bias_correction,
                    group['weight_decay'], grad_averaging, self.adam_w_mode,
                    global_grad_norm, max_grad_norm, self.use_nvlamb)
            if (len(g_32) > 0):
                multi_tensor_applier(
                    self.multi_tensor_lamb, self._dummy_overflow_buf,
                    [g_32, p_32, m_32, v_32], group['lr'], beta1, beta2,
                    group['eps'], group['step'], bias_correction,
                    group['weight_decay'], grad_averaging, self.adam_w_mode,
                    global_grad_norm, max_grad_norm, self.use_nvlamb)

        return loss
Пример #6
0
 def _flatten_grad_mt(self, scale):
     if self._flat_mt and len(self._grads) > 0:
         self._overflow_buf.zero_()
         multi_tensor_applier(amp_C.multi_tensor_scale, self._overflow_buf,
                              list(zip(*self._grads)), scale)
         self._grads = []
Пример #7
0
    def step(self,
             closure=None,
             grads=None,
             output_params=None,
             scale=1.,
             grad_norms=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output_params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        if hasattr(self, "_amp_stash"):
            raise RuntimeError(
                'apex.contrib.optimizers.FusedSGD should not be used with AMP.'
            )

        loss = None
        if closure is not None:
            loss = closure()

        if grads is None:
            raise RuntimeError(
                'apex.contrib.optimizers.FusedSGD must be wrapped \
	                       with apex.contrib.optimizers.FP16_Optimizer \
			       which provides grads.')
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        if output_params is None:
            raise RuntimeError(
                'apex.contrib.optimizers.FusedSGD must be wrapped \
                               with apex.contrib.optimizers.FP16_Optimizer \
                               which provides output_params.')
        elif isinstance(output_params, types.GeneratorType):
            output_params_group = [output_params]
        elif type(output_params[0]) != list:
            output_params_group = [output_params]
        else:
            output_params_group = output_params

        for group, grads_this_group, output_params_this_group in zip(
                self.param_groups, grads_group, output_params_group):
            if grads_this_group is None or output_params_this_group is None:
                raise RuntimeError(
                    'apex.contrib.optimizers.FusedSGD only works \
                                    when all parameters require grad.')

            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']

            first_runs = [True, True]

            # output_params_this_group: original weights (either fp16 or fp32)
            # group['params']: master weights (fp32)

            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
            # fp32, fp32, fp32, No
            fp32_grads = [
                g for (p, g) in zip(output_params_this_group, grads_this_group)
                if p.dtype == torch.float32
            ]
            fp32_params = [
                p2
                for (p1, p2) in zip(output_params_this_group, group['params'])
                if p1.dtype == torch.float32
            ]
            fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
            fp32_set = [fp32_grads, fp32_params, fp32_momentums]

            # fp16, fp32, fp32, Yes
            fp16_grads = [
                g for (p, g) in zip(output_params_this_group, grads_this_group)
                if p.dtype == torch.float16
            ]
            fp32_from_fp16_params = [
                p2
                for (p1, p2) in zip(output_params_this_group, group['params'])
                if p1.dtype == torch.float16
            ]
            fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(
                fp32_from_fp16_params)
            fp16_params = [
                p1
                for (p1, p2) in zip(output_params_this_group, group['params'])
                if p1.dtype == torch.float16
            ]
            fp16_set = [
                fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums,
                fp16_params
            ]

            launch_sets = [fp16_set, fp32_set]

            for launch_set, first_run in zip(launch_sets, first_runs):
                assert len(launch_set[0]) == len(launch_set[1])
                assert len(launch_set[0]) == len(launch_set[2])
                if len(launch_set[0]) > 0:
                    multi_tensor_applier(self.multi_tensor_sgd,
                                         self._dummy_overflow_buf, launch_set,
                                         weight_decay, momentum, dampening, lr,
                                         nesterov, first_run,
                                         self.wd_after_momentum, 1.0 / scale)

        return loss
Пример #8
0
    def step(self,
             closure=None,
             grads=None,
             output_params=None,
             scale=None,
             grad_norms=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.

        The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
        """
        if any(p is not None
               for p in [grads, output_params, scale, grad_norms]):
            raise RuntimeError(
                'FusedAdam has been updated.  Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
            )
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            bias_correction = 1 if group['bias_correction'] else 0
            beta1, beta2 = group['betas']

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            # create lists for multi-tensor apply
            g_16, p_16, m_16, v_16 = [], [], [], []
            g_32, p_32, m_32, v_32 = [], [], [], []

            for p in group['params']:
                if p.grad is None:
                    continue
                if p.grad.data.is_sparse:
                    raise RuntimeError(
                        'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if p.dtype == torch.float16:
                    g_16.append(p.grad.data)
                    p_16.append(p.data)
                    m_16.append(state['exp_avg'])
                    v_16.append(state['exp_avg_sq'])
                elif p.dtype == torch.float32:
                    g_32.append(p.grad.data)
                    p_32.append(p.data)
                    m_32.append(state['exp_avg'])
                    v_32.append(state['exp_avg_sq'])
                else:
                    raise RuntimeError('FusedAdam only support fp16 and fp32.')

            if (len(g_16) > 0):
                multi_tensor_applier(self.multi_tensor_adam,
                                     self._dummy_overflow_buf,
                                     [g_16, p_16, m_16, v_16], group['lr'],
                                     beta1, beta2, group['eps'], group['step'],
                                     self.adam_w_mode, bias_correction,
                                     group['weight_decay'])
            if (len(g_32) > 0):
                multi_tensor_applier(self.multi_tensor_adam,
                                     self._dummy_overflow_buf,
                                     [g_32, p_32, m_32, v_32], group['lr'],
                                     beta1, beta2, group['eps'], group['step'],
                                     self.adam_w_mode, bias_correction,
                                     group['weight_decay'])

        return loss
Пример #9
0
    def step(self,
             closure=None,
             grads=None,
             output_params=None,
             scale=1.,
             grad_norms=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        if hasattr(self, "_amp_stash"):
            grads = self._amp_stash.grads
            output_params = self._amp_stash.output_params
            scale = self._amp_stash.scale * self._amp_scale_adjustment
            grad_norms = self._amp_stash.grad_norms

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        if output_params is None:
            output_params_group = [None] * len(self.param_groups)
        elif isinstance(output_params, types.GeneratorType):
            output_params_group = [output_params]
        elif type(output_params[0]) != list:
            output_params_group = [output_params]
        else:
            output_params_group = output_params

        if grad_norms is None:
            grad_norms = [None] * len(self.param_groups)

        for group, grads_this_group, output_params_this_group, grad_norm in zip(
                self.param_groups, grads_group, output_params_group,
                grad_norms):
            if grads_this_group is None:
                grads_this_group = [None] * len(group['params'])
            if output_params_this_group is None:
                output_params_this_group = [None] * len(group['params'])

            # compute combined scale factor for this group
            combined_scale = scale
            if group['max_grad_norm'] > 0:
                # norm is in fact norm*scale
                clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
                if clip > 1:
                    combined_scale = clip * scale

            bias_correction = 1 if group['bias_correction'] else 0

            if self._use_multi_tensor:
                if output_params:
                    tensorlists = [[], [], [], [], []]
                else:
                    tensorlists = [[], [], [], []]
                tensordevice = None

            for p, grad, output_param in zip(group['params'], grads_this_group,
                                             output_params_this_group):
                #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
                    )

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                out_p = torch.tensor(
                    [], dtype=torch.float
                ) if output_param is None else output_param
                if self._use_multi_tensor:
                    pl = [p.data, exp_avg, exp_avg_sq, grad]
                    if output_param is not None:
                        pl.append(out_p)

                    for tl, t in zip(tensorlists, pl):
                        tl.append(t)

                    if tensordevice is None:
                        tensordevice = p.device
                    elif tensordevice != p.device:
                        raise RuntimeError(
                            'FusedAdam does not support use_mt with tensors on multiple device'
                        )

                else:
                    with torch.cuda.device(p.device):
                        fused_adam_cuda.adam(p.data, out_p, exp_avg,
                                             exp_avg_sq, grad, group['lr'],
                                             beta1, beta2, group['eps'],
                                             combined_scale, state['step'],
                                             self.eps_mode, bias_correction,
                                             group['weight_decay'])

            if self._use_multi_tensor:
                with torch.cuda.device(tensordevice):
                    multi_tensor_applier(
                        fused_adam_cuda.adam_mt, self._overflow_buf,
                        tensorlists, group['lr'], beta1, beta2, group['eps'],
                        combined_scale, state['step'], self.eps_mode,
                        bias_correction, group['weight_decay'])

        return loss