Ejemplo n.º 1
0
    def test_visitations(self):
        env = tabular_env.CliffwalkEnv(num_states=3, transition_noise=0.00)
        params = {
            'num_itrs': 50,
            'ent_wt': 0.0,
            'discount': 0.99,
        }
        qvals_py = q_iteration_py.softq_iteration(env, **params)

        visitations = q_iteration_py.compute_visitation(env,
                                                        qvals_py,
                                                        ent_wt=0.0,
                                                        env_time_limit=1)
        s_visitations = np.sum(visitations, axis=1)
        tru_visits = np.array([1, 0, 0])
        self.assertTrue(np.allclose(tru_visits, s_visitations))

        visitations = q_iteration_py.compute_visitation(env,
                                                        qvals_py,
                                                        ent_wt=0.0,
                                                        env_time_limit=3)
        s_visitations = np.sum(visitations, axis=1)
        tru_visits = np.array([1, 1, 1]) / 3.0
        self.assertTrue(np.allclose(tru_visits, s_visitations))

        visitations = q_iteration_py.compute_visitation(env,
                                                        qvals_py,
                                                        ent_wt=0.0,
                                                        env_time_limit=5)
        s_visitations = np.sum(visitations, axis=1)
        tru_visits = np.array([2, 2, 1]) / 5.0
        self.assertTrue(np.allclose(tru_visits, s_visitations))
Ejemplo n.º 2
0
 def pre_project(self):
     if self.wscheme == 'pi':
         self.pi_visit_sa = q_iteration_py.compute_visitation(
             self.env,
             self.current_q,
             ent_wt=self.ent_wt,
             discount=self.discount,
             env_time_limit=self.time_limit)
     elif self.wscheme == 'random':
         self.pi_visit_sa = q_iteration_py.compute_visitation(
             self.env,
             np.zeros_like(self.current_q),
             ent_wt=self.ent_wt,
             discount=self.discount,
             env_time_limit=self.time_limit)
     elif self.wscheme == 'buffer_infinite':
         pi_visit_sa = q_iteration_py.compute_visitation(
             self.env,
             self.current_q,
             ent_wt=self.ent_wt,
             discount=self.discount,
             env_time_limit=self.time_limit)
         self.buffer_n += 1
         self.buffer_sa *= (self.buffer_n - 1) / self.buffer_n
         self.buffer_sa += pi_visit_sa / self.buffer_n
     elif self.wscheme == 'buffer10':
         pi_visit_sa = q_iteration_py.compute_visitation(
             self.env,
             self.current_q,
             ent_wt=self.ent_wt,
             discount=self.discount,
             env_time_limit=self.time_limit)
         self.buffer10.append(pi_visit_sa)
         self.buffer10 = self.buffer10[-10:]
         self.buffer_sa = np.mean(self.buffer10, axis=0)
 def pre_project(self):
     super(WeightedSamplingFQI, self).pre_project()
     self.sample_visit_sa = q_iteration_py.compute_visitation(
         self.env,
         self.sampling_q,
         ent_wt=self.ent_wt,
         discount=self.discount,
         env_time_limit=self.time_limit)
     self.sa_weights = compute_sa_weights(self, self.wscheme,
                                          self.sample_visit_sa)
     self.validation_sa_weights = self.sa_weights * self.sample_visit_sa
 def __init__(self, env, network, weighting_scheme='none', **kwargs):
     super(WeightedSamplingFQI, self).__init__(env, network, **kwargs)
     self.wscheme = weighting_scheme
     self.vfn = q_iteration_py.logsumexp(self.ground_truth_q,
                                         alpha=self.ent_wt)
     self.optimal_visit_sa = q_iteration_py.compute_visitation(
         self.env,
         self.ground_truth_q,
         ent_wt=self.ent_wt,
         discount=self.discount,
         env_time_limit=self.time_limit)
     self.warmstart_adversarial_q = np.zeros_like(self.ground_truth_q)
Ejemplo n.º 5
0
    def __init__(self,
                 env,
                 network,
                 weighting_scheme='uniform',
                 weight_states_only=False,
                 time_limit=100,
                 **kwargs):
        super(WeightedExactFQI, self).__init__(env, network, **kwargs)
        assert self.backup_mode == fqi.BACKUP_EXACT
        self.wscheme = weighting_scheme
        self.time_limit = time_limit
        self.weight_states_only = weight_states_only

        self.visit_sa = q_iteration_py.compute_visitation(
            self.env,
            self.ground_truth_q,
            ent_wt=self.ent_wt,
            discount=self.discount,
            env_time_limit=self.time_limit)
        self.warmstart_adversarial_q = self.ground_truth_q[:, :]

        self.opt_proj_visit_sa = q_iteration_py.compute_visitation(
            self.env,
            self.ground_truth_q_proj,
            ent_wt=self.ent_wt,
            discount=self.discount,
            env_time_limit=self.time_limit)

        self.buffer_sa = np.zeros_like(self.ground_truth_q)
        self.buffer_n = 0

        self.buffer10 = []

        self.prev_q_target = np.zeros_like(self.ground_truth_q)
        self.prev_q_value = np.zeros_like(self.ground_truth_q)
        self.prev_loss = 0
        self.prev_weights = np.zeros_like(self.ground_truth_q)
Ejemplo n.º 6
0
    def pre_project(self):
        super(WeightedBufferFQI, self).pre_project()
        self.sa_weights = sampling_fqi.compute_sa_weights(
            self, self.wscheme, self.replay_buffer.probs_sa())
        self.validation_sa_weights = self.sa_weights * self.replay_buffer.probs_sa(
        )

        # validation loss
        self.buffer_validation_sa_weights = self.replay_buffer.probs_sa()
        sample_visit_sa = q_iteration.compute_visitation(
            self.env,
            self.sampling_q,
            ent_wt=self.ent_wt,
            discount=self.discount,
            env_time_limit=self.time_limit)
        self.onpolicy_validation_sa_weights = sample_visit_sa
def compute_sa_weights(fqi, wscheme, sample_visit_sa):
    if wscheme == 'uniform':
        sa_weights = 1.0 / (sample_visit_sa + 1e-6)
    elif wscheme == 'pi*':
        sa_weights = fqi.optimal_visit_sa / (sample_visit_sa + 1e-6)
    elif wscheme == 'pi':
        current_visit_sa = q_iteration_py.compute_visitation(
            fqi.env,
            fqi.current_q,
            ent_wt=fqi.ent_wt,
            discount=fqi.discount,
            env_time_limit=fqi.time_limit)
        sa_weights = current_visit_sa / (sample_visit_sa + 1e-6)
    elif wscheme in [
            'robust_prioritized', 'robust_adversarial',
            'robust_adversarial_fast', 'none'
    ]:
        sa_weights = np.ones_like(fqi.current_q)
    else:
        raise ValueError('Unkown weighting scheme: %s' % wscheme)
    return sa_weights
def compute_weights(self, samples, itr=0, clip_min=1e-6, clip_max=100.0):
    if self.wscheme == 'robust_prioritized':
        q_vals = self.evaluate_qvalues(samples[0], samples[1]).detach()
        target_qs = self.evaluate_target(samples[0], samples[1], samples[2],
                                         samples[3]).detach()
        error = torch.abs(q_vals - target_qs)
        weights = ptu.to_numpy(error)
    elif self.wscheme == 'robust_adversarial':
        # solve for max_pi [bellman error]
        # compute bellman errors
        q_vals = ptu.to_numpy(
            self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                  None,
                                  mode=fqi.MULTIPLE_HEADS))
        errors = np.abs(q_vals - self.all_target_q_np)
        # pick adversarial distribution - reward is bellman error
        adversarial_qs = q_iteration.softq_iteration_custom_reward(
            self.env,
            reward=errors,
            num_itrs=self.time_limit,
            discount=self.discount,
            ent_wt=self.ent_wt,
            warmstart_q=self.warmstart_adversarial_q,
            atol=1e-5)
        self.warmstart_adversarial_q = adversarial_qs
        visit_sa = q_iteration_py.compute_visitation(
            self.env,
            adversarial_qs,
            ent_wt=self.ent_wt,
            discount=self.discount,
            env_time_limit=self.time_limit)
        weights = visit_sa[samples[0], samples[1]]
    elif self.wscheme == 'robust_adversarial_fast':
        if itr % 10 == 0:
            # solve for max_pi [bellman error]
            # compute bellman errors
            q_vals = ptu.to_numpy(
                self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                      None,
                                      mode=fqi.MULTIPLE_HEADS))
            errors = np.abs(q_vals - self.all_target_q_np)
            # pick adversarial distribution - reward is bellman error
            adversarial_qs = q_iteration.softq_iteration_custom_reward(
                self.env,
                reward=errors,
                num_itrs=self.time_limit,
                discount=self.discount,
                ent_wt=self.ent_wt,
                warmstart_q=self.warmstart_adversarial_q,
                atol=1e-5)
            self.warmstart_adversarial_q = adversarial_qs
            self.adv_visit_sa = q_iteration_py.compute_visitation(
                self.env,
                adversarial_qs,
                ent_wt=self.ent_wt,
                discount=self.discount,
                env_time_limit=self.time_limit)
        weights = self.adv_visit_sa[samples[0], samples[1]]
    else:
        weights = self.sa_weights[samples[0], samples[1]]
    weights = (weights / np.sum(weights))  # normalize
    weights = np.minimum(weights, clip_max)
    weights = np.maximum(weights, clip_min)
    weights = (weights / np.sum(weights))  # normalize
    return weights
Ejemplo n.º 9
0
    def get_sample_states(self, itr=0):
        if itr % 5 == 0:  # compute weights
            weights = None
            if self.wscheme == 'uniform':
                weights = np.ones((self.env.num_states, self.env.num_actions))
            elif self.wscheme == 'buffer_infinite':
                weights = self.buffer_sa
            elif self.wscheme == 'buffer10':
                weights = self.buffer_sa
            elif self.wscheme == 'pi*':
                weights = self.visit_sa
            elif self.wscheme == 'pi*proj':
                assert self.log_proj_qstar
                weights = self.opt_proj_visit_sa
            elif self.wscheme == 'random':
                weights = self.pi_visit_sa
            elif self.wscheme == 'pi':
                weights = self.pi_visit_sa
            elif self.wscheme == 'online':
                q_vals = ptu.to_numpy(
                    self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                          None))
                visit_sa = q_iteration_py.compute_visitation(
                    self.env,
                    q_vals,
                    ent_wt=self.ent_wt,
                    discount=self.discount,
                    env_time_limit=self.time_limit)
                weights = visit_sa
            elif self.wscheme == 'robust_prioritized':
                q_vals = ptu.to_numpy(
                    self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                          None))
                errors = np.abs(q_vals - self.all_target_q_np)
                weights = errors
            elif self.wscheme == 'robust_adversarial':
                # solve for max_pi [bellman error]
                # compute bellman errors
                q_vals = ptu.to_numpy(
                    self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                          None))
                errors = np.abs(q_vals - self.all_target_q_np)
                # pick adversarial distribution - reward is bellman error
                adversarial_qs = q_iteration.softq_iteration_custom_reward(
                    self.env,
                    reward=errors,
                    num_itrs=self.time_limit,
                    discount=self.discount,
                    ent_wt=self.ent_wt,
                    warmstart_q=self.warmstart_adversarial_q,
                    atol=1e-5)
                self.warmstart_adversarial_q = adversarial_qs
                visit_sa = q_iteration_py.compute_visitation(
                    self.env,
                    adversarial_qs,
                    ent_wt=self.ent_wt,
                    discount=self.discount,
                    env_time_limit=self.time_limit)
                weights = visit_sa
            else:
                raise ValueError("Unknown weighting scheme: %s" % self.wscheme)

            if self.weight_states_only:
                weights = np.sum(weights, axis=1)
                weights = np.repeat(weights[:, np.newaxis],
                                    self.env.num_actions,
                                    axis=-1)
            self.weights = (weights / np.sum(weights))  # normalize
        if itr == 0:
            entropy = -np.sum(self.weights * np.log(self.weights + 1e-6))
            logger.record_tabular('weight_entropy', entropy)
            unif = np.ones_like(self.weights) / float(self.weights.size)
            max_entropy = -np.sum(unif * np.log(unif))
            logger.record_tabular('weight_entropy_normalized',
                                  entropy / max_entropy)
        return np.arange(0,
                         self.env.num_states), None, None, None, self.weights