def test_qiteration(self): params = { 'num_itrs': 50, 'ent_wt': 1.0, 'discount': 0.99, } qvals_py = q_iteration_py.softq_iteration(self.env, **params) qvals_cy = q_iteration.softq_iteration(self.env, **params) self.assertTrue(np.allclose(qvals_cy, qvals_py))
def test_q_iteration(self): params = { 'num_itrs': 1000, 'ent_wt': 0.1, 'discount': 0.95, } qvals = q_iteration.softq_iteration(self.env_small, **params) self.env_small.reset() for _ in range(50): #self.env_small.render() a_qvals = qvals[self.env_small.get_state()] _, rew, _, _ = self.env_small.step(np.argmax(a_qvals)) self.assertEqual(rew, 1.0)
def test_q_iteration(self): params = { 'num_itrs': 1000, 'ent_wt': 0.0, 'discount': 0.95, } qvals = q_iteration.softq_iteration(self.env, **params) self.env.reset() rews = 0 for _ in range(200): #self.env_small.render() a_qvals = qvals[self.env.get_state()] _, rew, _, _ = self.env.step(np.argmax(a_qvals)) rews += rew self.assertGreater(rews, 0.0)
def __init__(self, env, network, min_project_steps=5, max_project_steps=50, lr=1e-3, discount=0.99, n_steps=1, ent_wt=1.0, stop_modes=tuple(), backup_mode=BACKUP_EXACT, n_eval_trials=50, log_proj_qstar=False, target_mode='tq', optimizer='adam', smooth_target_tau=1.0, **kwargs): self.env = env self.network = network self.discount = discount self.ent_wt = ent_wt self.n_eval_trials = n_eval_trials self.lr = lr self.max_q = 1e5 self.target_mode = target_mode self.backup_mode = backup_mode if backup_mode == BACKUP_EXACT: self.q_format = MULTIPLE_HEADS else: self.q_format = FLAT self.min_project_steps = min_project_steps self.max_project_steps = max_project_steps self.stop_modes = stop_modes self.n_steps = n_steps self.lr = lr if optimizer == 'adam': self.qnet_optimizer = torch.optim.Adam(network.parameters(), lr=lr) elif optimizer == 'gd': self.qnet_optimizer = torch.optim.SGD(network.parameters(), lr=lr) else: raise ValueError("Unknown optimizer: %s" % optimizer) self.all_states = ptu.tensor(np.arange(env.num_states), dtype=torch.int64) with log_utils.timer('ground_truth_q'): self.ground_truth_q = q_iteration_cy.softq_iteration( self.env, num_itrs=max(self.env.num_states * 2, 1000), discount=self.discount, ent_wt=self.ent_wt) self.ground_truth_q_torch = ptu.tensor(self.ground_truth_q) self.valid_weights = np.sum(self.ground_truth_q, axis=1) self.valid_weights[self.valid_weights != 0] = 1.0 self.current_q = self.ground_truth_q returns = self.eval_policy(render=False, n_rollouts=self.n_eval_trials * 5) self.expert_returns = returns self.current_q = np.zeros_like(self.ground_truth_q) returns = self.eval_policy(render=False, n_rollouts=self.n_eval_trials * 5) self.random_returns = returns self.normalize_returns = lambda x: (x - self.random_returns) / ( self.expert_returns - self.random_returns) self.current_q = ptu.to_numpy(self.network(self.all_states)) # compute Proj(Q*) self.log_proj_qstar = log_proj_qstar self.ground_truth_q_proj = np.zeros_like(self.ground_truth_q) if log_proj_qstar: with log_utils.timer('proj_qstar'): self.ground_truth_q_proj = compute_exact_projection( self.ground_truth_q, network, self.all_states, weights=self.valid_weights, lr=lr) diff = weighted_q_diff(self.ground_truth_q, self.ground_truth_q_proj, self.valid_weights) self.qstar_abs_diff = np.abs(diff) self.smooth_target_tau = smooth_target_tau self.smooth_previous_target = np.zeros_like(self.ground_truth_q)
def update(self, step=-1): start_time = time.time() # backup with log_utils.timer('compute_backup'): self.all_target_q_np = q_iteration_cy.softq_iteration( self.env, num_itrs=self.n_steps, warmstart_q=self.current_q, discount=self.discount, ent_wt=self.ent_wt) # smooth if self.smooth_target_tau < 1.0: self.all_target_q_np = self.smooth_target_tau * self.all_target_q_np + ( 1 - self.smooth_target_tau) * self.current_q self.all_target_q = ptu.tensor(self.all_target_q_np) # project with log_utils.timer('pre_project'): self.pre_project() stopped_mode, critic_loss, k = self.project() if isinstance(stopped_mode, stopping.ValidationLoss): self.current_q = ptu.to_numpy(stopped_mode.best_validation_qs) logger.record_tabular('validation_stop_step', stopped_mode.validation_k) else: self.current_q = ptu.to_numpy(self.network(self.all_states)) self.current_q = np.minimum(self.current_q, self.max_q) # clip when diverging self.post_project() with log_utils.timer('eval_policy'): returns = self.eval_policy() logger.record_tabular('project_loss', ptu.to_numpy(critic_loss)) logger.record_tabular('fit_steps', k) if step >= 0: logger.record_tabular('step', step) # Logging logger.record_tabular('fit_q_value_mean', np.mean(self.current_q)) logger.record_tabular('target_q_value_mean', np.mean(self.all_target_q_np)) logger.record_tabular('returns_expert', self.expert_returns) logger.record_tabular('returns_random', self.random_returns) logger.record_tabular('returns', returns) log_utils.record_tabular_moving('returns', returns, n=50) logger.record_tabular('returns_normalized', self.normalize_returns(returns)) log_utils.record_tabular_moving('returns_normalized', self.normalize_returns(returns), n=50) # measure contraction errors diff_tq_qstar = weighted_q_diff(self.all_target_q_np, self.ground_truth_q, self.valid_weights) abs_diff_tq_qstar = np.abs(diff_tq_qstar) log_utils.record_tabular_stats('tq_q*_diff', diff_tq_qstar) log_utils.record_tabular_stats('tq_q*_diff_abs', abs_diff_tq_qstar) if self.log_proj_qstar: diff = weighted_q_diff(self.current_q, self.ground_truth_q_proj, self.valid_weights) abs_diff = np.abs(diff) log_utils.record_tabular_stats('q*_proj_diff', diff) log_utils.record_tabular_stats('q*_proj_diff_abs', abs_diff) log_utils.record_tabular_stats('ground_truth_error', self.qstar_abs_diff) logger.record_tabular('iteration_time', time.time() - start_time) logger.dump_tabular()