Exemplo n.º 1
0
    def test_standardize(self):
        standardize = Standardize(eps=1.1920929e-07)

        # scalar not allowed
        with pytest.raises(AssertionError):
            standardize(1.2)

        def _test_vec(x):
            assert np.allclose(
                standardize(x=x),
                [-1.3416406, -0.44721353, 0.44721353, 1.3416406])
            assert np.allclose(standardize(x=x).mean(), 0.0)
            assert np.allclose(standardize(x=x).std(), 1.0)

            assert np.allclose(standardize(x=x, mean=0, std=1),
                               [0.9999999, 1.9999998, 2.9999998, 3.9999995])

        # list
        _test_vec([1, 2, 3, 4])

        # ndarray
        _test_vec(np.array([1, 2, 3, 4]))

        # batched data
        out = standardize([[1, 2], [3, 2]])
        assert out.dtype == np.float32
        assert np.allclose(out, [[-1, 0], [1, 0]])
        assert np.allclose(out.mean(0), 0.0)

        # from outside data
        out = standardize([1, 2, 3, 4], mean=0, std=1)
        assert out.dtype == np.float32
        assert np.allclose(out, [1, 2, 3, 4])
    def test_standardize(self):
        standardize = Standardize(eps=1.1920929e-07)

        # Test scalar
        assert standardize(x=-1) == -1
        assert standardize(x=0) == 0
        assert standardize(x=1) == 1

        assert standardize(x=-1, mean=0, std=1) == -1
        assert standardize(x=0, mean=0, std=1) == 0
        assert standardize(x=1, mean=0, std=1) == 1

        # Test numpy scalar
        assert standardize(x=np.array(-1)) == -1
        assert standardize(x=np.array(0)) == 0
        assert standardize(x=np.array(1)) == 1

        assert standardize(x=np.array(-1), mean=0, std=1) == -1
        assert standardize(x=np.array(0), mean=0, std=1) == 0
        assert standardize(x=np.array(1), mean=0, std=1) == 1

        #
        # Test vector
        #
        def _test_vec(x):
            assert np.allclose(
                standardize(x=x),
                [-1.34164064, -0.44721355, 0.44721355, 1.34164064])

            assert np.allclose(
                standardize(x=x, mean=0, std=1),
                [0.99999988, 1.99999976, 2.99999964, 3.99999952])

        # Tuple
        a = (1, 2, 3, 4)
        _test_vec(a)

        # List
        b = [1, 2, 3, 4]
        _test_vec(b)

        # ndarray
        c = np.array([1, 2, 3, 4])
        _test_vec(c)

        #
        # Test exceptions
        #
        # ndarray more than 1-dim is not allowed
        d = np.array([[1, 2, 3, 4]])
        with pytest.raises(ValueError):
            standardize(x=d)
Exemplo n.º 3
0
    def tell(self, solutions, function_values):
        # Enforce ndarray of function values
        function_values = np.array(function_values)
        if self.rank_transform:
            # Make a copy of original function values, for recording true values
            original_function_values = np.copy(function_values)
            # Use centered ranks instead of raw values, combat with outliers.
            function_values = self.rank_transformer(function_values,
                                                    centered=True)

        # Make some results
        # Sort function values and select the minimum, since we are minimizing the objective.
        idx = np.argsort(function_values)[0]  # argsort is in ascending order
        self.best_param = solutions[idx]
        if self.rank_transform:  # use rank transform, we should record the original function values
            self.best_f_val = original_function_values[idx]
        else:
            self.best_f_val = function_values[idx]
        # Update the historical best result
        first_iteration = self.hist_best_param is None or self.hist_best_f_val is None
        if first_iteration or self.best_f_val < self.hist_best_f_val:
            self.hist_best_f_val = self.best_f_val
            self.hist_best_param = self.best_param

        # Compute gradient from original paper
        # Enforce fitness as Gaussian distributed, here we use centered ranks
        standardize = Standardize()
        F = standardize(function_values)
        # Compute gradient, F:[popsize], eps: [popsize, num_params]
        grad = (1 / self.std) * np.mean(np.expand_dims(F, 1) * self.eps,
                                        axis=0)
        grad = torch.from_numpy(grad).float()
        # Update the gradient to mu
        self.mu.grad = grad
        # Decay learning rate with lr scheduler
        self.lr_scheduler.step()
        # Take a gradient step
        self.optimizer.step()

        # Adaptive std
        if self.std > self.min_std:
            self.std = self.std_decay * self.std
Exemplo n.º 4
0
    def learn(self, D, info={}):
        batch_policy_loss = []
        batch_entropy_loss = []
        batch_total_loss = []

        for trajectory in D:
            logprobs = trajectory.all_info('action_logprob')
            entropies = trajectory.all_info('entropy')
            Qs = trajectory.all_discounted_returns

            # Standardize: encourage/discourage half of performed actions
            if self.config['agent.standardize_Q']:
                Qs = Standardize()(Qs).tolist()

            # Estimate policy gradient for all time steps and record all losses
            policy_loss = []
            entropy_loss = []
            for logprob, entropy, Q in zip(logprobs, entropies, Qs):
                policy_loss.append(-logprob * Q)
                entropy_loss.append(-entropy)

            # Average losses over all time steps
            policy_loss = torch.stack(policy_loss).mean()
            entropy_loss = torch.stack(entropy_loss).mean()

            # Calculate total loss
            entropy_coef = self.config['agent.entropy_coef']
            total_loss = policy_loss + entropy_coef * entropy_loss

            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_total_loss.append(total_loss)

        # Average loss over list of Trajectory
        policy_loss = torch.stack(batch_policy_loss).mean()
        entropy_loss = torch.stack(batch_entropy_loss).mean()
        loss = torch.stack(batch_total_loss).mean()

        # Train with estimated policy gradient
        self.optimizer.zero_grad()
        loss.backward()

        if self.config['agent.max_grad_norm'] is not None:
            clip_grad_norm_(parameters=self.policy.network.parameters(),
                            max_norm=self.config['agent.max_grad_norm'],
                            norm_type=2)

        if hasattr(self, 'lr_scheduler'):
            if 'train.iter' in self.config:  # iteration-based
                self.lr_scheduler.step()
            elif 'train.timestep' in self.config:  # timestep-based
                self.lr_scheduler.step(self.total_T)
            else:
                raise KeyError(
                    'expected `train.iter` or `train.timestep` in config, but got none of them'
                )

        self.optimizer.step()

        # Accumulate trained timesteps
        self.total_T += sum([trajectory.T for trajectory in D])

        out = {}
        out['loss'] = loss.item()
        out['policy_loss'] = policy_loss.item()
        out['entropy_loss'] = entropy_loss.item()
        if hasattr(self, 'lr_scheduler'):
            out['current_lr'] = self.lr_scheduler.get_lr()

        return out
Exemplo n.º 5
0
    def learn(self, D, info={}):
        batch_policy_loss = []
        batch_entropy_loss = []
        batch_value_loss = []
        batch_total_loss = []

        for trajectory in D:
            logprobs = trajectory.all_info('action_logprob')
            entropies = trajectory.all_info('entropy')
            Qs = trajectory.all_discounted_returns

            # Standardize: encourage/discourage half of performed actions
            if self.config['agent.standardize_Q']:
                Qs = Standardize()(Qs).tolist()

            # State values
            Vs = trajectory.all_info('V_s')
            final_V = trajectory.transitions[-1].V_s_next
            final_done = trajectory.transitions[-1].done

            # Advantage estimates
            As = [Q - V.item() for Q, V in zip(Qs, Vs)]
            if self.config['agent.standardize_adv']:
                As = Standardize()(As).tolist()

            # Estimate policy gradient for all time steps and record all losses
            policy_loss = []
            entropy_loss = []
            value_loss = []
            for logprob, entropy, A, Q, V in zip(logprobs, entropies, As, Qs,
                                                 Vs):
                policy_loss.append(-logprob * A)
                entropy_loss.append(-entropy)
                value_loss.append(
                    F.mse_loss(V,
                               torch.tensor(Q).view_as(V).to(V.device)))
            if final_done:  # learn terminal state value as zero
                value_loss.append(
                    F.mse_loss(final_V,
                               torch.tensor(0.0).view_as(V).to(V.device)))

            # Average losses over all time steps
            policy_loss = torch.stack(policy_loss).mean()
            entropy_loss = torch.stack(entropy_loss).mean()
            value_loss = torch.stack(value_loss).mean()

            # Calculate total loss
            entropy_coef = self.config['agent.entropy_coef']
            value_coef = self.config['agent.value_coef']
            total_loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss

            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_value_loss.append(value_loss)
            batch_total_loss.append(total_loss)

        # Average loss over list of Trajectory
        policy_loss = torch.stack(batch_policy_loss).mean()
        entropy_loss = torch.stack(batch_entropy_loss).mean()
        value_loss = torch.stack(batch_value_loss).mean()
        loss = torch.stack(batch_total_loss).mean()

        # Train with estimated policy gradient
        self.optimizer.zero_grad()
        loss.backward()

        if self.config['agent.max_grad_norm'] is not None:
            clip_grad_norm_(parameters=self.policy.network.parameters(),
                            max_norm=self.config['agent.max_grad_norm'],
                            norm_type=2)

        if hasattr(self, 'lr_scheduler'):
            if 'train.iter' in self.config:  # iteration-based
                self.lr_scheduler.step()
            elif 'train.timestep' in self.config:  # timestep-based
                self.lr_scheduler.step(self.total_T)
            else:
                raise KeyError(
                    'expected `train.iter` or `train.timestep` in config, but got none of them'
                )

        self.optimizer.step()

        # Accumulate trained timesteps
        self.total_T += sum([trajectory.T for trajectory in D])

        out = {}
        out['loss'] = loss.item()
        out['policy_loss'] = policy_loss.item()
        out['entropy_loss'] = entropy_loss.item()
        out['value_loss'] = value_loss.item()
        if hasattr(self, 'lr_scheduler'):
            out['current_lr'] = self.lr_scheduler.get_lr()

        return out
Exemplo n.º 6
0
    def learn(self, D):
        batch_policy_loss = []
        batch_entropy_loss = []
        batch_total_loss = []
        
        # Iterate over list of Trajectory in D
        for trajectory in D:
            # Get all discounted returns as estimate of Q
            Qs = trajectory.all_discounted_returns
            # TODO: when use GAE of TDs, really standardize it ? biased magnitude of learned value get wrong TD error
            # Standardize advantage estimates if required
            # encourage/discourage half of performed actions, respectively.
            if self.config['agent:standardize']:
                Qs = Standardize()(Qs)
            
            # Get all log-probabilities and entropies
            logprobs = trajectory.all_info('action_logprob')
            entropies = trajectory.all_info('entropy')
            
            # Estimate policy gradient for all time steps and record all losses
            policy_loss = []
            entropy_loss = []
            for logprob, entropy, Q in zip(logprobs, entropies, Qs):
                policy_loss.append(-logprob*Q)
                entropy_loss.append(-entropy)
                
            # Average over losses for all time steps
            policy_loss = torch.stack(policy_loss).mean()
            entropy_loss = torch.stack(entropy_loss).mean()
            
            # Calculate total loss
            entropy_coef = self.config['agent:entropy_coef']
            total_loss = policy_loss + entropy_coef*entropy_loss
            
            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_total_loss.append(total_loss)
            
        # Compute loss (average over trajectories)
        loss = torch.stack(batch_total_loss).mean()  # use stack because each element is zero-dim tensor
        
        # Zero-out gradient buffer
        self.optimizer.zero_grad()
        # Backward pass and compute gradients
        loss.backward()
        
        # Clip gradient norms if required
        if self.config['agent:max_grad_norm'] is not None:
            nn.utils.clip_grad_norm_(parameters=self.policy.network.parameters(), 
                                     max_norm=self.config['agent:max_grad_norm'], 
                                     norm_type=2)
        
        # Decay learning rate if required
        if hasattr(self, 'lr_scheduler'):
            self.lr_scheduler.step()
        
        # Take a gradient step
        self.optimizer.step()
        
        # Output dictionary for different losses
        # TODO: if no more backprop needed, record with .item(), save memory without store computation graph
        output = {}
        output['loss'] = loss
        output['batch_policy_loss'] = batch_policy_loss
        output['batch_entropy_loss'] = batch_entropy_loss
        output['batch_total_loss'] = batch_total_loss
        if hasattr(self, 'lr_scheduler'):
            output['current_lr'] = self.lr_scheduler.get_lr()

        return output
Exemplo n.º 7
0
    def learn(self, x):
        batch_policy_loss = []
        batch_entropy_loss = []
        batch_total_loss = []

        # Iterate over list of trajectories
        for trajectory in x:
            # Get all discounted returns
            Rs = trajectory.all_discounted_returns
            if self.config[
                    'standardize_r']:  # encourage/discourage half of performed actions, i.e. [-1, 1]
                Rs = Standardize()(Rs)

            # Get all log-probabilities and entropies
            logprobs = trajectory.all_info('action_logprob')
            entropies = trajectory.all_info('entropy')

            # All losses
            policy_loss = []
            entropy_loss = []

            # Estimate policy gradient for all time steps
            for logprob, entropy, R in zip(logprobs, entropies, Rs):
                policy_loss.append(-logprob *
                                   float(R))  # TODO: supports VecEnv
                entropy_loss.append(-entropy)

            # Sum up losses for all time steps
            policy_loss = torch.stack(policy_loss).sum()
            entropy_loss = torch.stack(entropy_loss).sum()

            # Calculate total loss
            entropy_coef = self.config['entropy_coef']
            total_loss = policy_loss + entropy_coef * entropy_loss

            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_total_loss.append(total_loss)

        # Compute loss (average over trajectories)
        loss = torch.stack(batch_total_loss).mean(
        )  # use stack because each element is zero-dim tensor

        # Zero-out gradient buffer
        self.optimizer.zero_grad()
        # Backward pass and compute gradients
        loss.backward()

        # Clip gradient norms if required
        if self.config['max_grad_norm'] is not None:
            nn.utils.clip_grad_norm_(
                parameters=self.policy.network.parameters(),
                max_norm=self.config['max_grad_norm'],
                norm_type=2)

        # Decay learning rate if required
        if hasattr(self, 'lr_scheduler'):
            self.lr_scheduler.step()

        # Take a gradient step
        self.optimizer.step()

        # Output dictionary for different losses
        # TODO: if no more backprop needed, record with .item(), save memory without store computation graph
        output = {}
        output['loss'] = loss
        output['batch_policy_loss'] = batch_policy_loss
        output['batch_entropy_loss'] = batch_entropy_loss
        output['batch_total_loss'] = batch_total_loss
        if hasattr(self, 'lr_scheduler'):
            output['current_lr'] = self.lr_scheduler.get_lr()

        return output
Exemplo n.º 8
0
    def learn(self, D):
        out = {}

        batch_policy_loss = []
        batch_entropy_loss = []
        batch_total_loss = []

        for trajectory in D:  # iterate over trajectories
            # Get all discounted returns as estimate of Q
            Qs = trajectory.all_discounted_returns

            # Standardize: encourage/discourage half of performed actions
            if self.config['agent.standardize_Q']:
                Qs = Standardize()(Qs).tolist()

            # Get all log-probabilities and entropies
            logprobs = trajectory.all_info('action_logprob')
            entropies = trajectory.all_info('entropy')

            # Estimate policy gradient for all time steps and record all losses
            policy_loss = []
            entropy_loss = []
            for logprob, entropy, Q in zip(logprobs, entropies, Qs):
                policy_loss.append(-logprob * Q)
                entropy_loss.append(-entropy)

            # Average over losses for all time steps
            policy_loss = torch.stack(policy_loss).mean()
            entropy_loss = torch.stack(entropy_loss).mean()

            # Calculate total loss
            entropy_coef = self.config['agent.entropy_coef']
            total_loss = policy_loss + entropy_coef * entropy_loss

            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_total_loss.append(total_loss)

        # Compute loss (average over trajectories): use `stack` as each is zero-dim
        loss = torch.stack(batch_total_loss).mean()
        policy_loss = torch.stack(batch_policy_loss).mean()
        entropy_loss = torch.stack(batch_entropy_loss).mean()

        # Zero-out gradient buffer
        self.optimizer.zero_grad()
        # Backward pass and compute gradients
        loss.backward()

        # Clip gradient norms if required
        if self.config['agent.max_grad_norm'] is not None:
            nn.utils.clip_grad_norm_(
                parameters=self.policy.network.parameters(),
                max_norm=self.config['agent.max_grad_norm'],
                norm_type=2)

        # Decay learning rate if required
        if hasattr(self, 'lr_scheduler'):
            if 'train.iter' in self.config:  # iteration-based training, so just increment epoch by default
                self.lr_scheduler.step()
            elif 'train.timestep' in self.config:  # timestep-based training, increment with timesteps
                self.lr_scheduler.step(self.total_T)
            else:
                raise KeyError(
                    'expected `train.iter` or `train.timestep` in config, but none of them exist'
                )

        # Take a gradient step
        self.optimizer.step()

        # Accumulate trained timesteps
        self.total_T += sum([trajectory.T for trajectory in D])

        # Output dictionary: use `item()` to save memory if no more backprop needed
        out['loss'] = loss.item()
        out['policy_loss'] = policy_loss.item()
        out['entropy_loss'] = entropy_loss.item()
        if hasattr(self, 'lr_scheduler'):
            out['current_lr'] = self.lr_scheduler.get_lr()

        return out
Exemplo n.º 9
0
    def learn(self, D):
        out = {}

        batch_policy_loss = []
        batch_value_loss = []
        batch_entropy_loss = []
        batch_total_loss = []

        for trajectory in D:  # iterate over trajectories
            # Get all discounted returns as estimate of Q
            Qs = trajectory.all_discounted_returns
            # Standardize: encourage/discourage half of performed actions
            if self.config['agent.standardize_Q']:
                Qs = Standardize()(Qs).tolist()

            # Get all state values (not use `all_V` for tuple result as trajectory only has one episode data)
            Vs = trajectory.all_info('V_s')
            final_V = trajectory.transitions[-1].V_s_next
            final_done = trajectory.transitions[-1].done

            # Advantage estimates
            As = [Q - V.item() for Q, V in zip(Qs, Vs)]
            # Standardize advantage: encourage/discourage half of performed actions
            if self.config['agent.standardize_adv']:
                As = Standardize()(As).tolist()

            # Get all log-probabilities and entropies
            logprobs = trajectory.all_info('action_logprob')
            entropies = trajectory.all_info('entropy')

            # Estimate policy gradient for all time steps and record all losses
            policy_loss = []
            entropy_loss = []
            value_loss = []
            for logprob, entropy, A, Q, V in zip(logprobs, entropies, As, Qs,
                                                 Vs):
                policy_loss.append(-logprob * A)
                entropy_loss.append(-entropy)
                value_loss.append(
                    F.mse_loss(V,
                               torch.tensor(Q).view_as(V).to(V.device)))
            if final_done:  # learn terminal state value as zero
                value_loss.append(
                    F.mse_loss(final_V,
                               torch.tensor(0.0).view_as(V).to(V.device)))

            # Average over losses for all time steps
            policy_loss = torch.stack(policy_loss).mean()
            entropy_loss = torch.stack(entropy_loss).mean()
            value_loss = torch.stack(value_loss).mean()

            # Calculate total loss
            entropy_coef = self.config['agent.entropy_coef']
            value_coef = self.config['agent.value_coef']
            total_loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss

            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_value_loss.append(value_loss)
            batch_total_loss.append(total_loss)

        # Compute loss (average over trajectories): use `stack` as each is zero-dim
        loss = torch.stack(batch_total_loss).mean()
        policy_loss = torch.stack(batch_policy_loss).mean()
        entropy_loss = torch.stack(batch_entropy_loss).mean()
        value_loss = torch.stack(batch_value_loss).mean()

        # Zero-out gradient buffer
        self.optimizer.zero_grad()
        # Backward pass and compute gradients
        loss.backward()

        # Clip gradient norms if required
        if self.config['agent.max_grad_norm'] is not None:
            nn.utils.clip_grad_norm_(
                parameters=self.policy.network.parameters(),
                max_norm=self.config['agent.max_grad_norm'],
                norm_type=2)

        # Decay learning rate if required
        if hasattr(self, 'lr_scheduler'):
            if 'train.iter' in self.config:  # iteration-based training, so just increment epoch by default
                self.lr_scheduler.step()
            elif 'train.timestep' in self.config:  # timestep-based training, increment with timesteps
                self.lr_scheduler.step(self.total_T)
            else:
                raise KeyError(
                    'expected `train.iter` or `train.timestep` in config, but none of them exist'
                )

        # Take a gradient step
        self.optimizer.step()

        # Accumulate trained timesteps
        self.total_T += sum([trajectory.T for trajectory in D])

        # Output dictionary: use `item()` to save memory if no more backprop needed
        out['loss'] = loss.item()
        out['policy_loss'] = policy_loss.item()
        out['entropy_loss'] = entropy_loss.item()
        out['value_loss'] = value_loss.item()
        if hasattr(self, 'lr_scheduler'):
            out['current_lr'] = self.lr_scheduler.get_lr()

        return out
Exemplo n.º 10
0
    def learn(self, D):
        batch_policy_loss = []
        batch_value_loss = []
        batch_entropy_loss = []
        batch_total_loss = []

        # Iterate over list of Segment in D
        for segment in D:
            # Get all boostrapped discounted returns as estimate of Q
            Qs = segment.all_bootstrapped_discounted_returns
            # TODO: when use GAE of TDs, really standardize it ? biased magnitude of learned value get wrong TD error
            # Standardize advantage estimates if required
            # encourage/discourage half of performed actions, respectively.

            ########
            # A2C: testing, normalizing advantage estimate instead

            # Get all state values (without V_s_next with all done=True also without the final transition)
            Vs = segment.all_info('V_s')

            # Advantage estimates
            As = [Q - V.item() for Q, V in zip(Qs, Vs)]

            ############
            if self.config['agent:standardize']:
                As = Standardize()(As)

            # Get all log-probabilities and entropies
            logprobs = segment.all_info('action_logprob')
            entropies = segment.all_info('entropy')

            # Estimate policy gradient for all time steps and record all losses
            policy_loss = []
            value_loss = []
            entropy_loss = []
            for logprob, entropy, A, Q, V in zip(logprobs, entropies, As, Qs,
                                                 Vs):
                policy_loss.append(-logprob * A)
                value_loss.append(F.mse_loss(V, torch.tensor(Q).to(V.device)))
                entropy_loss.append(-entropy)

            # Average over losses for all time steps
            policy_loss = torch.stack(policy_loss).mean()
            value_loss = torch.stack(value_loss).mean()
            entropy_loss = torch.stack(entropy_loss).mean()

            # Calculate total loss
            value_coef = self.config['agent:value_coef']
            entropy_coef = self.config['agent:entropy_coef']
            total_loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss

            # Record all losses
            batch_policy_loss.append(policy_loss)
            batch_value_loss.append(value_loss)
            batch_entropy_loss.append(entropy_loss)
            batch_total_loss.append(total_loss)

        # Compute loss (average over segments)
        loss = torch.stack(batch_total_loss).mean(
        )  # use stack because each element is zero-dim tensor

        # Zero-out gradient buffer
        self.optimizer.zero_grad()
        # Backward pass and compute gradients
        loss.backward()

        # Clip gradient norms if required
        if self.config['agent:max_grad_norm'] is not None:
            nn.utils.clip_grad_norm_(
                parameters=self.policy.network.parameters(),
                max_norm=self.config['agent:max_grad_norm'],
                norm_type=2)

        # Decay learning rate if required
        if hasattr(self, 'lr_scheduler'):
            self.lr_scheduler.step()

        # Take a gradient step
        self.optimizer.step()

        # Output dictionary for different losses
        # TODO: if no more backprop needed, record with .item(), save memory without store computation graph
        output = {}
        output['loss'] = loss  # TODO: maybe item()
        output['batch_policy_loss'] = batch_policy_loss
        output['batch_value_loss'] = batch_value_loss
        output['batch_entropy_loss'] = batch_entropy_loss
        output['batch_total_loss'] = batch_total_loss
        if hasattr(self, 'lr_scheduler'):
            output['current_lr'] = self.lr_scheduler.get_lr()

        return output