Exemplo n.º 1
0
 def train(self):
     '''
     Completes one training step for the agent if it is time to train.
     Otherwise this function does nothing.
     '''
     if util.in_eval_lab_modes():
         return np.nan
     clock = self.body.env.clock
     if self.to_train == 1:
         batch = self.sample()
         clock.set_batch_size(len(batch))
         loss = self.calc_q_loss(batch)
         self.net.train_step(loss,
                             self.optim,
                             self.lr_scheduler,
                             clock=clock,
                             global_net=self.global_net)
         # reset
         self.to_train = 0
         logger.debug(
             f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
         )
         return loss.item()
     else:
         return np.nan
Exemplo n.º 2
0
 def time_fn(*args, **kwargs):
     start = time.time()
     output = fn(*args, **kwargs)
     end = time.time()
     logger.debug(
         f'Timed: {fn.__name__} {round((end - start) * 1000, 4)}ms')
     return output
Exemplo n.º 3
0
 def train(self):
     clock = self.body.env.clock
     if self.to_train == 1:
         # onpolicy update
         super_loss = super().train()
         # offpolicy sil update with random minibatch
         total_sil_loss = torch.tensor(0.0)
         for _ in range(self.training_iter):
             batch = self.replay_sample()
             for _ in range(self.training_batch_iter):
                 pdparams, _v_preds = self.calc_pdparam_v(batch)
                 sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(
                     batch, pdparams)
                 sil_loss = sil_policy_loss + sil_val_loss
                 self.net.train_step(sil_loss,
                                     self.optim,
                                     self.lr_scheduler,
                                     clock=clock,
                                     global_net=self.global_net)
                 total_sil_loss += sil_loss
         sil_loss = total_sil_loss / self.training_iter
         loss = super_loss + sil_loss
         logger.debug(
             f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
         )
         return loss.item()
     else:
         return np.nan
Exemplo n.º 4
0
 def train(self):
     '''
     Completes one training step for the agent if it is time to train.
     i.e. the environment timestep is greater than the minimum training timestep and a multiple of the training_frequency.
     Each training step consists of sampling n batches from the agent's memory.
     For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times
     Otherwise this function does nothing.
     '''
     if util.in_eval_lab_modes():
         return np.nan
     clock = self.body.env.clock
     if self.to_train == 1:
         total_loss = torch.tensor(0.0)
         for _ in range(self.training_iter):
             batch = self.sample()
             clock.set_batch_size(len(batch))
             for _ in range(self.training_batch_iter):
                 loss = self.calc_q_loss(batch)
                 self.net.train_step(loss,
                                     self.optim,
                                     self.lr_scheduler,
                                     clock=clock,
                                     global_net=self.global_net)
                 total_loss += loss
         loss = total_loss / (self.training_iter * self.training_batch_iter)
         # reset
         self.to_train = 0
         logger.debug(
             f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
         )
         return loss.item()
     else:
         return np.nan
Exemplo n.º 5
0
    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        states = batch['states']
        next_states = batch['next_states']
        q_preds = self.net(states)
        with torch.no_grad():
            # Use online_net to select actions in next state
            online_next_q_preds = self.online_net(next_states)
            # Use eval_net to calculate next_q_preds for actions chosen by online_net
            next_q_preds = self.eval_net(next_states)
        act_q_preds = q_preds.gather(
            -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
        online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
        max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
        max_q_targets = batch['rewards'] + self.gamma * (
            1 - batch['dones']) * max_next_q_preds
        logger.debug(
            f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
        q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

        # TODO use the same loss_fn but do not reduce yet
        if 'Prioritized' in util.get_class_name(self.body.memory):  # PER
            errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
            self.body.memory.update_priorities(errors)
        return q_loss
Exemplo n.º 6
0
 def calc_gae_advs_v_targets(self, batch, v_preds):
     '''
     Calculate GAE, and advs = GAE, v_targets = advs + v_preds
     See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
         next_v_pred = next_v_pred.unsqueeze(dim=0)
     v_preds_all = torch.cat((v_preds, next_v_pred), dim=0)
     advs = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                v_preds_all, self.gamma, self.lam)
     v_targets = advs + v_preds
     advs = math_util.standardize(
         advs)  # standardize only for advs, not v_targets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemplo n.º 7
0
    def _normalize_value(cls, domain, intent, slot, value):
        if intent == 'request':
            return DEF_VAL_UNK

        if domain not in cls.stand_value_dict.keys():
            return value

        if slot not in cls.stand_value_dict[domain]:
            return value

        value_list = cls.stand_value_dict[domain][slot]
        low_value_list = [item.lower() for item in value_list]
        value_list = list(set(value_list).union(set(low_value_list)))
        if value not in value_list:
            normalized_v = simple_fuzzy_match(value_list, value)
            if normalized_v is not None:
                return normalized_v
            # try some transformations
            cand_values = transform_value(value)
            for cv in cand_values:
                _nv = simple_fuzzy_match(value_list, cv)
                if _nv is not None:
                    return _nv
            if check_if_time(value):
                return value

            logger.debug('Value not found in standard value set: [%s] (slot: %s domain: %s)' % (value, slot, domain))
        return value
Exemplo n.º 8
0
 def reset(self):
     # _reward = np.nan
     env_info_dict = self.u_env.reset(train_mode=(util.get_lab_mode() != 'dev'), config=self.env_spec.get('multiwoz'))
     a, b = 0, 0  # default singleton aeb
     env_info_a = self._get_env_info(env_info_dict, a)
     state = env_info_a.states[b]
     self.done = False
     logger.debug(f'Env {self.e} reset state: {state}')
     return state
Exemplo n.º 9
0
 def calc_ret_advs(self, batch):
     '''Calculate plain returns; which is generalized to advantage in ActorCritic'''
     batch_rewards = self.modify_batch_reward(batch)
     rets = math_util.calc_returns(batch_rewards, batch['dones'], self.gamma)
     advs = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
     logger.debug(f'advs: {advs}')
     return advs
Exemplo n.º 10
0
 def step(self, action):
     env_info_dict = self.u_env.step(action)
     a, b = 0, 0  # default singleton aeb
     env_info_a = self._get_env_info(env_info_dict, a)
     reward = env_info_a.rewards[b]  # * self.reward_scale
     state = env_info_a.states[b]
     done = env_info_a.local_done[b]
     self.done = done = done or self.clock.t > self.max_t
     logger.debug(f'Env {self.e} step reward: {reward}, state: {state}, done: {done}')
     return state, reward, done, env_info_a 
Exemplo n.º 11
0
    def reset(self, obs):
        '''Do agent reset per session, such as memory pointer'''
        logger.debug(f'Agent {self.a} reset')
        if self.dst:
            self.dst.init_session()
        if hasattr(self.algorithm, "reset"):  # This is mainly for external policies that may need to reset its state.
            self.algorithm.reset()

        input_act, state, encoded_state = self.state_update(obs, "null")  # "null" action to be compatible with MDBT

        self.body.state, self.body.encoded_state = state, encoded_state
Exemplo n.º 12
0
 def calc_policy_loss(self, batch, pdparams, advs):
     '''Calculate the actor's policy loss'''
     action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
     actions = batch['actions']
     if self.body.env.is_venv:
         actions = math_util.venv_unpack(actions)
     log_probs = action_pd.log_prob(actions)
     policy_loss = - self.policy_loss_coef * (log_probs * advs).mean()
     if self.entropy_coef_spec:
         entropy = action_pd.entropy().mean()
         self.body.mean_entropy = entropy  # update logging variable
         policy_loss += (-self.body.entropy_coef * entropy)
     logger.debug(f'Actor policy loss: {policy_loss:g}')
     return policy_loss
Exemplo n.º 13
0
 def calc_ret_advs_v_targets(self, batch, v_preds):
     '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets'''
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     rets = math_util.calc_returns(batch['rewards'], batch['dones'],
                                   self.gamma)
     advs = rets - v_preds
     v_targets = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemplo n.º 14
0
    def check_fn(*args, **kwargs):
        if not to_check_train_step():
            return fn(*args, **kwargs)

        net = args[0]  # first arg self
        # get pre-update parameters to compare
        pre_params = [param.clone() for param in net.parameters()]

        # run train_step, get loss
        loss = fn(*args, **kwargs)
        assert not torch.isnan(loss).any(), loss

        # get post-update parameters to compare
        post_params = [param.clone() for param in net.parameters()]
        if loss == 0.0:
            # if loss is 0, there should be no updates
            # TODO if without momentum, parameters should not change too
            for p_name, param in net.named_parameters():
                assert param.grad.norm() == 0
        else:
            # check parameter updates
            try:
                assert not all(
                    torch.equal(w1, w2)
                    for w1, w2 in zip(pre_params, post_params)
                ), f'Model parameter is not updated in train_step(), check if your tensor is detached from graph. Loss: {loss:g}'
                logger.info(
                    f'Model parameter is updated in train_step(). Loss: {loss: g}'
                )
            except Exception as e:
                logger.error(e)
                if os.environ.get('PY_ENV') == 'test':
                    # raise error if in unit test
                    raise (e)

            # check grad norms
            min_norm, max_norm = 0.0, 1e5
            for p_name, param in net.named_parameters():
                try:
                    grad_norm = param.grad.norm()
                    assert min_norm < grad_norm < max_norm, f'Gradient norm for {p_name} is {grad_norm:g}, fails the extreme value check {min_norm} < grad_norm < {max_norm}. Loss: {loss:g}. Check your network and loss computation.'
                except Exception as e:
                    logger.warning(e)
            logger.info(f'Gradient norms passed value check.')
        logger.debug('Passed network parameter update check.')
        # store grad norms for debugging
        net.store_grad_norms()
        return loss
Exemplo n.º 15
0
    def train(self):
        '''Train actor critic by computing the loss in batch efficiently'''
        if util.in_eval_lab_modes():
            return np.nan
        clock = self.body.env.clock
        if self.to_train == 1:
            batch = self.sample()
            """
            Add rewards over here.
            """
            batch = self.replace_reward_batch(batch)
            # batch = self.fetch_disc_reward(batch)
            clock.set_batch_size(len(batch))
            pdparams, v_preds = self.calc_pdparam_v(batch)

            # get loss of critic: advs and targets of critic v_targets.
            advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
            policy_loss = self.calc_policy_loss(batch, pdparams,
                                                advs)  # from actor
            val_loss = self.calc_val_loss(v_preds, v_targets)  # from critic
            if self.shared:  # shared network
                loss = policy_loss + val_loss
                self.net.train_step(loss,
                                    self.optim,
                                    self.lr_scheduler,
                                    clock=clock,
                                    global_net=self.global_net)
            else:
                # not shared! F**k You!
                self.net.train_step(policy_loss,
                                    self.optim,
                                    self.lr_scheduler,
                                    clock=clock,
                                    global_net=self.global_net)
                self.critic_net.train_step(val_loss,
                                           self.critic_optim,
                                           self.critic_lr_scheduler,
                                           clock=clock,
                                           global_net=self.global_critic_net)
                loss = policy_loss + val_loss
            # reset
            self.to_train = 0
            logger.debug(
                f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
            )
            return loss.item()
        else:
            return np.nan
Exemplo n.º 16
0
    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        batch_rewards_ori, batch_rewards_log, batch_rewards_log_double, batch_reward_log_minus_one = self.fetch_irl_reward(
            batch)
        self.reward_count += batch_rewards_ori.mean().item()
        self.batch_count += 1
        # batch_rewards = batch_rewards_ori + batch['rewards']
        # batch_rewards = batch_reward_log_minus_one
        # batch_rewards = batch_rewards_log
        # batch_rewards = batch_rewards_log_double + batch['rewards']
        """
        here to change the reward function. From two to choose one. For me, baseline is not running over here.
        Specify the method of surgery
        change VAE function in the other place.
        """
        batch_rewards = batch_rewards_log.to("cpu") + batch['rewards']
        # batch_rewards = batch['rewards']
        # batch_rewards = batch_rewards_ori.to("cpu") + batch['rewards']
        # flag = copy.deepcopy(batch['rewards'])
        # flag[flag<=0]=0
        # flag[flag>0]=1
        # batch_rewards = batch_rewards_log + flag * batch['rewards']

        states = batch['states']
        next_states = batch['next_states']
        q_preds = self.net(states)
        with torch.no_grad():
            # Use online_net to select actions in next state
            online_next_q_preds = self.online_net(next_states)
            # Use eval_net to calculate next_q_preds for actions chosen by online_net
            next_q_preds = self.eval_net(next_states)
        act_q_preds = q_preds.gather(
            -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
        online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
        max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
        max_q_targets = batch_rewards + self.gamma * (
            1 - batch['dones']) * max_next_q_preds
        logger.debug(
            f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
        q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

        # TODO use the same loss_fn but do not reduce yet
        if 'Prioritized' in util.get_class_name(self.body.memory):  # PER
            errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
            self.body.memory.update_priorities(errors)
        return q_loss
Exemplo n.º 17
0
def save_algorithm(algorithm, ckpt=None):
    '''Save all the nets for an algorithm'''
    agent = algorithm.agent
    net_names = algorithm.net_names
    model_prepath = agent.spec['meta']['model_prepath']
    if ckpt is not None:
        model_prepath = f'{model_prepath}_ckpt-{ckpt}'
    for net_name in net_names:
        net = getattr(algorithm, net_name)
        model_path = f'{model_prepath}_{net_name}_model.pt'
        save(net, model_path)
        optim_name = net_name.replace('net', 'optim')
        optim = getattr(algorithm, optim_name, None)
        if optim is not None:  # only trainable net has optim
            optim_path = f'{model_prepath}_{net_name}_optim.pt'
            save(optim, optim_path)
    logger.debug(f'Saved algorithm {util.get_class_name(algorithm)} nets {net_names} to {model_prepath}_*.pt')
Exemplo n.º 18
0
    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        states = batch['states']
        next_states = batch['next_states']
        q_preds = self.net(states)
        with torch.no_grad():
            next_q_preds = self.net(next_states)
        act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
        # Bellman equation: compute max_q_targets using reward and max estimated Q values (0 if no next_state)
        max_next_q_preds, _ = next_q_preds.max(dim=-1, keepdim=True)
        max_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * max_next_q_preds
        logger.debug(f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
        q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

        # TODO use the same loss_fn but do not reduce yet
        if 'Prioritized' in util.get_class_name(self.body.memory):  # PER
            errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
            self.body.memory.update_priorities(errors)
        return q_loss
Exemplo n.º 19
0
 def calc_ret_advs_v_targets(self, batch, v_preds):
     """
     Args:
         batch:
         v_preds:
     Returns: advs(difference between prediction and original reward), v_targets, original reward.
     """
     '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets'''
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     rets = math_util.calc_returns(batch['rewards'], batch['dones'],
                                   self.gamma)
     advs = rets - v_preds
     v_targets = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemplo n.º 20
0
    def calc_policy_loss(self, batch, pdparams, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        clip_eps = self.body.clip_eps
        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        states = batch['states']
        actions = batch['actions']
        if self.body.env.is_venv:
            states = math_util.venv_unpack(states)
            actions = math_util.venv_unpack(actions)

        # L^CLIP
        log_probs = action_pd.log_prob(actions)
        with torch.no_grad():
            old_pdparams = self.calc_pdparam(states, net=self.old_net)
            old_action_pd = policy_util.init_action_pd(self.body.ActionPD,
                                                       old_pdparams)
            old_log_probs = old_action_pd.log_prob(actions)
        assert log_probs.shape == old_log_probs.shape
        ratios = torch.exp(log_probs -
                           old_log_probs)  # clip to prevent overflow
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.min(sur_1, sur_2).mean()
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        entropy = action_pd.entropy().mean()
        self.body.mean_entropy = entropy  # update logging variable
        ent_penalty = -self.body.entropy_coef * entropy
        logger.debug(f'ent_penalty: {ent_penalty}')

        policy_loss = clip_loss + ent_penalty
        logger.debug(f'PPO Actor policy loss: {policy_loss:g}')
        return policy_loss
Exemplo n.º 21
0
 def calc_nstep_advs_v_targets(self, batch, v_preds):
     '''
     Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets
     See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     nstep_rets = math_util.calc_nstep_returns(batch['rewards'],
                                               batch['dones'], next_v_pred,
                                               self.gamma,
                                               self.num_step_returns)
     advs = nstep_rets - v_preds
     v_targets = nstep_rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemplo n.º 22
0
    def calc_sil_policy_val_loss(self, batch, pdparams):
        '''
        Calculate the SIL policy losses for actor and critic
        sil_policy_loss = -log_prob * max(R - v_pred, 0)
        sil_val_loss = (max(R - v_pred, 0)^2) / 2
        This is called on a randomly-sample batch from experience replay
        '''
        v_preds = self.calc_v(batch['states'], use_cache=False)
        rets = math_util.calc_returns(batch['rewards'], batch['dones'],
                                      self.gamma)
        clipped_advs = torch.clamp(rets - v_preds, min=0.0)

        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        actions = batch['actions']
        if self.body.env.is_venv:
            actions = math_util.venv_unpack(actions)
        log_probs = action_pd.log_prob(actions)

        sil_policy_loss = -self.sil_policy_loss_coef * (log_probs *
                                                        clipped_advs).mean()
        sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2
        logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}')
        logger.debug(f'SIL critic value loss: {sil_val_loss:g}')
        return sil_policy_loss, sil_val_loss
Exemplo n.º 23
0
 def calc_q_loss(self, batch):
     '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
     states = batch['states']
     next_states = batch['next_states']
     if self.body.env.is_venv:
         states = math_util.venv_unpack(states)
         next_states = math_util.venv_unpack(next_states)
     q_preds = self.net(states)
     with torch.no_grad():
         next_q_preds = self.net(next_states)
     if self.body.env.is_venv:
         q_preds = math_util.venv_pack(q_preds, self.body.env.num_envs)
         next_q_preds = math_util.venv_pack(next_q_preds,
                                            self.body.env.num_envs)
     act_q_preds = q_preds.gather(
         -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
     act_next_q_preds = next_q_preds.gather(
         -1, batch['next_actions'].long().unsqueeze(-1)).squeeze(-1)
     act_q_targets = batch['rewards'] + self.gamma * (
         1 - batch['dones']) * act_next_q_preds
     logger.debug(
         f'act_q_preds: {act_q_preds}\nact_q_targets: {act_q_targets}')
     q_loss = self.net.loss_fn(act_q_preds, act_q_targets)
     return q_loss
Exemplo n.º 24
0
 def calc_val_loss(self, v_preds, v_targets):
     '''Calculate the critic's value loss'''
     assert v_preds.shape == v_targets.shape, f'{v_preds.shape} != {v_targets.shape}'
     val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets)
     logger.debug(f'Critic value loss: {val_loss:g}')
     return val_loss
Exemplo n.º 25
0
    def train(self):
        # torch.save(self.net.state_dict(), './reward_model/policy_pretrain.mdl')
        # raise ValueError("policy pretrain stops")
        if util.in_eval_lab_modes():
            return np.nan
        clock = self.body.env.clock
        if self.body.env.clock.epi > 700:
            self.pretrain_finished = True
            # torch.save(self.discriminator.state_dict(), './reward_model/airl_pretrain.mdl')
            # raise ValueError("pretrain stops here")
        if self.to_train == 1:
            net_util.copy(self.net, self.old_net)  # update old net
            batch = self.sample()
            if self.reward_type == 'OFFGAN':
                batch = self.replace_reward_batch(batch)
            # if self.reward_type =='DISC':
            # batch = self.fetch_disc_reward(batch)
            # if self.reward_type =='AIRL':
            # batch = self.fetch_airl_reward(batch)
            # if self.reward_type == 'OFFGAN_update':
            # batch = self.fetch_offgan_reward(batch)

            clock.set_batch_size(len(batch))
            _pdparams, v_preds = self.calc_pdparam_v(batch)
            advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
            # piggy back on batch, but remember to not pack or unpack
            batch['advs'], batch['v_targets'] = advs, v_targets
            if self.body.env.is_venv:  # unpack if venv for minibatch sampling
                for k, v in batch.items():
                    if k not in ('advs', 'v_targets'):
                        batch[k] = math_util.venv_unpack(v)
            total_loss = torch.tensor(0.0)
            for _ in range(self.training_epoch):
                minibatches = util.split_minibatch(batch, self.minibatch_size)

                # if not self.pretrain_finished or not self.policy_training_flag:
                #     break

                for minibatch in minibatches:
                    if self.body.env.is_venv:  # re-pack to restore proper shape
                        for k, v in minibatch.items():
                            if k not in ('advs', 'v_targets'):
                                minibatch[k] = math_util.venv_pack(
                                    v, self.body.env.num_envs)
                    advs, v_targets = minibatch['advs'], minibatch['v_targets']
                    pdparams, v_preds = self.calc_pdparam_v(minibatch)
                    policy_loss = self.calc_policy_loss(
                        minibatch, pdparams, advs)  # from actor
                    val_loss = self.calc_val_loss(v_preds,
                                                  v_targets)  # from critic
                    if self.shared:  # shared network
                        loss = policy_loss + val_loss
                        self.net.train_step(loss,
                                            self.optim,
                                            self.lr_scheduler,
                                            clock=clock,
                                            global_net=self.global_net)
                    else:
                        # pretrain_finished = false -> policy keep fixed, updating value net and disc
                        if not self.pretrain_finished:
                            self.critic_net.train_step(
                                val_loss,
                                self.critic_optim,
                                self.critic_lr_scheduler,
                                clock=clock,
                                global_net=self.global_critic_net)
                            loss = val_loss
                        if self.pretrain_finished and self.policy_training_flag:
                            self.net.train_step(policy_loss,
                                                self.optim,
                                                self.lr_scheduler,
                                                clock=clock,
                                                global_net=self.global_net)
                            self.critic_net.train_step(
                                val_loss,
                                self.critic_optim,
                                self.critic_lr_scheduler,
                                clock=clock,
                                global_net=self.global_critic_net)
                            loss = policy_loss + val_loss

                    total_loss += loss
            loss = total_loss / self.training_epoch / len(minibatches)
            if not self.pretrain_finished:
                logger.info(
                    "warmup Value net, epi: {}, frame: {}, loss: {}".format(
                        clock.epi, clock.frame, loss))
            # reset
            self.to_train = 0
            self.policy_training_flag = False
            logger.debug(
                f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
            )
            return loss.item()
        else:
            return np.nan
Exemplo n.º 26
0
 def train(self):
     if util.in_eval_lab_modes():
         return np.nan
     clock = self.body.env.clock
     if self.to_train == 1:
         net_util.copy(self.net, self.old_net)  # update old net
         batch = self.sample()
         clock.set_batch_size(len(batch))
         _pdparams, v_preds = self.calc_pdparam_v(batch)
         advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
         # piggy back on batch, but remember to not pack or unpack
         batch['advs'], batch['v_targets'] = advs, v_targets
         if self.body.env.is_venv:  # unpack if venv for minibatch sampling
             for k, v in batch.items():
                 if k not in ('advs', 'v_targets'):
                     batch[k] = math_util.venv_unpack(v)
         total_loss = torch.tensor(0.0)
         for _ in range(self.training_epoch):
             minibatches = util.split_minibatch(batch, self.minibatch_size)
             for minibatch in minibatches:
                 if self.body.env.is_venv:  # re-pack to restore proper shape
                     for k, v in minibatch.items():
                         if k not in ('advs', 'v_targets'):
                             minibatch[k] = math_util.venv_pack(
                                 v, self.body.env.num_envs)
                 advs, v_targets = minibatch['advs'], minibatch['v_targets']
                 pdparams, v_preds = self.calc_pdparam_v(minibatch)
                 policy_loss = self.calc_policy_loss(
                     minibatch, pdparams, advs)  # from actor
                 val_loss = self.calc_val_loss(v_preds,
                                               v_targets)  # from critic
                 if self.shared:  # shared network
                     loss = policy_loss + val_loss
                     self.net.train_step(loss,
                                         self.optim,
                                         self.lr_scheduler,
                                         clock=clock,
                                         global_net=self.global_net)
                 else:
                     self.net.train_step(policy_loss,
                                         self.optim,
                                         self.lr_scheduler,
                                         clock=clock,
                                         global_net=self.global_net)
                     self.critic_net.train_step(
                         val_loss,
                         self.critic_optim,
                         self.critic_lr_scheduler,
                         clock=clock,
                         global_net=self.global_critic_net)
                     loss = policy_loss + val_loss
                 total_loss += loss
         loss = total_loss / self.training_epoch / len(minibatches)
         # reset
         self.to_train = 0
         logger.debug(
             f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
         )
         return loss.item()
     else:
         return np.nan