Example #1
0
 def test_qiteration(self):
     params = {
         'num_itrs': 50,
         'ent_wt': 1.0,
         'discount': 0.99,
     }
     qvals_py = q_iteration_py.softq_iteration(self.env, **params)
     qvals_cy = q_iteration.softq_iteration(self.env, **params)
     self.assertTrue(np.allclose(qvals_cy, qvals_py))
Example #2
0
 def test_q_iteration(self):
     params = {
         'num_itrs': 1000,
         'ent_wt': 0.1,
         'discount': 0.95,
     }
     qvals = q_iteration.softq_iteration(self.env_small, **params)
     self.env_small.reset()
     for _ in range(50):
         #self.env_small.render()
         a_qvals = qvals[self.env_small.get_state()]
         _, rew, _, _ = self.env_small.step(np.argmax(a_qvals))
     self.assertEqual(rew, 1.0)
Example #3
0
 def test_q_iteration(self):
     params = {
         'num_itrs': 1000,
         'ent_wt': 0.0,
         'discount': 0.95,
     }
     qvals = q_iteration.softq_iteration(self.env, **params)
     self.env.reset()
     rews = 0
     for _ in range(200):
         #self.env_small.render()
         a_qvals = qvals[self.env.get_state()]
         _, rew, _, _ = self.env.step(np.argmax(a_qvals))
         rews += rew
     self.assertGreater(rews, 0.0)
Example #4
0
    def __init__(self,
                 env,
                 network,
                 min_project_steps=5,
                 max_project_steps=50,
                 lr=1e-3,
                 discount=0.99,
                 n_steps=1,
                 ent_wt=1.0,
                 stop_modes=tuple(),
                 backup_mode=BACKUP_EXACT,
                 n_eval_trials=50,
                 log_proj_qstar=False,
                 target_mode='tq',
                 optimizer='adam',
                 smooth_target_tau=1.0,
                 **kwargs):
        self.env = env
        self.network = network
        self.discount = discount
        self.ent_wt = ent_wt
        self.n_eval_trials = n_eval_trials
        self.lr = lr
        self.max_q = 1e5

        self.target_mode = target_mode

        self.backup_mode = backup_mode
        if backup_mode == BACKUP_EXACT:
            self.q_format = MULTIPLE_HEADS
        else:
            self.q_format = FLAT

        self.min_project_steps = min_project_steps
        self.max_project_steps = max_project_steps
        self.stop_modes = stop_modes
        self.n_steps = n_steps
        self.lr = lr

        if optimizer == 'adam':
            self.qnet_optimizer = torch.optim.Adam(network.parameters(), lr=lr)
        elif optimizer == 'gd':
            self.qnet_optimizer = torch.optim.SGD(network.parameters(), lr=lr)
        else:
            raise ValueError("Unknown optimizer: %s" % optimizer)
        self.all_states = ptu.tensor(np.arange(env.num_states),
                                     dtype=torch.int64)

        with log_utils.timer('ground_truth_q'):
            self.ground_truth_q = q_iteration_cy.softq_iteration(
                self.env,
                num_itrs=max(self.env.num_states * 2, 1000),
                discount=self.discount,
                ent_wt=self.ent_wt)
            self.ground_truth_q_torch = ptu.tensor(self.ground_truth_q)
        self.valid_weights = np.sum(self.ground_truth_q, axis=1)
        self.valid_weights[self.valid_weights != 0] = 1.0
        self.current_q = self.ground_truth_q
        returns = self.eval_policy(render=False,
                                   n_rollouts=self.n_eval_trials * 5)
        self.expert_returns = returns
        self.current_q = np.zeros_like(self.ground_truth_q)
        returns = self.eval_policy(render=False,
                                   n_rollouts=self.n_eval_trials * 5)
        self.random_returns = returns
        self.normalize_returns = lambda x: (x - self.random_returns) / (
            self.expert_returns - self.random_returns)

        self.current_q = ptu.to_numpy(self.network(self.all_states))
        # compute Proj(Q*)
        self.log_proj_qstar = log_proj_qstar
        self.ground_truth_q_proj = np.zeros_like(self.ground_truth_q)
        if log_proj_qstar:
            with log_utils.timer('proj_qstar'):
                self.ground_truth_q_proj = compute_exact_projection(
                    self.ground_truth_q,
                    network,
                    self.all_states,
                    weights=self.valid_weights,
                    lr=lr)
            diff = weighted_q_diff(self.ground_truth_q,
                                   self.ground_truth_q_proj,
                                   self.valid_weights)
            self.qstar_abs_diff = np.abs(diff)

        self.smooth_target_tau = smooth_target_tau
        self.smooth_previous_target = np.zeros_like(self.ground_truth_q)
Example #5
0
    def update(self, step=-1):
        start_time = time.time()
        # backup
        with log_utils.timer('compute_backup'):
            self.all_target_q_np = q_iteration_cy.softq_iteration(
                self.env,
                num_itrs=self.n_steps,
                warmstart_q=self.current_q,
                discount=self.discount,
                ent_wt=self.ent_wt)
            # smooth
            if self.smooth_target_tau < 1.0:
                self.all_target_q_np = self.smooth_target_tau * self.all_target_q_np + (
                    1 - self.smooth_target_tau) * self.current_q
            self.all_target_q = ptu.tensor(self.all_target_q_np)

        # project
        with log_utils.timer('pre_project'):
            self.pre_project()

        stopped_mode, critic_loss, k = self.project()

        if isinstance(stopped_mode, stopping.ValidationLoss):
            self.current_q = ptu.to_numpy(stopped_mode.best_validation_qs)
            logger.record_tabular('validation_stop_step',
                                  stopped_mode.validation_k)
        else:
            self.current_q = ptu.to_numpy(self.network(self.all_states))
        self.current_q = np.minimum(self.current_q,
                                    self.max_q)  # clip when diverging
        self.post_project()
        with log_utils.timer('eval_policy'):
            returns = self.eval_policy()

        logger.record_tabular('project_loss', ptu.to_numpy(critic_loss))
        logger.record_tabular('fit_steps', k)
        if step >= 0:
            logger.record_tabular('step', step)

        # Logging
        logger.record_tabular('fit_q_value_mean', np.mean(self.current_q))
        logger.record_tabular('target_q_value_mean',
                              np.mean(self.all_target_q_np))
        logger.record_tabular('returns_expert', self.expert_returns)
        logger.record_tabular('returns_random', self.random_returns)
        logger.record_tabular('returns', returns)
        log_utils.record_tabular_moving('returns', returns, n=50)
        logger.record_tabular('returns_normalized',
                              self.normalize_returns(returns))
        log_utils.record_tabular_moving('returns_normalized',
                                        self.normalize_returns(returns),
                                        n=50)

        # measure contraction errors
        diff_tq_qstar = weighted_q_diff(self.all_target_q_np,
                                        self.ground_truth_q,
                                        self.valid_weights)
        abs_diff_tq_qstar = np.abs(diff_tq_qstar)
        log_utils.record_tabular_stats('tq_q*_diff', diff_tq_qstar)
        log_utils.record_tabular_stats('tq_q*_diff_abs', abs_diff_tq_qstar)

        if self.log_proj_qstar:
            diff = weighted_q_diff(self.current_q, self.ground_truth_q_proj,
                                   self.valid_weights)
            abs_diff = np.abs(diff)
            log_utils.record_tabular_stats('q*_proj_diff', diff)
            log_utils.record_tabular_stats('q*_proj_diff_abs', abs_diff)
            log_utils.record_tabular_stats('ground_truth_error',
                                           self.qstar_abs_diff)

        logger.record_tabular('iteration_time', time.time() - start_time)

        logger.dump_tabular()