コード例 #1
0
def compute_exact_projection(target_q,
                             network,
                             states,
                             weights,
                             N=10000,
                             robust=True,
                             **optimizer_args):
    network = network
    optimizer = torch.optim.Adam(network.parameters(), **optimizer_args)
    pt_target_q = ptu.tensor(target_q)
    for k in range(N):
        q_values = network(states)
        if robust:
            s_weights = torch.sum(torch.abs(q_values - pt_target_q),
                                  dim=1).detach()
            s_weights = s_weights / torch.sum(s_weights) * len(s_weights)
        else:
            s_weights = ptu.tensor(weights)
        critic_loss = torch.mean(s_weights * torch.mean(
            (q_values - pt_target_q)**2, dim=1))

        network.zero_grad()
        critic_loss.backward()
        optimizer.step()
        if k % 1000 == 0:
            logger.log('Itr %d exact projection loss: %f' %
                       (k, ptu.to_numpy(critic_loss)))
    proj_q = ptu.to_numpy(network(states))
    network.reset_weights()

    logger.log('Exact projection abs diff: %f' %
               (np.mean(np.abs(weighted_q_diff(target_q, proj_q, weights)))))
    return proj_q
コード例 #2
0
    def update(self, q_network=None, fqi=None, all_target_q=None, **kwargs):
        # compute oracle loss
        q_values = q_network(fqi.all_states).detach()
        weights = ptu.tensor(fqi.validation_sa_weights)
        oracle_loss = ptu.to_numpy(
            torch.sum(weights * ((q_values - all_target_q)**2)))
        prev_oracle_loss = self.validation_loss

        self.validation_k_counter += 1
        if oracle_loss < prev_oracle_loss:
            self.best_validation_qs = q_values
            self.validation_loss = oracle_loss
            self.validation_k = self.validation_k_counter - 1
コード例 #3
0
 def pre_project(self):
     if self.sampling_policy == 'adversarial':
         q_vals = ptu.to_numpy(
             self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                   None,
                                   mode=fqi.MULTIPLE_HEADS))
         errors = np.abs(q_vals - self.all_target_q_np)**0.5
         # pick adversarial distribution - reward is bellman error
         adversarial_qs = q_iteration.softq_iteration_custom_reward(
             self.env,
             reward=errors,
             num_itrs=self.time_limit,
             discount=self.discount,
             ent_wt=self.ent_wt,
             atol=1e-5)
         self.adversarial_qs = adversarial_qs
     self.batch_s, self.batch_a, self.batch_ns, self.batch_r = self.collect_samples(
     )
     self._total_samples += len(self.batch_s)
     logger.record_tabular('total_samples', self._total_samples)
コード例 #4
0
    def __init__(self,
                 env,
                 network,
                 min_project_steps=5,
                 max_project_steps=50,
                 lr=1e-3,
                 discount=0.99,
                 n_steps=1,
                 ent_wt=1.0,
                 stop_modes=tuple(),
                 backup_mode=BACKUP_EXACT,
                 n_eval_trials=50,
                 log_proj_qstar=False,
                 target_mode='tq',
                 optimizer='adam',
                 smooth_target_tau=1.0,
                 **kwargs):
        self.env = env
        self.network = network
        self.discount = discount
        self.ent_wt = ent_wt
        self.n_eval_trials = n_eval_trials
        self.lr = lr
        self.max_q = 1e5

        self.target_mode = target_mode

        self.backup_mode = backup_mode
        if backup_mode == BACKUP_EXACT:
            self.q_format = MULTIPLE_HEADS
        else:
            self.q_format = FLAT

        self.min_project_steps = min_project_steps
        self.max_project_steps = max_project_steps
        self.stop_modes = stop_modes
        self.n_steps = n_steps
        self.lr = lr

        if optimizer == 'adam':
            self.qnet_optimizer = torch.optim.Adam(network.parameters(), lr=lr)
        elif optimizer == 'gd':
            self.qnet_optimizer = torch.optim.SGD(network.parameters(), lr=lr)
        else:
            raise ValueError("Unknown optimizer: %s" % optimizer)
        self.all_states = ptu.tensor(np.arange(env.num_states),
                                     dtype=torch.int64)

        with log_utils.timer('ground_truth_q'):
            self.ground_truth_q = q_iteration_cy.softq_iteration(
                self.env,
                num_itrs=max(self.env.num_states * 2, 1000),
                discount=self.discount,
                ent_wt=self.ent_wt)
            self.ground_truth_q_torch = ptu.tensor(self.ground_truth_q)
        self.valid_weights = np.sum(self.ground_truth_q, axis=1)
        self.valid_weights[self.valid_weights != 0] = 1.0
        self.current_q = self.ground_truth_q
        returns = self.eval_policy(render=False,
                                   n_rollouts=self.n_eval_trials * 5)
        self.expert_returns = returns
        self.current_q = np.zeros_like(self.ground_truth_q)
        returns = self.eval_policy(render=False,
                                   n_rollouts=self.n_eval_trials * 5)
        self.random_returns = returns
        self.normalize_returns = lambda x: (x - self.random_returns) / (
            self.expert_returns - self.random_returns)

        self.current_q = ptu.to_numpy(self.network(self.all_states))
        # compute Proj(Q*)
        self.log_proj_qstar = log_proj_qstar
        self.ground_truth_q_proj = np.zeros_like(self.ground_truth_q)
        if log_proj_qstar:
            with log_utils.timer('proj_qstar'):
                self.ground_truth_q_proj = compute_exact_projection(
                    self.ground_truth_q,
                    network,
                    self.all_states,
                    weights=self.valid_weights,
                    lr=lr)
            diff = weighted_q_diff(self.ground_truth_q,
                                   self.ground_truth_q_proj,
                                   self.valid_weights)
            self.qstar_abs_diff = np.abs(diff)

        self.smooth_target_tau = smooth_target_tau
        self.smooth_previous_target = np.zeros_like(self.ground_truth_q)
コード例 #5
0
    def update(self, step=-1):
        start_time = time.time()
        # backup
        with log_utils.timer('compute_backup'):
            self.all_target_q_np = q_iteration_cy.softq_iteration(
                self.env,
                num_itrs=self.n_steps,
                warmstart_q=self.current_q,
                discount=self.discount,
                ent_wt=self.ent_wt)
            # smooth
            if self.smooth_target_tau < 1.0:
                self.all_target_q_np = self.smooth_target_tau * self.all_target_q_np + (
                    1 - self.smooth_target_tau) * self.current_q
            self.all_target_q = ptu.tensor(self.all_target_q_np)

        # project
        with log_utils.timer('pre_project'):
            self.pre_project()

        stopped_mode, critic_loss, k = self.project()

        if isinstance(stopped_mode, stopping.ValidationLoss):
            self.current_q = ptu.to_numpy(stopped_mode.best_validation_qs)
            logger.record_tabular('validation_stop_step',
                                  stopped_mode.validation_k)
        else:
            self.current_q = ptu.to_numpy(self.network(self.all_states))
        self.current_q = np.minimum(self.current_q,
                                    self.max_q)  # clip when diverging
        self.post_project()
        with log_utils.timer('eval_policy'):
            returns = self.eval_policy()

        logger.record_tabular('project_loss', ptu.to_numpy(critic_loss))
        logger.record_tabular('fit_steps', k)
        if step >= 0:
            logger.record_tabular('step', step)

        # Logging
        logger.record_tabular('fit_q_value_mean', np.mean(self.current_q))
        logger.record_tabular('target_q_value_mean',
                              np.mean(self.all_target_q_np))
        logger.record_tabular('returns_expert', self.expert_returns)
        logger.record_tabular('returns_random', self.random_returns)
        logger.record_tabular('returns', returns)
        log_utils.record_tabular_moving('returns', returns, n=50)
        logger.record_tabular('returns_normalized',
                              self.normalize_returns(returns))
        log_utils.record_tabular_moving('returns_normalized',
                                        self.normalize_returns(returns),
                                        n=50)

        # measure contraction errors
        diff_tq_qstar = weighted_q_diff(self.all_target_q_np,
                                        self.ground_truth_q,
                                        self.valid_weights)
        abs_diff_tq_qstar = np.abs(diff_tq_qstar)
        log_utils.record_tabular_stats('tq_q*_diff', diff_tq_qstar)
        log_utils.record_tabular_stats('tq_q*_diff_abs', abs_diff_tq_qstar)

        if self.log_proj_qstar:
            diff = weighted_q_diff(self.current_q, self.ground_truth_q_proj,
                                   self.valid_weights)
            abs_diff = np.abs(diff)
            log_utils.record_tabular_stats('q*_proj_diff', diff)
            log_utils.record_tabular_stats('q*_proj_diff_abs', abs_diff)
            log_utils.record_tabular_stats('ground_truth_error',
                                           self.qstar_abs_diff)

        logger.record_tabular('iteration_time', time.time() - start_time)

        logger.dump_tabular()
コード例 #6
0
    def project(self, network=None, optimizer=None, sampler=None):
        if network is None:
            network = self.network
        if optimizer is None:
            optimizer = self.qnet_optimizer
        if sampler is None:
            sampler = self.get_sample_states

        k = 0
        stopped_mode = None
        [stop_mode.reset() for stop_mode in self.stop_modes]
        with log_utils.timer('projection') as timer:
            for k in range(self.max_project_steps):
                with timer.subtimer('compute_samples_weights'):
                    sample_s, sample_a, sample_ns, sample_r, weights = sampler(
                        itr=k)
                sample_s, sample_a = ptu.all_tensor([sample_s, sample_a],
                                                    dtype=torch.int64)
                weights, = ptu.all_tensor([weights])

                with timer.subtimer('eval_target'):
                    target_q = self.evaluate_target(sample_s, sample_a,
                                                    sample_ns,
                                                    sample_r).detach()
                with timer.subtimer('eval_q'):
                    q_values = self.evaluate_qvalues(sample_s,
                                                     sample_a,
                                                     network=network)

                if self.q_format == MULTIPLE_HEADS:
                    if len(weights.shape) == 2:
                        critic_loss = torch.mean(weights *
                                                 (q_values - target_q)**2)
                    else:
                        critic_loss = torch.mean(weights * torch.mean(
                            (q_values - target_q)**2, dim=1))
                else:
                    critic_loss = torch.mean(weights *
                                             (q_values - target_q)**2)

                with timer.subtimer('backprop'):
                    network.zero_grad()
                    critic_loss.backward()
                    optimizer.step()

                stop_args = dict(critic_loss=ptu.to_numpy(critic_loss),
                                 q_network=network,
                                 all_target_q=self.all_target_q,
                                 fqi=self,
                                 discount=self.discount,
                                 ent_wt=self.ent_wt)
                [
                    stop_mode.update(**stop_args)
                    for stop_mode in self.stop_modes
                ]
                if (k >= self.min_project_steps):
                    stopped = False
                    for stop_mode in self.stop_modes:
                        if stop_mode.check():
                            logger.log('Early stopping via %s.' % stop_mode)
                            stopped = True
                            stopped_mode = stop_mode
                            break
                    if stopped:
                        break
        return stopped_mode, critic_loss, k
コード例 #7
0
def compute_weights(self, samples, itr=0, clip_min=1e-6, clip_max=100.0):
    if self.wscheme == 'robust_prioritized':
        q_vals = self.evaluate_qvalues(samples[0], samples[1]).detach()
        target_qs = self.evaluate_target(samples[0], samples[1], samples[2],
                                         samples[3]).detach()
        error = torch.abs(q_vals - target_qs)
        weights = ptu.to_numpy(error)
    elif self.wscheme == 'robust_adversarial':
        # solve for max_pi [bellman error]
        # compute bellman errors
        q_vals = ptu.to_numpy(
            self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                  None,
                                  mode=fqi.MULTIPLE_HEADS))
        errors = np.abs(q_vals - self.all_target_q_np)
        # pick adversarial distribution - reward is bellman error
        adversarial_qs = q_iteration.softq_iteration_custom_reward(
            self.env,
            reward=errors,
            num_itrs=self.time_limit,
            discount=self.discount,
            ent_wt=self.ent_wt,
            warmstart_q=self.warmstart_adversarial_q,
            atol=1e-5)
        self.warmstart_adversarial_q = adversarial_qs
        visit_sa = q_iteration_py.compute_visitation(
            self.env,
            adversarial_qs,
            ent_wt=self.ent_wt,
            discount=self.discount,
            env_time_limit=self.time_limit)
        weights = visit_sa[samples[0], samples[1]]
    elif self.wscheme == 'robust_adversarial_fast':
        if itr % 10 == 0:
            # solve for max_pi [bellman error]
            # compute bellman errors
            q_vals = ptu.to_numpy(
                self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                      None,
                                      mode=fqi.MULTIPLE_HEADS))
            errors = np.abs(q_vals - self.all_target_q_np)
            # pick adversarial distribution - reward is bellman error
            adversarial_qs = q_iteration.softq_iteration_custom_reward(
                self.env,
                reward=errors,
                num_itrs=self.time_limit,
                discount=self.discount,
                ent_wt=self.ent_wt,
                warmstart_q=self.warmstart_adversarial_q,
                atol=1e-5)
            self.warmstart_adversarial_q = adversarial_qs
            self.adv_visit_sa = q_iteration_py.compute_visitation(
                self.env,
                adversarial_qs,
                ent_wt=self.ent_wt,
                discount=self.discount,
                env_time_limit=self.time_limit)
        weights = self.adv_visit_sa[samples[0], samples[1]]
    else:
        weights = self.sa_weights[samples[0], samples[1]]
    weights = (weights / np.sum(weights))  # normalize
    weights = np.minimum(weights, clip_max)
    weights = np.maximum(weights, clip_min)
    weights = (weights / np.sum(weights))  # normalize
    return weights
コード例 #8
0
    def get_sample_states(self, itr=0):
        if itr % 5 == 0:  # compute weights
            weights = None
            if self.wscheme == 'uniform':
                weights = np.ones((self.env.num_states, self.env.num_actions))
            elif self.wscheme == 'buffer_infinite':
                weights = self.buffer_sa
            elif self.wscheme == 'buffer10':
                weights = self.buffer_sa
            elif self.wscheme == 'pi*':
                weights = self.visit_sa
            elif self.wscheme == 'pi*proj':
                assert self.log_proj_qstar
                weights = self.opt_proj_visit_sa
            elif self.wscheme == 'random':
                weights = self.pi_visit_sa
            elif self.wscheme == 'pi':
                weights = self.pi_visit_sa
            elif self.wscheme == 'online':
                q_vals = ptu.to_numpy(
                    self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                          None))
                visit_sa = q_iteration_py.compute_visitation(
                    self.env,
                    q_vals,
                    ent_wt=self.ent_wt,
                    discount=self.discount,
                    env_time_limit=self.time_limit)
                weights = visit_sa
            elif self.wscheme == 'robust_prioritized':
                q_vals = ptu.to_numpy(
                    self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                          None))
                errors = np.abs(q_vals - self.all_target_q_np)
                weights = errors
            elif self.wscheme == 'robust_adversarial':
                # solve for max_pi [bellman error]
                # compute bellman errors
                q_vals = ptu.to_numpy(
                    self.evaluate_qvalues(np.arange(0, self.env.num_states),
                                          None))
                errors = np.abs(q_vals - self.all_target_q_np)
                # pick adversarial distribution - reward is bellman error
                adversarial_qs = q_iteration.softq_iteration_custom_reward(
                    self.env,
                    reward=errors,
                    num_itrs=self.time_limit,
                    discount=self.discount,
                    ent_wt=self.ent_wt,
                    warmstart_q=self.warmstart_adversarial_q,
                    atol=1e-5)
                self.warmstart_adversarial_q = adversarial_qs
                visit_sa = q_iteration_py.compute_visitation(
                    self.env,
                    adversarial_qs,
                    ent_wt=self.ent_wt,
                    discount=self.discount,
                    env_time_limit=self.time_limit)
                weights = visit_sa
            else:
                raise ValueError("Unknown weighting scheme: %s" % self.wscheme)

            if self.weight_states_only:
                weights = np.sum(weights, axis=1)
                weights = np.repeat(weights[:, np.newaxis],
                                    self.env.num_actions,
                                    axis=-1)
            self.weights = (weights / np.sum(weights))  # normalize
        if itr == 0:
            entropy = -np.sum(self.weights * np.log(self.weights + 1e-6))
            logger.record_tabular('weight_entropy', entropy)
            unif = np.ones_like(self.weights) / float(self.weights.size)
            max_entropy = -np.sum(unif * np.log(unif))
            logger.record_tabular('weight_entropy_normalized',
                                  entropy / max_entropy)
        return np.arange(0,
                         self.env.num_states), None, None, None, self.weights