Beispiel #1
0
    def step(self, gradients: List[Optional[Tensor]]):
        params = self.param_group['params']
        params_with_grad = []
        grads = []
        exp_avgs = []
        exp_avg_sqs = []
        max_exp_avg_sqs = []
        state_steps: List[Tensor] = []

        if len(params) != len(gradients):
            raise ValueError(
                "the gradients passed in does not equal to the size of the parameters!"
                + f"Params length: {len(params)}. " +
                f"Gradients length: {len(gradients)}")

        for param, gradient in zip(self.param_group['params'], gradients):
            if gradient is not None:
                params_with_grad.append(param)
                grads.append(gradient)
                # Lazy state initialization
                if param not in self.state:
                    self.state[param] = {}
                    state = self.state[param]
                    state['step'] = torch.tensor(0.0)
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(
                        param, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(
                        param, memory_format=torch.preserve_format)
                    if self.amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(
                            param, memory_format=torch.preserve_format)

                state = self.state[param]

                exp_avgs.append(state['exp_avg'])
                exp_avg_sqs.append(state['exp_avg_sq'])

                if self.amsgrad:
                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                state_steps.append(state['step'])

        with torch.no_grad():
            F.adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   amsgrad=self.amsgrad,
                   maximize=self.maximize,
                   beta1=self.defaults['beta1'],
                   beta2=self.defaults['beta2'],
                   lr=self.defaults['lr'],
                   weight_decay=self.defaults['weight_decay'],
                   eps=self.defaults['eps'],
                   foreach=self.foreach)
    def step_param(self, param: Tensor, grad: Optional[Tensor]):
        """
        Similar to step, but operates on a single parameter and optionally a
        gradient tensor.
        """
        params = [param]
        params_with_grad = []
        grads = []
        exp_avgs = []
        exp_avg_sqs = []
        max_exp_avg_sqs = []
        state_steps: List[int] = []
        if grad is not None:
            params_with_grad.append(param)
            grads.append(grad)
        if param not in self.state:
            self.state[param] = {}
            state = self.state[param]
            state['step'] = torch.tensor(0.0)
            state['exp_avg'] = torch.zeros_like(
                param, memory_format=torch.preserve_format)
            state['exp_avg_sq'] = torch.zeros_like(
                param, memory_format=torch.preserve_format)
            if self.amsgrad:
                state['max_exp_avg_sq'] = torch.zeros_like(
                    param, memory_format=torch.preserve_format)

        state = self.state[param]
        exp_avgs.append(state['exp_avg'])
        exp_avg_sqs.append(state['exp_avg_sq'])

        if self.amsgrad:
            max_exp_avg_sqs.append(state['max_exp_avg_sq'])

        # update the steps for each param group update
        state['step'] += 1
        # record the step after step update
        state_steps.append(state['step'].item())
        with torch.no_grad():
            F.adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   amsgrad=self.amsgrad,
                   beta1=self.defaults['beta1'],
                   beta2=self.defaults['beta2'],
                   lr=self.defaults['lr'],
                   weight_decay=self.defaults['weight_decay'],
                   eps=self.defaults['eps'])