def compute_exact_projection(target_q, network, states, weights, N=10000, robust=True, **optimizer_args): network = network optimizer = torch.optim.Adam(network.parameters(), **optimizer_args) pt_target_q = ptu.tensor(target_q) for k in range(N): q_values = network(states) if robust: s_weights = torch.sum(torch.abs(q_values - pt_target_q), dim=1).detach() s_weights = s_weights / torch.sum(s_weights) * len(s_weights) else: s_weights = ptu.tensor(weights) critic_loss = torch.mean(s_weights * torch.mean( (q_values - pt_target_q)**2, dim=1)) network.zero_grad() critic_loss.backward() optimizer.step() if k % 1000 == 0: logger.log('Itr %d exact projection loss: %f' % (k, ptu.to_numpy(critic_loss))) proj_q = ptu.to_numpy(network(states)) network.reset_weights() logger.log('Exact projection abs diff: %f' % (np.mean(np.abs(weighted_q_diff(target_q, proj_q, weights))))) return proj_q
def update(self, q_network=None, fqi=None, all_target_q=None, **kwargs): # compute oracle loss q_values = q_network(fqi.all_states).detach() weights = ptu.tensor(fqi.validation_sa_weights) oracle_loss = ptu.to_numpy( torch.sum(weights * ((q_values - all_target_q)**2))) prev_oracle_loss = self.validation_loss self.validation_k_counter += 1 if oracle_loss < prev_oracle_loss: self.best_validation_qs = q_values self.validation_loss = oracle_loss self.validation_k = self.validation_k_counter - 1
def pre_project(self): if self.sampling_policy == 'adversarial': q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None, mode=fqi.MULTIPLE_HEADS)) errors = np.abs(q_vals - self.all_target_q_np)**0.5 # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, atol=1e-5) self.adversarial_qs = adversarial_qs self.batch_s, self.batch_a, self.batch_ns, self.batch_r = self.collect_samples( ) self._total_samples += len(self.batch_s) logger.record_tabular('total_samples', self._total_samples)
def __init__(self, env, network, min_project_steps=5, max_project_steps=50, lr=1e-3, discount=0.99, n_steps=1, ent_wt=1.0, stop_modes=tuple(), backup_mode=BACKUP_EXACT, n_eval_trials=50, log_proj_qstar=False, target_mode='tq', optimizer='adam', smooth_target_tau=1.0, **kwargs): self.env = env self.network = network self.discount = discount self.ent_wt = ent_wt self.n_eval_trials = n_eval_trials self.lr = lr self.max_q = 1e5 self.target_mode = target_mode self.backup_mode = backup_mode if backup_mode == BACKUP_EXACT: self.q_format = MULTIPLE_HEADS else: self.q_format = FLAT self.min_project_steps = min_project_steps self.max_project_steps = max_project_steps self.stop_modes = stop_modes self.n_steps = n_steps self.lr = lr if optimizer == 'adam': self.qnet_optimizer = torch.optim.Adam(network.parameters(), lr=lr) elif optimizer == 'gd': self.qnet_optimizer = torch.optim.SGD(network.parameters(), lr=lr) else: raise ValueError("Unknown optimizer: %s" % optimizer) self.all_states = ptu.tensor(np.arange(env.num_states), dtype=torch.int64) with log_utils.timer('ground_truth_q'): self.ground_truth_q = q_iteration_cy.softq_iteration( self.env, num_itrs=max(self.env.num_states * 2, 1000), discount=self.discount, ent_wt=self.ent_wt) self.ground_truth_q_torch = ptu.tensor(self.ground_truth_q) self.valid_weights = np.sum(self.ground_truth_q, axis=1) self.valid_weights[self.valid_weights != 0] = 1.0 self.current_q = self.ground_truth_q returns = self.eval_policy(render=False, n_rollouts=self.n_eval_trials * 5) self.expert_returns = returns self.current_q = np.zeros_like(self.ground_truth_q) returns = self.eval_policy(render=False, n_rollouts=self.n_eval_trials * 5) self.random_returns = returns self.normalize_returns = lambda x: (x - self.random_returns) / ( self.expert_returns - self.random_returns) self.current_q = ptu.to_numpy(self.network(self.all_states)) # compute Proj(Q*) self.log_proj_qstar = log_proj_qstar self.ground_truth_q_proj = np.zeros_like(self.ground_truth_q) if log_proj_qstar: with log_utils.timer('proj_qstar'): self.ground_truth_q_proj = compute_exact_projection( self.ground_truth_q, network, self.all_states, weights=self.valid_weights, lr=lr) diff = weighted_q_diff(self.ground_truth_q, self.ground_truth_q_proj, self.valid_weights) self.qstar_abs_diff = np.abs(diff) self.smooth_target_tau = smooth_target_tau self.smooth_previous_target = np.zeros_like(self.ground_truth_q)
def update(self, step=-1): start_time = time.time() # backup with log_utils.timer('compute_backup'): self.all_target_q_np = q_iteration_cy.softq_iteration( self.env, num_itrs=self.n_steps, warmstart_q=self.current_q, discount=self.discount, ent_wt=self.ent_wt) # smooth if self.smooth_target_tau < 1.0: self.all_target_q_np = self.smooth_target_tau * self.all_target_q_np + ( 1 - self.smooth_target_tau) * self.current_q self.all_target_q = ptu.tensor(self.all_target_q_np) # project with log_utils.timer('pre_project'): self.pre_project() stopped_mode, critic_loss, k = self.project() if isinstance(stopped_mode, stopping.ValidationLoss): self.current_q = ptu.to_numpy(stopped_mode.best_validation_qs) logger.record_tabular('validation_stop_step', stopped_mode.validation_k) else: self.current_q = ptu.to_numpy(self.network(self.all_states)) self.current_q = np.minimum(self.current_q, self.max_q) # clip when diverging self.post_project() with log_utils.timer('eval_policy'): returns = self.eval_policy() logger.record_tabular('project_loss', ptu.to_numpy(critic_loss)) logger.record_tabular('fit_steps', k) if step >= 0: logger.record_tabular('step', step) # Logging logger.record_tabular('fit_q_value_mean', np.mean(self.current_q)) logger.record_tabular('target_q_value_mean', np.mean(self.all_target_q_np)) logger.record_tabular('returns_expert', self.expert_returns) logger.record_tabular('returns_random', self.random_returns) logger.record_tabular('returns', returns) log_utils.record_tabular_moving('returns', returns, n=50) logger.record_tabular('returns_normalized', self.normalize_returns(returns)) log_utils.record_tabular_moving('returns_normalized', self.normalize_returns(returns), n=50) # measure contraction errors diff_tq_qstar = weighted_q_diff(self.all_target_q_np, self.ground_truth_q, self.valid_weights) abs_diff_tq_qstar = np.abs(diff_tq_qstar) log_utils.record_tabular_stats('tq_q*_diff', diff_tq_qstar) log_utils.record_tabular_stats('tq_q*_diff_abs', abs_diff_tq_qstar) if self.log_proj_qstar: diff = weighted_q_diff(self.current_q, self.ground_truth_q_proj, self.valid_weights) abs_diff = np.abs(diff) log_utils.record_tabular_stats('q*_proj_diff', diff) log_utils.record_tabular_stats('q*_proj_diff_abs', abs_diff) log_utils.record_tabular_stats('ground_truth_error', self.qstar_abs_diff) logger.record_tabular('iteration_time', time.time() - start_time) logger.dump_tabular()
def project(self, network=None, optimizer=None, sampler=None): if network is None: network = self.network if optimizer is None: optimizer = self.qnet_optimizer if sampler is None: sampler = self.get_sample_states k = 0 stopped_mode = None [stop_mode.reset() for stop_mode in self.stop_modes] with log_utils.timer('projection') as timer: for k in range(self.max_project_steps): with timer.subtimer('compute_samples_weights'): sample_s, sample_a, sample_ns, sample_r, weights = sampler( itr=k) sample_s, sample_a = ptu.all_tensor([sample_s, sample_a], dtype=torch.int64) weights, = ptu.all_tensor([weights]) with timer.subtimer('eval_target'): target_q = self.evaluate_target(sample_s, sample_a, sample_ns, sample_r).detach() with timer.subtimer('eval_q'): q_values = self.evaluate_qvalues(sample_s, sample_a, network=network) if self.q_format == MULTIPLE_HEADS: if len(weights.shape) == 2: critic_loss = torch.mean(weights * (q_values - target_q)**2) else: critic_loss = torch.mean(weights * torch.mean( (q_values - target_q)**2, dim=1)) else: critic_loss = torch.mean(weights * (q_values - target_q)**2) with timer.subtimer('backprop'): network.zero_grad() critic_loss.backward() optimizer.step() stop_args = dict(critic_loss=ptu.to_numpy(critic_loss), q_network=network, all_target_q=self.all_target_q, fqi=self, discount=self.discount, ent_wt=self.ent_wt) [ stop_mode.update(**stop_args) for stop_mode in self.stop_modes ] if (k >= self.min_project_steps): stopped = False for stop_mode in self.stop_modes: if stop_mode.check(): logger.log('Early stopping via %s.' % stop_mode) stopped = True stopped_mode = stop_mode break if stopped: break return stopped_mode, critic_loss, k
def compute_weights(self, samples, itr=0, clip_min=1e-6, clip_max=100.0): if self.wscheme == 'robust_prioritized': q_vals = self.evaluate_qvalues(samples[0], samples[1]).detach() target_qs = self.evaluate_target(samples[0], samples[1], samples[2], samples[3]).detach() error = torch.abs(q_vals - target_qs) weights = ptu.to_numpy(error) elif self.wscheme == 'robust_adversarial': # solve for max_pi [bellman error] # compute bellman errors q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None, mode=fqi.MULTIPLE_HEADS)) errors = np.abs(q_vals - self.all_target_q_np) # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, warmstart_q=self.warmstart_adversarial_q, atol=1e-5) self.warmstart_adversarial_q = adversarial_qs visit_sa = q_iteration_py.compute_visitation( self.env, adversarial_qs, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = visit_sa[samples[0], samples[1]] elif self.wscheme == 'robust_adversarial_fast': if itr % 10 == 0: # solve for max_pi [bellman error] # compute bellman errors q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None, mode=fqi.MULTIPLE_HEADS)) errors = np.abs(q_vals - self.all_target_q_np) # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, warmstart_q=self.warmstart_adversarial_q, atol=1e-5) self.warmstart_adversarial_q = adversarial_qs self.adv_visit_sa = q_iteration_py.compute_visitation( self.env, adversarial_qs, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = self.adv_visit_sa[samples[0], samples[1]] else: weights = self.sa_weights[samples[0], samples[1]] weights = (weights / np.sum(weights)) # normalize weights = np.minimum(weights, clip_max) weights = np.maximum(weights, clip_min) weights = (weights / np.sum(weights)) # normalize return weights
def get_sample_states(self, itr=0): if itr % 5 == 0: # compute weights weights = None if self.wscheme == 'uniform': weights = np.ones((self.env.num_states, self.env.num_actions)) elif self.wscheme == 'buffer_infinite': weights = self.buffer_sa elif self.wscheme == 'buffer10': weights = self.buffer_sa elif self.wscheme == 'pi*': weights = self.visit_sa elif self.wscheme == 'pi*proj': assert self.log_proj_qstar weights = self.opt_proj_visit_sa elif self.wscheme == 'random': weights = self.pi_visit_sa elif self.wscheme == 'pi': weights = self.pi_visit_sa elif self.wscheme == 'online': q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None)) visit_sa = q_iteration_py.compute_visitation( self.env, q_vals, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = visit_sa elif self.wscheme == 'robust_prioritized': q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None)) errors = np.abs(q_vals - self.all_target_q_np) weights = errors elif self.wscheme == 'robust_adversarial': # solve for max_pi [bellman error] # compute bellman errors q_vals = ptu.to_numpy( self.evaluate_qvalues(np.arange(0, self.env.num_states), None)) errors = np.abs(q_vals - self.all_target_q_np) # pick adversarial distribution - reward is bellman error adversarial_qs = q_iteration.softq_iteration_custom_reward( self.env, reward=errors, num_itrs=self.time_limit, discount=self.discount, ent_wt=self.ent_wt, warmstart_q=self.warmstart_adversarial_q, atol=1e-5) self.warmstart_adversarial_q = adversarial_qs visit_sa = q_iteration_py.compute_visitation( self.env, adversarial_qs, ent_wt=self.ent_wt, discount=self.discount, env_time_limit=self.time_limit) weights = visit_sa else: raise ValueError("Unknown weighting scheme: %s" % self.wscheme) if self.weight_states_only: weights = np.sum(weights, axis=1) weights = np.repeat(weights[:, np.newaxis], self.env.num_actions, axis=-1) self.weights = (weights / np.sum(weights)) # normalize if itr == 0: entropy = -np.sum(self.weights * np.log(self.weights + 1e-6)) logger.record_tabular('weight_entropy', entropy) unif = np.ones_like(self.weights) / float(self.weights.size) max_entropy = -np.sum(unif * np.log(unif)) logger.record_tabular('weight_entropy_normalized', entropy / max_entropy) return np.arange(0, self.env.num_states), None, None, None, self.weights