예제 #1
0
    def _update_vf_once_recurrent(self, episodes):

        # Sort episodes desc by length for pack_sequence
        episodes = sorted(episodes, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(episodes)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in episodes:
            states = self.batch_states(
                [transition["state"] for transition in ep], self.device,
                self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_vs_teacher = torch.as_tensor(
            [[transition["v_teacher"]] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        with torch.no_grad():
            vf_rs = concatenate_recurrent_states(
                _collect_first_recurrent_states_of_vf(episodes))

        flat_vs_pred, _ = pack_and_forward(self.vf, seqs_states, vf_rs)

        vf_loss = F.mse_loss(flat_vs_pred, flat_vs_teacher)
        self.vf.zero_grad()
        vf_loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm)
        self.vf_optimizer.step()
예제 #2
0
    def _update_recurrent(self, dataset):
        """Update both the policy and the value function."""

        flat_dataset = flatten_sequences_time_first(dataset)
        if self.obs_normalizer:
            self._update_obs_normalizer(flat_dataset)

        self._update_policy_recurrent(dataset)
        self._update_vf_recurrent(dataset)
예제 #3
0
파일: ppo.py 프로젝트: imatge-upc/PiCoEDL
def _add_log_prob_and_value_to_episodes_recurrent(
    episodes,
    model,
    phi,
    batch_states,
    obs_normalizer,
    device,
):
    # Sort desc by lengths so that pack_sequence does not change the order
    episodes = sorted(episodes, key=len, reverse=True)

    # Prepare data for a recurrent model
    seqs_states = []
    seqs_next_states = []
    for ep in episodes:
        states = batch_states([transition["state"] for transition in ep],
                              device, phi)
        next_states = batch_states(
            [transition["next_state"] for transition in ep], device, phi)
        if obs_normalizer:
            states = obs_normalizer(states, update=False)
            next_states = obs_normalizer(next_states, update=False)
        seqs_states.append(states)
        seqs_next_states.append(next_states)

    flat_transitions = flatten_sequences_time_first(episodes)

    # Predict values using a recurrent model
    with torch.no_grad(), pfrl.utils.evaluating(model):
        rs = concatenate_recurrent_states(
            [ep[0]["recurrent_state"] for ep in episodes])
        next_rs = concatenate_recurrent_states(
            [ep[0]["next_recurrent_state"] for ep in episodes])
        assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs))

        (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs)
        (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states,
                                                next_rs)

        flat_actions = torch.tensor([b["action"] for b in flat_transitions],
                                    device=device)
        flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy()
        flat_vs = flat_vs.cpu().numpy()
        flat_next_vs = flat_next_vs.cpu().numpy()

    # Add predicted values to transitions
    for transition, log_prob, v, next_v in zip(flat_transitions,
                                               flat_log_probs, flat_vs,
                                               flat_next_vs):
        transition["log_prob"] = float(log_prob)
        transition["v_pred"] = float(v)
        transition["next_v_pred"] = float(next_v)
예제 #4
0
 def _update_if_dataset_is_ready(self):
     dataset_size = (
         sum(len(episode) for episode in self.memory)
         + len(self.last_episode)
         + (
             0
             if self.batch_last_episode is None
             else sum(len(episode) for episode in self.batch_last_episode)
         )
     )
     if dataset_size >= self.update_interval:
         self._flush_last_episode()
         if self.recurrent:
             dataset = _make_dataset_recurrent(
                 episodes=self.memory,
                 model=self.model,
                 phi=self.phi,
                 batch_states=self.batch_states,
                 obs_normalizer=self.obs_normalizer,
                 gamma=self.gamma,
                 lambd=self.lambd,
                 max_recurrent_sequence_len=self.max_recurrent_sequence_len,
                 device=self.device,
             )
             self._update_recurrent(dataset)
         else:
             dataset = _make_dataset(
                 episodes=self.memory,
                 model=self.model,
                 phi=self.phi,
                 batch_states=self.batch_states,
                 obs_normalizer=self.obs_normalizer,
                 gamma=self.gamma,
                 lambd=self.lambd,
                 device=self.device,
             )
             assert len(dataset) == dataset_size
             self._update(dataset)
         self.explained_variance = _compute_explained_variance(
             flatten_sequences_time_first(self.memory)
         )
         self.memory = []
예제 #5
0
    def _line_search(self, full_step, dataset, advs, action_distrib_old, gain):
        """Do line search for a safe step size."""
        policy_params = list(self.policy.parameters())
        policy_params_sizes = [param.numel() for param in policy_params]
        policy_params_shapes = [param.shape for param in policy_params]
        step_size = 1.0
        flat_params = _flatten_and_concat_variables(policy_params).detach()

        if self.recurrent:
            seqs_states = []
            for ep in dataset:
                states = self.batch_states(
                    [transition["state"] for transition in ep], self.device, self.phi
                )
                if self.obs_normalizer:
                    states = self.obs_normalizer(states, update=False)
                seqs_states.append(states)
            with torch.no_grad(), pfrl.utils.evaluating(self.model):
                policy_rs = concatenate_recurrent_states(
                    _collect_first_recurrent_states_of_policy(dataset)
                )

            def evaluate_current_policy():
                distrib, _ = pack_and_forward(self.policy, seqs_states, policy_rs)
                return distrib

        else:
            states = self.batch_states(
                [transition["state"] for transition in dataset], self.device, self.phi
            )
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)

            def evaluate_current_policy():
                return self.policy(states)

        flat_transitions = (
            flatten_sequences_time_first(dataset) if self.recurrent else dataset
        )
        actions = torch.tensor(
            [transition["action"] for transition in flat_transitions],
            device=self.device,
        )
        log_prob_old = torch.tensor(
            [transition["log_prob"] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        for i in range(self.line_search_max_backtrack + 1):
            self.logger.info("Line search iteration: %s step size: %s", i, step_size)
            new_flat_params = flat_params + step_size * full_step
            new_params = _split_and_reshape_to_ndarrays(
                new_flat_params,
                sizes=policy_params_sizes,
                shapes=policy_params_shapes,
            )
            _replace_params_data(policy_params, new_params)
            with torch.no_grad(), pfrl.utils.evaluating(self.policy):
                new_action_distrib = evaluate_current_policy()
                new_gain = self._compute_gain(
                    log_prob=new_action_distrib.log_prob(actions),
                    log_prob_old=log_prob_old,
                    entropy=new_action_distrib.entropy(),
                    advs=advs,
                )
                new_kl = torch.mean(
                    torch.distributions.kl_divergence(
                        action_distrib_old, new_action_distrib
                    )
                )

            improve = float(new_gain) - float(gain)
            self.logger.info("Surrogate objective improve: %s", improve)
            self.logger.info("KL divergence: %s", float(new_kl))
            if not torch.isfinite(new_gain):
                self.logger.info("Surrogate objective is not finite. Bakctracking...")
            elif not torch.isfinite(new_kl):
                self.logger.info("KL divergence is not finite. Bakctracking...")
            elif improve < 0:
                self.logger.info("Surrogate objective didn't improve. Bakctracking...")
            elif float(new_kl) > self.max_kl:
                self.logger.info("KL divergence exceeds max_kl. Bakctracking...")
            else:
                self.kl_record.append(float(new_kl))
                self.policy_step_size_record.append(step_size)
                break
            step_size *= 0.5
        else:
            self.logger.info(
                "Line search coundn't find a good step size. The policy was not"
                " updated."
            )
            self.policy_step_size_record.append(0.0)
            _replace_params_data(
                policy_params,
                _split_and_reshape_to_ndarrays(
                    flat_params, sizes=policy_params_sizes, shapes=policy_params_shapes
                ),
            )
예제 #6
0
    def _update_policy_recurrent(self, dataset):
        """Update the policy using a given dataset.

        The policy is updated via CG and line search.
        """

        # Sort episodes desc by length for pack_sequence
        dataset = sorted(dataset, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(dataset)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in dataset:
            states = self.batch_states(
                [transition["state"] for transition in ep],
                self.device,
                self.phi,
            )
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_actions = torch.as_tensor(
            [transition["action"] for transition in flat_transitions],
            device=self.device,
        )
        flat_advs = torch.as_tensor(
            [transition["adv"] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        if self.standardize_advantages:
            std_advs, mean_advs = torch.std_mean(flat_advs, unbiased=False)
            flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8)

        with torch.no_grad():
            policy_rs = concatenate_recurrent_states(
                _collect_first_recurrent_states_of_policy(dataset)
            )

        flat_distribs, _ = pack_and_forward(self.policy, seqs_states, policy_rs)

        log_prob_old = torch.tensor(
            [transition["log_prob"] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        gain = self._compute_gain(
            log_prob=flat_distribs.log_prob(flat_actions),
            log_prob_old=log_prob_old,
            entropy=flat_distribs.entropy(),
            advs=flat_advs,
        )

        # Distribution to compute KL div against
        with torch.no_grad():
            # torch.distributions.Distribution cannot be deepcopied
            action_distrib_old, _ = pack_and_forward(
                self.policy, seqs_states, policy_rs
            )

        full_step = self._compute_kl_constrained_step(
            action_distrib=flat_distribs,
            action_distrib_old=action_distrib_old,
            gain=gain,
        )

        self._line_search(
            full_step=full_step,
            dataset=dataset,
            advs=flat_advs,
            action_distrib_old=action_distrib_old,
            gain=gain,
        )
예제 #7
0
파일: replay_buffer.py 프로젝트: pfnet/pfrl
def batch_recurrent_experiences(experiences,
                                device,
                                phi,
                                gamma,
                                batch_states=batch_states):
    """Batch experiences for recurrent model updates.

    Args:
        experiences: list of episodes. Each episode is a list
            containing between 1 and n dicts, each containing:
              - state (object): State
              - action (object): Action
              - reward (float): Reward
              - is_state_terminal (bool): True iff next state is terminal
              - next_state (object): Next state
            The list must be sorted desc by lengths to be packed by
            `torch.nn.rnn.pack_sequence` later.
        device : GPU or CPU the tensor should be placed on
        phi : Preprocessing function
        gamma: discount factor
        batch_states: function that converts a list to a batch
    Returns:
        dict of batched transitions
    """
    assert _is_sorted_desc_by_lengths(experiences)
    flat_transitions = flatten_sequences_time_first(experiences)
    batch_exp = {
        "state": [
            batch_states([transition["state"]
                          for transition in ep], device, phi)
            for ep in experiences
        ],
        "action":
        torch.as_tensor(
            [transition["action"] for transition in flat_transitions],
            device=device),
        "reward":
        torch.as_tensor(
            [transition["reward"] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        ),
        "next_state": [
            batch_states([transition["next_state"]
                          for transition in ep], device, phi)
            for ep in experiences
        ],
        "is_state_terminal":
        torch.as_tensor(
            [
                transition["is_state_terminal"]
                for transition in flat_transitions
            ],
            dtype=torch.float,
            device=device,
        ),
        "discount":
        torch.full((len(flat_transitions), ),
                   gamma,
                   dtype=torch.float,
                   device=device),
        "recurrent_state":
        recurrent_state_from_numpy(
            concatenate_recurrent_states(
                [ep[0]["recurrent_state"] for ep in experiences]),
            device,
        ),
        "next_recurrent_state":
        recurrent_state_from_numpy(
            concatenate_recurrent_states(
                [ep[0]["next_recurrent_state"] for ep in experiences]),
            device,
        ),
    }
    # Batch next actions only when all the transitions have them
    if all(transition["next_action"] is not None
           for transition in flat_transitions):
        batch_exp["next_action"] = torch.as_tensor(
            [transition["next_action"] for transition in flat_transitions],
            device=device,
        )
    return batch_exp
예제 #8
0
파일: ppo.py 프로젝트: imatge-upc/PiCoEDL
    def _update_once_recurrent(self, episodes, mean_advs, std_advs):

        assert std_advs is None or std_advs > 0

        device = self.device

        # Sort desc by lengths so that pack_sequence does not change the order
        episodes = sorted(episodes, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(episodes)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in episodes:
            states = self.batch_states(
                [transition["state"] for transition in ep],
                self.device,
                self.phi,
            )
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_actions = torch.tensor(
            [transition["action"] for transition in flat_transitions],
            device=device,
        )
        flat_advs = torch.tensor(
            [transition["adv"] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )
        if self.standardize_advantages:
            flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8)
        flat_log_probs_old = torch.tensor(
            [transition["log_prob"] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )
        flat_vs_pred_old = torch.tensor(
            [[transition["v_pred"]] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )
        flat_vs_teacher = torch.tensor(
            [[transition["v_teacher"]] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        )

        with torch.no_grad(), pfrl.utils.evaluating(self.model):
            rs = concatenate_recurrent_states(
                [ep[0]["recurrent_state"] for ep in episodes])

        (flat_distribs,
         flat_vs_pred), _ = pack_and_forward(self.model, seqs_states, rs)
        flat_log_probs = flat_distribs.log_prob(flat_actions)
        flat_entropy = flat_distribs.entropy()

        self.model.zero_grad()
        loss = self._lossfun(
            entropy=flat_entropy,
            vs_pred=flat_vs_pred,
            log_probs=flat_log_probs,
            vs_pred_old=flat_vs_pred_old,
            log_probs_old=flat_log_probs_old,
            advs=flat_advs,
            vs_teacher=flat_vs_teacher,
        )
        loss.backward()
        if self.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.max_grad_norm)
        self.optimizer.step()
        self.n_updates += 1