Beispiel #1
0
    def _compute_target_values(self, exp_batch):

        batch_next_state = exp_batch["next_state"]

        with evaluating(self.model):
            if self.recurrent:
                next_qout, _ = pack_and_forward(
                    self.model,
                    batch_next_state,
                    exp_batch["next_recurrent_state"],
                )
            else:
                next_qout = self.model(batch_next_state)

        if self.recurrent:
            target_next_qout, _ = pack_and_forward(
                self.target_model,
                batch_next_state,
                exp_batch["next_recurrent_state"],
            )
        else:
            target_next_qout = self.target_model(batch_next_state)

        next_q_max = target_next_qout.evaluate_actions(
            next_qout.greedy_actions)

        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]
        discount = exp_batch["discount"]

        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
Beispiel #2
0
    def _compute_y_and_t(self, exp_batch):

        batch_state = exp_batch["state"]
        batch_size = len(exp_batch["reward"])

        if self.recurrent:
            qout, _ = pack_and_forward(
                self.model,
                batch_state,
                exp_batch["recurrent_state"],
            )
        else:
            qout = self.model(batch_state)

        batch_actions = exp_batch["action"]
        batch_q = qout.evaluate_actions(batch_actions)

        # Compute target values
        batch_next_state = exp_batch["next_state"]
        with torch.no_grad():
            if self.recurrent:
                target_qout, _ = pack_and_forward(
                    self.target_model,
                    batch_state,
                    exp_batch["recurrent_state"],
                )
                target_next_qout, _ = pack_and_forward(
                    self.target_model,
                    batch_next_state,
                    exp_batch["next_recurrent_state"],
                )
            else:
                target_qout = self.target_model(batch_state)
                target_next_qout = self.target_model(batch_next_state)

            next_q_max = torch.reshape(target_next_qout.max, (batch_size, ))

            batch_rewards = exp_batch["reward"]
            batch_terminal = exp_batch["is_state_terminal"]

            # T Q: Bellman operator
            t_q = (batch_rewards + exp_batch["discount"] *
                   (1.0 - batch_terminal) * next_q_max)

            # T_PAL Q: persistent advantage learning operator
            cur_advantage = torch.reshape(
                target_qout.compute_advantage(batch_actions), (batch_size, ))
            next_advantage = torch.reshape(
                target_next_qout.compute_advantage(batch_actions),
                (batch_size, ))
            tpal_q = t_q + self.alpha * torch.max(cur_advantage,
                                                  next_advantage)

        return batch_q, tpal_q
Beispiel #3
0
def _add_log_prob_and_value_to_episodes_recurrent(
    episodes,
    model,
    phi,
    batch_states,
    obs_normalizer,
    device,
):
    # Sort desc by lengths so that pack_sequence does not change the order
    episodes = sorted(episodes, key=len, reverse=True)

    # Prepare data for a recurrent model
    seqs_states = []
    seqs_next_states = []
    for ep in episodes:
        states = batch_states([transition["state"] for transition in ep],
                              device, phi)
        next_states = batch_states(
            [transition["next_state"] for transition in ep], device, phi)
        if obs_normalizer:
            states = obs_normalizer(states, update=False)
            next_states = obs_normalizer(next_states, update=False)
        seqs_states.append(states)
        seqs_next_states.append(next_states)

    flat_transitions = flatten_sequences_time_first(episodes)

    # Predict values using a recurrent model
    with torch.no_grad(), pfrl.utils.evaluating(model):
        rs = concatenate_recurrent_states(
            [ep[0]["recurrent_state"] for ep in episodes])
        next_rs = concatenate_recurrent_states(
            [ep[0]["next_recurrent_state"] for ep in episodes])
        assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs))

        (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs)
        (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states,
                                                next_rs)

        flat_actions = torch.tensor([b["action"] for b in flat_transitions],
                                    device=device)
        flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy()
        flat_vs = flat_vs.cpu().numpy()
        flat_next_vs = flat_next_vs.cpu().numpy()

    # Add predicted values to transitions
    for transition, log_prob, v, next_v in zip(flat_transitions,
                                               flat_log_probs, flat_vs,
                                               flat_next_vs):
        transition["log_prob"] = float(log_prob)
        transition["v_pred"] = float(v)
        transition["next_v_pred"] = float(next_v)
Beispiel #4
0
    def _compute_y_and_t(self, exp_batch):
        """Compute a batch of predicted/target return distributions."""

        batch_size = exp_batch["reward"].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch["state"]

        # (batch_size, n_actions, n_atoms)
        if self.recurrent:
            qout, _ = pack_and_forward(self.model, batch_state,
                                       exp_batch["recurrent_state"])
        else:
            qout = self.model(batch_state)
        n_atoms = qout.z_values.size()[0]

        batch_actions = exp_batch["action"]
        batch_q = qout.evaluate_actions_as_distribution(batch_actions)
        assert batch_q.shape == (batch_size, n_atoms)

        with torch.no_grad():
            batch_q_target = self._compute_target_values(exp_batch)
            assert batch_q_target.shape == (batch_size, n_atoms)

        batch_q_scalars = qout.evaluate_actions(batch_actions)
        self.q_record.extend(batch_q_scalars.detach().cpu().numpy().ravel())

        return batch_q, batch_q_target
Beispiel #5
0
    def _update_vf_once_recurrent(self, episodes):

        # Sort episodes desc by length for pack_sequence
        episodes = sorted(episodes, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(episodes)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in episodes:
            states = self.batch_states(
                [transition["state"] for transition in ep], self.device,
                self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_vs_teacher = torch.as_tensor(
            [[transition["v_teacher"]] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        with torch.no_grad():
            vf_rs = concatenate_recurrent_states(
                _collect_first_recurrent_states_of_vf(episodes))

        flat_vs_pred, _ = pack_and_forward(self.vf, seqs_states, vf_rs)

        vf_loss = F.mse_loss(flat_vs_pred, flat_vs_teacher)
        self.vf.zero_grad()
        vf_loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm)
        self.vf_optimizer.step()
Beispiel #6
0
    def _compute_target_values(self, exp_batch):
        """Compute a batch of target return distributions."""

        batch_next_state = exp_batch["next_state"]
        if self.recurrent:
            target_next_qout, _ = pack_and_forward(
                self.target_model,
                batch_next_state,
                exp_batch["next_recurrent_state"],
            )
        else:
            target_next_qout = self.target_model(batch_next_state)

        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]

        batch_size = exp_batch["reward"].shape[0]
        z_values = target_next_qout.z_values
        n_atoms = z_values.size()[0]

        # next_q_max: (batch_size, n_atoms)
        next_q_max = target_next_qout.max_as_distribution.detach()
        assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape

        # Tz: (batch_size, n_atoms)
        Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) *
              torch.unsqueeze(exp_batch["discount"], 1) * z_values[None])
        return _apply_categorical_projection(Tz, next_q_max, z_values)
Beispiel #7
0
    def _compute_y_and_taus(self, exp_batch):
        """Compute a batch of predicted return distributions.

        Returns:
            torch.Tensor: Predicted return distributions.
                (batch_size, N).
        """

        batch_size = exp_batch["reward"].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch["state"]

        # (batch_size, n_actions, n_atoms)
        if self.recurrent:
            tau2av, _ = pack_and_forward(
                self.model,
                batch_state,
                exp_batch["recurrent_state"],
            )
        else:
            tau2av = self.model(batch_state)
        taus = torch.rand(
            batch_size,
            self.quantile_thresholds_N,
            device=self.device,
            dtype=torch.float,
        )
        av = tau2av(taus)
        batch_actions = exp_batch["action"]
        y = av.evaluate_actions_as_quantiles(batch_actions)

        self.q_record.extend(av.q_values.detach().cpu().numpy().ravel())

        return y, taus
Beispiel #8
0
    def _compute_y_and_t(self, exp_batch):

        batch_state = exp_batch["state"]
        batch_size = len(exp_batch["reward"])

        if self.recurrent:
            qout, _ = pack_and_forward(
                self.model,
                batch_state,
                exp_batch["recurrent_state"],
            )
        else:
            qout = self.model(batch_state)

        batch_actions = exp_batch["action"]
        # Q(s_t,a_t)
        batch_q = qout.evaluate_actions(batch_actions).reshape((batch_size, 1))

        with torch.no_grad():
            # Compute target values
            if self.recurrent:
                target_qout, _ = pack_and_forward(
                    self.target_model,
                    batch_state,
                    exp_batch["recurrent_state"],
                )
            else:
                target_qout = self.target_model(batch_state)

            # Q'(s_t,a_t)
            target_q = target_qout.evaluate_actions(batch_actions).reshape(
                (batch_size, 1))

            # LQ'(s_t,a)
            target_q_expect = self._l_operator(target_qout).reshape(
                (batch_size, 1))

            # r + g * LQ'(s_{t+1},a)
            batch_q_target = self._compute_target_values(exp_batch).reshape(
                (batch_size, 1))

            # Q'(s_t,a_t) + r + g * LQ'(s_{t+1},a) - LQ'(s_t,a)
            t = target_q + batch_q_target - target_q_expect

        return batch_q, t
Beispiel #9
0
    def _compute_target_values(self, exp_batch):
        """Compute a batch of target return distributions."""

        batch_next_state = exp_batch["next_state"]
        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]

        with pfrl.utils.evaluating(self.target_model), pfrl.utils.evaluating(
                self.model):
            if self.recurrent:
                target_next_qout, _ = pack_and_forward(
                    self.target_model,
                    batch_next_state,
                    exp_batch["next_recurrent_state"],
                )
                next_qout, _ = pack_and_forward(
                    self.model,
                    batch_next_state,
                    exp_batch["next_recurrent_state"],
                )
            else:
                target_next_qout = self.target_model(batch_next_state)
                next_qout = self.model(batch_next_state)

        batch_size = batch_rewards.shape[0]
        z_values = target_next_qout.z_values
        n_atoms = z_values.numel()

        # next_q_max: (batch_size, n_atoms)
        next_q_max = target_next_qout.evaluate_actions_as_distribution(
            next_qout.greedy_actions.detach())
        assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape

        # Tz: (batch_size, n_atoms)
        Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) *
              exp_batch["discount"][..., None] * z_values[None])
        # Tz = (
        #     batch_rewards.squeeze(dim=-1)
        #     + (1.0 - batch_terminal.unsqueeze(dim=-1))
        #     * exp_batch["discount"].unsqueeze(dim=-1)
        #     * z_values.unsqueeze(dim=0)
        # )

        return _apply_categorical_projection(Tz, next_q_max, z_values)
Beispiel #10
0
    def _compute_target_values(self, exp_batch):
        """Compute a batch of target return distributions.

        Returns:
            torch.Tensor: (batch_size, N_prime).
        """
        batch_next_state = exp_batch["next_state"]
        batch_size = len(exp_batch["reward"])
        taus_tilde = torch.rand(
            batch_size,
            self.quantile_thresholds_K,
            device=self.device,
            dtype=torch.float,
        )

        if self.recurrent:
            target_next_tau2av, _ = pack_and_forward(
                self.target_model,
                batch_next_state,
                exp_batch["next_recurrent_state"],
            )
        else:
            target_next_tau2av = self.target_model(batch_next_state)
        greedy_actions = target_next_tau2av(taus_tilde).greedy_actions
        taus_prime = torch.rand(
            batch_size,
            self.quantile_thresholds_N_prime,
            device=self.device,
            dtype=torch.float,
        )
        target_next_maxz = target_next_tau2av(
            taus_prime).evaluate_actions_as_quantiles(greedy_actions)

        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]
        batch_discount = exp_batch["discount"]
        assert batch_rewards.shape == (batch_size, )
        assert batch_terminal.shape == (batch_size, )
        assert batch_discount.shape == (batch_size, )
        batch_rewards = batch_rewards.unsqueeze(-1)
        batch_terminal = batch_terminal.unsqueeze(-1)
        batch_discount = batch_discount.unsqueeze(-1)

        return (batch_rewards + batch_discount *
                (1.0 - batch_terminal) * target_next_maxz)
Beispiel #11
0
    def _compute_target_values(self, exp_batch):

        batch_next_state = exp_batch["next_state"]

        if self.recurrent:
            target_next_qout, _ = pack_and_forward(
                self.target_model, batch_next_state, exp_batch["next_recurrent_state"],
            )
        else:
            target_next_qout = self.target_model(batch_next_state)
        next_q_expect = self._l_operator(target_next_qout)

        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]

        return (
            batch_rewards + exp_batch["discount"] * (1 - batch_terminal) * next_q_expect
        )
Beispiel #12
0
    def _compute_target_values(self, exp_batch):
        batch_next_state = exp_batch["next_state"]

        if self.recurrent:
            target_next_qout, _ = pack_and_forward(
                self.target_model,
                batch_next_state,
                exp_batch["next_recurrent_state"],
            )
        else:
            target_next_qout = self.target_model(batch_next_state)
        next_q_max = target_next_qout.max

        batch_terminal = exp_batch["is_state_terminal"]
        discount = exp_batch["discount"]
        batch_rewards = exp_batch["reward"]

        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
Beispiel #13
0
    def _compute_y_and_t(self, exp_batch):
        batch_size = exp_batch["reward"].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch["state"]

        if self.recurrent:
            qout, _ = pack_and_forward(self.model, batch_state,
                                       exp_batch["recurrent_state"])
        else:
            qout = self.model(batch_state)

        batch_actions = exp_batch["action"]
        batch_q = torch.reshape(qout.evaluate_actions(batch_actions),
                                (batch_size, 1))

        with torch.no_grad():
            batch_q_target = torch.reshape(
                self._compute_target_values(exp_batch), (batch_size, 1))

        return batch_q, batch_q_target
Beispiel #14
0
 def evaluate_current_policy():
     distrib, _ = pack_and_forward(self.policy, seqs_states, policy_rs)
     return distrib
Beispiel #15
0
    def _update_policy_recurrent(self, dataset):
        """Update the policy using a given dataset.

        The policy is updated via CG and line search.
        """

        # Sort episodes desc by length for pack_sequence
        dataset = sorted(dataset, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(dataset)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in dataset:
            states = self.batch_states(
                [transition["state"] for transition in ep],
                self.device,
                self.phi,
            )
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_actions = torch.as_tensor(
            [transition["action"] for transition in flat_transitions],
            device=self.device,
        )
        flat_advs = torch.as_tensor(
            [transition["adv"] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        if self.standardize_advantages:
            std_advs, mean_advs = torch.std_mean(flat_advs, unbiased=False)
            flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8)

        with torch.no_grad():
            policy_rs = concatenate_recurrent_states(
                _collect_first_recurrent_states_of_policy(dataset)
            )

        flat_distribs, _ = pack_and_forward(self.policy, seqs_states, policy_rs)

        log_prob_old = torch.tensor(
            [transition["log_prob"] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        gain = self._compute_gain(
            log_prob=flat_distribs.log_prob(flat_actions),
            log_prob_old=log_prob_old,
            entropy=flat_distribs.entropy(),
            advs=flat_advs,
        )

        # Distribution to compute KL div against
        with torch.no_grad():
            # torch.distributions.Distribution cannot be deepcopied
            action_distrib_old, _ = pack_and_forward(
                self.policy, seqs_states, policy_rs
            )

        full_step = self._compute_kl_constrained_step(
            action_distrib=flat_distribs,
            action_distrib_old=action_distrib_old,
            gain=gain,
        )

        self._line_search(
            full_step=full_step,
            dataset=dataset,
            advs=flat_advs,
            action_distrib_old=action_distrib_old,
            gain=gain,
        )
Beispiel #16
0
    def update(self, statevar):
        assert self.t_start < self.t

        n = self.t - self.t_start

        self.assert_shared_memory()

        if statevar is None:
            R = 0
        else:
            with torch.no_grad(), pfrl.utils.evaluating(self.model):
                if self.recurrent:
                    (_,
                     vout), _ = one_step_forward(self.model, statevar,
                                                 self.train_recurrent_states)
                else:
                    _, vout = self.model(statevar)
            R = float(vout)

        pi_loss_factor = self.pi_loss_coef
        v_loss_factor = self.v_loss_coef

        # Normalize the loss of sequences truncated by terminal states
        if self.keep_loss_scale_same and self.t - self.t_start < self.t_max:
            factor = self.t_max / (self.t - self.t_start)
            pi_loss_factor *= factor
            v_loss_factor *= factor

        if self.normalize_grad_by_t_max:
            pi_loss_factor /= self.t - self.t_start
            v_loss_factor /= self.t - self.t_start

        # Batch re-compute for efficient backprop
        batch_obs = self.batch_states(
            [self.past_obs[i] for i in range(self.t_start, self.t)],
            self.device,
            self.phi,
        )
        if self.recurrent:
            (batch_distrib, batch_v), _ = pack_and_forward(
                self.model,
                [batch_obs],
                self.past_recurrent_state[self.t_start],
            )
        else:
            batch_distrib, batch_v = self.model(batch_obs)
        batch_action = torch.stack(
            [self.past_action[i] for i in range(self.t_start, self.t)])
        batch_log_prob = batch_distrib.log_prob(batch_action)
        batch_entropy = batch_distrib.entropy()
        rev_returns = []
        for i in reversed(range(self.t_start, self.t)):
            R *= self.gamma
            R += self.past_rewards[i]
            rev_returns.append(R)
        batch_return = torch.as_tensor(list(reversed(rev_returns)),
                                       dtype=torch.float)
        batch_adv = batch_return - batch_v.detach().squeeze(-1)
        assert batch_log_prob.shape == (n, )
        assert batch_adv.shape == (n, )
        assert batch_entropy.shape == (n, )
        pi_loss = torch.sum(-batch_adv * batch_log_prob -
                            self.beta * batch_entropy,
                            dim=0)
        assert batch_v.shape == (n, 1)
        assert batch_return.shape == (n, )
        v_loss = F.mse_loss(batch_v, batch_return[..., None],
                            reduction="sum") / 2

        if pi_loss_factor != 1.0:
            pi_loss *= pi_loss_factor

        if v_loss_factor != 1.0:
            v_loss *= v_loss_factor

        if self.process_idx == 0:
            logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss)

        total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss)

        # Compute gradients using thread-specific model
        self.model.zero_grad()
        total_loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        # Copy the gradients to the globally shared model
        copy_param.copy_grad(target_link=self.shared_model,
                             source_link=self.model)
        # Update the globally shared model
        self.optimizer.step()
        if self.process_idx == 0:
            logger.debug("update")

        self.sync_parameters()

        self.past_obs = {}
        self.past_action = {}
        self.past_rewards = {}
        self.past_recurrent_state = {}

        self.t_start = self.t
Beispiel #17
0
    def _update_once_recurrent(self, episodes, mean_advs, std_advs):

        assert std_advs is None or std_advs > 0

        device = self.device

        # Sort desc by lengths so that pack_sequence does not change the order
        episodes = sorted(episodes, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(episodes)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in episodes:
            states = self.batch_states(
                [transition["state"] for transition in ep],
                self.device,
                self.phi,
            )
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_actions = torch.tensor(
            [transition["action"] for transition in flat_transitions],
            device=device,
        )
        flat_advs = torch.tensor(
            [transition["adv"] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )
        if self.standardize_advantages:
            flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8)
        flat_log_probs_old = torch.tensor(
            [transition["log_prob"] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )
        flat_vs_pred_old = torch.tensor(
            [[transition["v_pred"]] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )
        flat_vs_teacher = torch.tensor(
            [[transition["v_teacher"]] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )

        with torch.no_grad(), pfrl.utils.evaluating(self.model):
            rs = concatenate_recurrent_states(
                [ep[0]["recurrent_state"] for ep in episodes])

        (flat_distribs,
         flat_vs_pred), _ = pack_and_forward(self.model, seqs_states, rs)
        flat_log_probs = flat_distribs.log_prob(flat_actions)
        flat_entropy = flat_distribs.entropy()

        self.model.zero_grad()
        loss = self._lossfun(
            entropy=flat_entropy,
            vs_pred=flat_vs_pred,
            log_probs=flat_log_probs,
            vs_pred_old=flat_vs_pred_old,
            log_probs_old=flat_log_probs_old,
            advs=flat_advs,
            vs_teacher=flat_vs_teacher,
        )
        loss.backward()
        if self.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.max_grad_norm)
        self.optimizer.step()
        self.n_updates += 1