Beispiel #1
0
    def __init__(self, config):
        BasePlayer.__init__(self, config)
        self.network = config['network']
        self.actions_num = self.action_space.shape[0]
        self.actions_low = self.action_space.low
        self.actions_high = self.action_space.high
        self.mask = [False]

        observation_shape = algos_torch.torch_ext.shape_whc_to_cwh(
            self.state_shape)

        self.normalize_input = self.config['normalize_input']
        print(self.state_shape)
        config = {
            'actions_num': self.actions_num,
            'input_shape':
            algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape),
            'games_num': 1,
            'batch_num': 1,
        }
        self.model = self.network.build(config)
        self.model.cuda()
        self.model.eval()

        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(observation_shape).cuda()
            self.running_mean_std.eval()
    def __init__(self, base_name, observation_space, action_space, config):
        common.a2c_common.ContinuousA2CBase.__init__(self, base_name,
                                                     observation_space,
                                                     action_space, config)
        obs_shape = algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape)
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'games_num': 1,
            'batch_num': 1,
        }
        self.model = self.network.build(config)
        self.model.cuda()
        self.last_lr = float(self.last_lr)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    float(self.last_lr))
        #self.optimizer = algos_torch.torch_ext.RangerQH(self.model.parameters(), float(self.last_lr))

        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).cuda()
        if self.has_curiosity:
            self.rnd_curiosity = rnd_curiosity.RNDCurisityTrain(
                algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape),
                self.curiosity_config['network'], self.curiosity_config,
                self.writer, lambda obs: self._preproc_obs(obs))
class RNDCurisityTrain(nn.Module):
    def __init__(self, state_shape, model, config, writter, _preproc_obs):
        nn.Module.__init__(self)
        rnd_config = {
            'input_shape': state_shape,
        }
        self.model = RNDCuriosityNetwork(model.build('rnd',
                                                     **rnd_config)).cuda()
        self.config = config
        self.lr = config['lr']
        self.writter = writter
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          float(self.lr))
        self._preproc_obs = _preproc_obs
        self.output_normalization = RunningMeanStd((1, ),
                                                   norm_only=True).cuda()
        self.frame = 0
        self.exp_percent = config.get('exp_percent', 1.0)

    def get_loss(self, obs):
        obs = self._preproc_obs(obs)
        self.model.eval()
        self.output_normalization.train()
        with torch.no_grad():
            loss = self.model(obs)
            loss = loss.squeeze()
            loss = self.output_normalization(loss)

            return loss.cpu().numpy() * self.config['scale_value']

    def train(self, obs):
        self.model.train()
        mini_epoch = self.config['mini_epochs']
        mini_batch = self.config['minibatch_size']

        num_minibatches = np.shape(obs)[0] // mini_batch
        self.frame = self.frame + 1
        for _ in range(mini_epoch):
            # returning loss from last epoch
            avg_loss = 0
            for i in range(num_minibatches):
                obs_batch = obs[i * mini_batch:(i + 1) * mini_batch]
                obs_batch = self._preproc_obs(obs_batch)
                obs_batch = torch_ext.random_sample(obs_batch,
                                                    self.exp_percent)
                loss = self.model(obs_batch).mean()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                avg_loss += loss.item()

        self.writter.add_scalar('rnd/train_loss', avg_loss, self.frame)
        return avg_loss / num_minibatches
 def __init__(self, state_shape, model, config, writter, _preproc_obs):
     nn.Module.__init__(self)
     rnd_config = {
         'input_shape': state_shape,
     }
     self.model = RNDCuriosityNetwork(model.build('rnd',
                                                  **rnd_config)).cuda()
     self.config = config
     self.lr = config['lr']
     self.writter = writter
     self.optimizer = torch.optim.Adam(self.model.parameters(),
                                       float(self.lr))
     self._preproc_obs = _preproc_obs
     self.output_normalization = RunningMeanStd((1, ),
                                                norm_only=True).cuda()
     self.frame = 0
     self.exp_percent = config.get('exp_percent', 1.0)
Beispiel #5
0
 def __init__(self, base_name, observation_space, action_space, config):
     common.a2c_common.DiscreteA2CBase.__init__(self, base_name,
                                                observation_space,
                                                action_space, config)
     config = {
         'actions_num': self.actions_num,
         'input_shape':
         algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape),
         'games_num': 1,
         'batch_num': 1,
     }
     self.model = self.network.build(config)
     self.model.cuda()
     self.last_lr = float(self.last_lr)
     self.optimizer = optim.Adam(self.model.parameters(),
                                 float(self.last_lr))
     #self.optimizer = algos_torch.torch_ext.RangerQH(self.model.parameters(), float(self.last_lr))
     if self.normalize_input:
         self.running_mean_std = RunningMeanStd(
             observation_space.shape).cuda()
Beispiel #6
0
class DiscreteA2CAgent(common.a2c_common.DiscreteA2CBase):
    def __init__(self, base_name, observation_space, action_space, config):
        common.a2c_common.DiscreteA2CBase.__init__(self, base_name,
                                                   observation_space,
                                                   action_space, config)
        config = {
            'actions_num': self.actions_num,
            'input_shape':
            algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape),
            'games_num': 1,
            'batch_num': 1,
        }
        self.model = self.network.build(config)
        self.model.cuda()
        self.last_lr = float(self.last_lr)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    float(self.last_lr))
        #self.optimizer = algos_torch.torch_ext.RangerQH(self.model.parameters(), float(self.last_lr))
        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(
                observation_space.shape).cuda()

    def set_eval(self):
        self.model.eval()
        if self.normalize_input:
            self.running_mean_std.eval()

    def set_train(self):
        self.model.train()
        if self.normalize_input:
            self.running_mean_std.train()

    def update_epoch(self):
        self.epoch_num += 1
        return self.epoch_num

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == np.uint8:
            obs_batch = torch.cuda.ByteTensor(obs_batch)
            obs_batch = obs_batch.float() / 255.0
        else:
            obs_batch = torch.cuda.FloatTensor(obs_batch)
        if len(obs_batch.size()) == 3:
            obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def save(self, fn):
        state = {
            'epoch': self.epoch_num,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        if self.normalize_input:
            state['running_mean_std'] = self.running_mean_std.state_dict()
        algos_torch.torch_ext.save_scheckpoint(fn, state)

    def restore(self, fn):
        algos_torch.torch_ext.load_checkpoint(fn, self.model, self.optimizer)
        self.epoch_num = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    def get_masked_action_values(self, obs, action_masks):
        obs = self._preproc_obs(obs)
        action_masks = torch.Tensor(action_masks).cuda()
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'inputs': obs,
            'action_masks': action_masks
        }
        with torch.no_grad():
            neglogp, value, action, logits = self.model(input_dict)
        return action.detach().cpu().numpy(), value.detach().cpu().numpy(
        ), neglogp.detach().cpu().numpy(), logits.detach().cpu().numpy(), None

    def get_action_values(self, obs):
        obs = self._preproc_obs(obs)
        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'inputs': obs,
        }
        with torch.no_grad():
            neglogp, value, action, logits = self.model(input_dict)
        return action.detach().cpu().numpy(), value.detach().cpu().numpy(
        ), neglogp.detach().cpu().numpy(), None

    def get_values(self, obs):
        obs = self._preproc_obs(obs)
        self.model.eval()
        input_dict = {'is_train': False, 'prev_actions': None, 'inputs': obs}
        with torch.no_grad():
            neglogp, value, action, logits = self.model(input_dict)
        return value.detach().cpu().numpy()

    def get_weights(self):
        return torch.nn.utils.parameters_to_vector(self.model.parameters())

    def set_weights(self, weights):
        torch.nn.utils.vector_to_parameters(weights, self.model.parameters())

    def train_actor_critic(self, input_dict):
        self.model.train()
        value_preds_batch = torch.cuda.FloatTensor(input_dict['old_values'])
        old_action_log_probs_batch = torch.cuda.FloatTensor(
            input_dict['old_logp_actions'])
        advantage = torch.cuda.FloatTensor(input_dict['advantages'])
        return_batch = torch.cuda.FloatTensor(input_dict['returns'])
        actions_batch = torch.cuda.LongTensor(input_dict['actions'])
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)
        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        input_dict = {
            'is_train': True,
            'prev_actions': actions_batch,
            'inputs': obs_batch
        }
        action_log_probs, values, entropy = self.model(input_dict)

        if self.ppo:
            ratio = torch.exp(old_action_log_probs_batch - action_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - curr_e_clip,
                                1.0 + curr_e_clip) * advantage
            a_loss = torch.max(-surr1, -surr2).mean()
        else:
            a_loss = (action_log_probs * advantage).mean()

        values = torch.squeeze(values)
        if self.clip_value:
            value_pred_clipped = value_preds_batch + \
                (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip)
            value_losses = (values - return_batch)**2
            value_losses_clipped = (value_pred_clipped - return_batch)**2
            c_loss = torch.max(value_losses, value_losses_clipped)
        else:
            c_loss = (return_batch - values)**2

        c_loss = c_loss.mean()
        loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
        self.optimizer.step()
        with torch.no_grad():
            kl_dist = 0.5 * (
                (old_action_log_probs_batch - action_log_probs)**2).mean()
            kl_dist = kl_dist.item()
            if self.is_adaptive_lr:
                if kl_dist > (2.0 * self.lr_threshold):
                    self.last_lr = max(self.last_lr / 1.5, 1e-6)
                if kl_dist < (0.5 * self.lr_threshold):
                    self.last_lr = min(self.last_lr * 1.5, 1e-2)
        return a_loss.item(), c_loss.item(), entropy.item(
        ), kl_dist, self.last_lr, lr_mul
Beispiel #7
0
class PpoPlayerDiscrete(BasePlayer):
    def __init__(self, config):
        BasePlayer.__init__(self, config)
        self.network = config['network']
        self.actions_num = self.action_space.n
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']
        observation_shape = algos_torch.torch_ext.shape_whc_to_cwh(
            self.state_shape)
        config = {
            'actions_num': self.actions_num,
            'input_shape': observation_shape,
            'games_num': 1,
            'batch_num': 1,
        }
        self.model = self.network.build(config)
        self.model.cuda()
        self.model.eval()

        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(observation_shape).cuda()
            self.running_mean_std.eval()

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == np.uint8:
            obs_batch = torch.cuda.ByteTensor(obs_batch)
            obs_batch = obs_batch.float() / 255.0
        else:
            obs_batch = torch.cuda.FloatTensor(obs_batch)
        if len(obs_batch.size()) == 3:
            obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def get_masked_action(self, obs, action_masks, is_determenistic=False):
        if self.num_agents == 1:
            obs = np.expand_dims(obs, axis=0)
        obs = self._preproc_obs(obs)
        action_masks = torch.Tensor(action_masks).cuda()
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'inputs': obs,
            'action_masks': action_masks
        }
        with torch.no_grad():
            neglogp, value, action, logits = self.model(input_dict)

        if is_determenistic:
            return np.argmax(logits.squeeze().detach().cpu().numpy(), axis=1)
        else:
            return action.squeeze().detach().cpu().numpy()

    def get_action(self, obs, is_determenistic=False):
        if self.num_agents == 1:
            obs = np.expand_dims(obs, axis=0)
        obs = self._preproc_obs(obs)
        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'inputs': obs,
        }
        with torch.no_grad():
            neglogp, value, action, logits = self.model(input_dict)

        if is_determenistic:
            return np.argmax(logits.detach().cpu().numpy(), axis=1).squeeze()
        else:
            return action.squeeze().detach().cpu().numpy()

    def restore(self, fn):
        checkpoint = algos_torch.torch_ext.load_checkpoint(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])

    def reset(self):
        if self.network.is_rnn():
            self.last_state = self.initial_state
Beispiel #8
0
class PpoPlayerContinuous(BasePlayer):
    def __init__(self, config):
        BasePlayer.__init__(self, config)
        self.network = config['network']
        self.actions_num = self.action_space.shape[0]
        self.actions_low = self.action_space.low
        self.actions_high = self.action_space.high
        self.mask = [False]

        observation_shape = algos_torch.torch_ext.shape_whc_to_cwh(
            self.state_shape)

        self.normalize_input = self.config['normalize_input']
        print(self.state_shape)
        config = {
            'actions_num': self.actions_num,
            'input_shape':
            algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape),
            'games_num': 1,
            'batch_num': 1,
        }
        self.model = self.network.build(config)
        self.model.cuda()
        self.model.eval()

        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(observation_shape).cuda()
            self.running_mean_std.eval()

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == np.uint8:
            obs_batch = torch.cuda.ByteTensor(obs_batch)
            obs_batch = obs_batch.float() / 255.0
        else:
            obs_batch = torch.cuda.FloatTensor(obs_batch)
        if len(obs_batch.size()) == 3:
            obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def get_action(self, obs, is_determenistic=False):
        obs = self._preproc_obs(np.expand_dims(obs, axis=0))
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'inputs': obs,
        }
        with torch.no_grad():
            neglogp, value, action, mu, sigma = self.model(input_dict)
        if is_determenistic:
            current_action = mu
        else:
            current_action = action
        current_action = np.squeeze(current_action.detach().cpu().numpy())
        return rescale_actions(self.actions_low, self.actions_high,
                               np.clip(current_action, -1.0, 1.0))

    def restore(self, fn):
        checkpoint = algos_torch.torch_ext.load_checkpoint(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])

    def reset(self):
        if self.network.is_rnn():
            self.last_state = self.initial_state
class A2CAgent(common.a2c_common.ContinuousA2CBase):
    def __init__(self, base_name, observation_space, action_space, config):
        common.a2c_common.ContinuousA2CBase.__init__(self, base_name,
                                                     observation_space,
                                                     action_space, config)
        obs_shape = algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape)
        config = {
            'actions_num': self.actions_num,
            'input_shape': obs_shape,
            'games_num': 1,
            'batch_num': 1,
        }
        self.model = self.network.build(config)
        self.model.cuda()
        self.last_lr = float(self.last_lr)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    float(self.last_lr))
        #self.optimizer = algos_torch.torch_ext.RangerQH(self.model.parameters(), float(self.last_lr))

        if self.normalize_input:
            self.running_mean_std = RunningMeanStd(obs_shape).cuda()
        if self.has_curiosity:
            self.rnd_curiosity = rnd_curiosity.RNDCurisityTrain(
                algos_torch.torch_ext.shape_whc_to_cwh(self.state_shape),
                self.curiosity_config['network'], self.curiosity_config,
                self.writer, lambda obs: self._preproc_obs(obs))

    def update_epoch(self):
        self.epoch_num += 1
        return self.epoch_num

    def _preproc_obs(self, obs_batch):
        if obs_batch.dtype == np.uint8:
            obs_batch = torch.cuda.ByteTensor(obs_batch)
            obs_batch = obs_batch.float() / 255.0
        else:
            obs_batch = torch.cuda.FloatTensor(obs_batch)
        if len(obs_batch.size()) == 3:
            obs_batch = obs_batch.permute((0, 2, 1))
        if len(obs_batch.size()) == 4:
            obs_batch = obs_batch.permute((0, 3, 1, 2))
        if self.normalize_input:
            obs_batch = self.running_mean_std(obs_batch)
        return obs_batch

    def save(self, fn):
        state = {
            'epoch': self.epoch_num,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        if self.normalize_input:
            state['running_mean_std'] = self.running_mean_std.state_dict()
        if self.has_curiosity:
            state['rnd_nets'] = self.rnd_curiosity.state_dict()
        algos_torch.torch_ext.save_scheckpoint(fn, state)

    def restore(self, fn):
        checkpoint = algos_torch.torch_ext.load_checkpoint(fn)
        self.epoch_num = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(
                checkpoint['running_mean_std'])
        if self.has_curiosity:
            self.rnd_curiosity.load_state_dict(checkpoint['rnd_nets'])
            for state in self.rnd_curiosity.optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    def get_masked_action_values(self, obs, action_masks):
        assert False

    def set_eval(self):
        self.model.eval()
        if self.normalize_input:
            self.running_mean_std.eval()

    def set_train(self):
        self.model.train()
        if self.normalize_input:
            self.running_mean_std.train()

    def get_action_values(self, obs):
        self.set_eval()
        obs = self._preproc_obs(obs)
        input_dict = {
            'is_train': False,
            'prev_actions': None,
            'inputs': obs,
        }
        with torch.no_grad():
            neglogp, value, action, mu, sigma = self.model(input_dict)
        return action.detach().cpu().numpy(), \
                value.detach().cpu().numpy(), \
                neglogp.detach().cpu().numpy(), \
                mu.detach().cpu().numpy(), \
                sigma.detach().cpu().numpy(), \
                None

    def get_values(self, obs):
        obs = self._preproc_obs(obs)
        self.model.eval()
        input_dict = {'is_train': False, 'prev_actions': None, 'inputs': obs}
        with torch.no_grad():
            neglogp, value, action, mu, sigma = self.model(input_dict)
        return value.detach().cpu().numpy()

    def get_weights(self):
        return torch.nn.utils.parameters_to_vector(self.model.parameters())

    def set_weights(self, weights):
        torch.nn.utils.vector_to_parameters(weights, self.model.parameters())

    def get_intrinsic_reward(self, obs):
        return self.rnd_curiosity.get_loss(obs)

    def train_intrinsic_reward(self, dict):
        obs = dict['obs']
        self.rnd_curiosity.train(obs)

    def train_actor_critic(self, input_dict):
        self.set_train()

        value_preds_batch = torch.cuda.FloatTensor(input_dict['old_values'])
        old_action_log_probs_batch = torch.cuda.FloatTensor(
            input_dict['old_logp_actions'])
        advantage = torch.cuda.FloatTensor(input_dict['advantages'])
        old_mu_batch = torch.cuda.FloatTensor(input_dict['mu'])
        old_sigma_batch = torch.cuda.FloatTensor(input_dict['sigma'])
        return_batch = torch.cuda.FloatTensor(input_dict['returns'])
        actions_batch = torch.cuda.FloatTensor(input_dict['actions'])
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)
        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        input_dict = {
            'is_train': True,
            'prev_actions': actions_batch,
            'inputs': obs_batch
        }
        action_log_probs, values, entropy, mu, sigma = self.model(input_dict)
        if self.ppo:
            ratio = torch.exp(old_action_log_probs_batch - action_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - curr_e_clip,
                                1.0 + curr_e_clip) * advantage
            a_loss = torch.max(-surr1, -surr2).mean()
        else:
            a_loss = (action_log_probs * advantage).mean()

        values = torch.squeeze(values)
        if self.clip_value:
            value_pred_clipped = value_preds_batch + \
                (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip)
            value_losses = (values - return_batch)**2
            value_losses_clipped = (value_pred_clipped - return_batch)**2
            c_loss = torch.max(value_losses, value_losses_clipped)
        else:
            c_loss = (return_batch - values)**2

        if self.has_curiosity:
            c_loss = c_loss.sum(dim=1).mean()
        else:
            c_loss = c_loss.mean()

        loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef
        if self.bounds_loss_coef is not None:
            soft_bound = 1.1
            mu_loss_high = torch.clamp_max(mu - soft_bound, 0.0)**2
            mu_loss_low = torch.clamp_max(-mu + soft_bound, 0.0)**2
            b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1).mean()
        else:
            b_loss = 0
        loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
        self.optimizer.step()

        with torch.no_grad():
            kl_dist = algos_torch.torch_ext.policy_kl(mu.detach(),
                                                      sigma.detach(),
                                                      old_mu_batch,
                                                      old_sigma_batch)
            kl_dist = kl_dist.item()
            if self.is_adaptive_lr:
                if kl_dist > (2.0 * self.lr_threshold):
                    self.last_lr = max(self.last_lr / 1.5, 1e-6)
                if kl_dist < (0.5 * self.lr_threshold):
                    self.last_lr = min(self.last_lr * 1.5, 1e-2)

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.last_lr

        return a_loss.item(), c_loss.item(), entropy.item(), \
            kl_dist, self.last_lr, lr_mul, \
            mu.detach().cpu().numpy(), sigma.detach().cpu().numpy(), b_loss.item()