def test_visitations(self): env = tabular_env.CliffwalkEnv(num_states=3, transition_noise=0.00) params = { 'num_itrs': 50, 'ent_wt': 0.0, 'discount': 0.99, } qvals_py = q_iteration_py.softq_iteration(env, **params) visitations = q_iteration_py.compute_visitation(env, qvals_py, ent_wt=0.0, env_time_limit=1) s_visitations = np.sum(visitations, axis=1) tru_visits = np.array([1, 0, 0]) self.assertTrue(np.allclose(tru_visits, s_visitations)) visitations = q_iteration_py.compute_visitation(env, qvals_py, ent_wt=0.0, env_time_limit=3) s_visitations = np.sum(visitations, axis=1) tru_visits = np.array([1, 1, 1]) / 3.0 self.assertTrue(np.allclose(tru_visits, s_visitations)) visitations = q_iteration_py.compute_visitation(env, qvals_py, ent_wt=0.0, env_time_limit=5) s_visitations = np.sum(visitations, axis=1) tru_visits = np.array([2, 2, 1]) / 5.0 self.assertTrue(np.allclose(tru_visits, s_visitations))
def pre_project(self): if self.wscheme == 'pi': self.pi_visit_sa = q_iteration_py.compute_visitation( self.env, self.current_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) elif self.wscheme == 'random': self.pi_visit_sa = q_iteration_py.compute_visitation( self.env, np.zeros_like(self.current_q), ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) elif self.wscheme == 'buffer_infinite': pi_visit_sa = q_iteration_py.compute_visitation( self.env, self.current_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.buffer_n += 1 self.buffer_sa *= (self.buffer_n - 1) / self.buffer_n self.buffer_sa += pi_visit_sa / self.buffer_n elif self.wscheme == 'buffer10': pi_visit_sa = q_iteration_py.compute_visitation( self.env, self.current_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.buffer10.append(pi_visit_sa) self.buffer10 = self.buffer10[-10:] self.buffer_sa = np.mean(self.buffer10, axis=0)
def pre_project(self): super(WeightedSamplingFQI, self).pre_project() self.sample_visit_sa = q_iteration_py.compute_visitation( self.env, self.sampling_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.sa_weights = compute_sa_weights(self, self.wscheme, self.sample_visit_sa) self.validation_sa_weights = self.sa_weights * self.sample_visit_sa
def __init__(self, env, network, weighting_scheme='none', **kwargs): super(WeightedSamplingFQI, self).__init__(env, network, **kwargs) self.wscheme = weighting_scheme self.vfn = q_iteration_py.logsumexp(self.ground_truth_q, alpha=self.ent_wt) self.optimal_visit_sa = q_iteration_py.compute_visitation( self.env, self.ground_truth_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.warmstart_adversarial_q = np.zeros_like(self.ground_truth_q)
def __init__(self, env, network, weighting_scheme='uniform', weight_states_only=False, time_limit=100, **kwargs): super(WeightedExactFQI, self).__init__(env, network, **kwargs) assert self.backup_mode == fqi.BACKUP_EXACT self.wscheme = weighting_scheme self.time_limit = time_limit self.weight_states_only = weight_states_only self.visit_sa = q_iteration_py.compute_visitation( self.env, self.ground_truth_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.warmstart_adversarial_q = self.ground_truth_q[:, :] self.opt_proj_visit_sa = q_iteration_py.compute_visitation( self.env, self.ground_truth_q_proj, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.buffer_sa = np.zeros_like(self.ground_truth_q) self.buffer_n = 0 self.buffer10 = [] self.prev_q_target = np.zeros_like(self.ground_truth_q) self.prev_q_value = np.zeros_like(self.ground_truth_q) self.prev_loss = 0 self.prev_weights = np.zeros_like(self.ground_truth_q)
def pre_project(self): super(WeightedBufferFQI, self).pre_project() self.sa_weights = sampling_fqi.compute_sa_weights( self, self.wscheme, self.replay_buffer.probs_sa()) self.validation_sa_weights = self.sa_weights * self.replay_buffer.probs_sa( ) # validation loss self.buffer_validation_sa_weights = self.replay_buffer.probs_sa() sample_visit_sa = q_iteration.compute_visitation( self.env, self.sampling_q, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) self.onpolicy_validation_sa_weights = sample_visit_sa
def compute_sa_weights(fqi, wscheme, sample_visit_sa): if wscheme == 'uniform': sa_weights = 1.0 / (sample_visit_sa + 1e-6) elif wscheme == 'pi*': sa_weights = fqi.optimal_visit_sa / (sample_visit_sa + 1e-6) elif wscheme == 'pi': current_visit_sa = q_iteration_py.compute_visitation( fqi.env, fqi.current_q, ent_wt=fqi.ent_wt, discount=fqi.discount, env_time_limit=fqi.time_limit) sa_weights = current_visit_sa / (sample_visit_sa + 1e-6) elif wscheme in [ 'robust_prioritized', 'robust_adversarial', 'robust_adversarial_fast', 'none' ]: sa_weights = np.ones_like(fqi.current_q) else: raise ValueError('Unkown weighting scheme: %s' % wscheme) return sa_weights
def compute_weights(self, samples, itr=0, clip_min=1e-6, clip_max=100.0): if self.wscheme == 'robust_prioritized': q_vals = self.evaluate_qvalues(samples[0], samples[1]).detach() target_qs = self.evaluate_target(samples[0], samples[1], samples[2], samples[3]).detach() error = torch.abs(q_vals - target_qs) weights = ptu.to_numpy(error) elif self.wscheme == 'robust_adversarial': # solve for max_pi [bellman error] # compute bellman errors q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None, mode=fqi.MULTIPLE_HEADS)) errors = np.abs(q_vals - self.all_target_q_np) # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, warmstart_q=self.warmstart_adversarial_q, atol=1e-5) self.warmstart_adversarial_q = adversarial_qs visit_sa = q_iteration_py.compute_visitation( self.env, adversarial_qs, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = visit_sa[samples[0], samples[1]] elif self.wscheme == 'robust_adversarial_fast': if itr % 10 == 0: # solve for max_pi [bellman error] # compute bellman errors q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None, mode=fqi.MULTIPLE_HEADS)) errors = np.abs(q_vals - self.all_target_q_np) # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, warmstart_q=self.warmstart_adversarial_q, atol=1e-5) self.warmstart_adversarial_q = adversarial_qs self.adv_visit_sa = q_iteration_py.compute_visitation( self.env, adversarial_qs, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = self.adv_visit_sa[samples[0], samples[1]] else: weights = self.sa_weights[samples[0], samples[1]] weights = (weights / np.sum(weights)) # normalize weights = np.minimum(weights, clip_max) weights = np.maximum(weights, clip_min) weights = (weights / np.sum(weights)) # normalize return weights
def get_sample_states(self, itr=0): if itr % 5 == 0: # compute weights weights = None if self.wscheme == 'uniform': weights = np.ones((self.env.num_states, self.env.num_actions)) elif self.wscheme == 'buffer_infinite': weights = self.buffer_sa elif self.wscheme == 'buffer10': weights = self.buffer_sa elif self.wscheme == 'pi*': weights = self.visit_sa elif self.wscheme == 'pi*proj': assert self.log_proj_qstar weights = self.opt_proj_visit_sa elif self.wscheme == 'random': weights = self.pi_visit_sa elif self.wscheme == 'pi': weights = self.pi_visit_sa elif self.wscheme == 'online': q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None)) visit_sa = q_iteration_py.compute_visitation( self.env, q_vals, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = visit_sa elif self.wscheme == 'robust_prioritized': q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None)) errors = np.abs(q_vals - self.all_target_q_np) weights = errors elif self.wscheme == 'robust_adversarial': # solve for max_pi [bellman error] # compute bellman errors q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None)) errors = np.abs(q_vals - self.all_target_q_np) # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, warmstart_q=self.warmstart_adversarial_q, atol=1e-5) self.warmstart_adversarial_q = adversarial_qs visit_sa = q_iteration_py.compute_visitation( self.env, adversarial_qs, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = visit_sa else: raise ValueError("Unknown weighting scheme: %s" % self.wscheme) if self.weight_states_only: weights = np.sum(weights, axis=1) weights = np.repeat(weights[:, np.newaxis], self.env.num_actions, axis=-1) self.weights = (weights / np.sum(weights)) # normalize if itr == 0: entropy = -np.sum(self.weights * np.log(self.weights + 1e-6)) logger.record_tabular('weight_entropy', entropy) unif = np.ones_like(self.weights) / float(self.weights.size) max_entropy = -np.sum(unif * np.log(unif)) logger.record_tabular('weight_entropy_normalized', entropy / max_entropy) return np.arange(0, self.env.num_states), None, None, None, self.weights