Ejemplo n.º 1
0
def _add_log_prob_and_value_to_episodes(
    episodes,
    model,
    phi,
    batch_states,
    obs_normalizer,
    device,
):

    dataset = list(itertools.chain.from_iterable(episodes))

    # Compute v_pred and next_v_pred
    states = batch_states([b["state"] for b in dataset], device, phi)
    next_states = batch_states([b["next_state"] for b in dataset], device, phi)

    if obs_normalizer:
        states = obs_normalizer(states, update=False)
        next_states = obs_normalizer(next_states, update=False)

    with torch.no_grad(), pfrl.utils.evaluating(model):
        distribs, vs_pred = model(states)
        _, next_vs_pred = model(next_states)

        actions = torch.tensor([b["action"] for b in dataset], device=device)
        log_probs = distribs.log_prob(actions).cpu().numpy()
        vs_pred = vs_pred.cpu().numpy().ravel()
        next_vs_pred = next_vs_pred.cpu().numpy().ravel()

    for transition, log_prob, v_pred, next_v_pred in zip(
            dataset, log_probs, vs_pred, next_vs_pred):
        transition["log_prob"] = log_prob
        transition["v_pred"] = v_pred
        transition["next_v_pred"] = next_v_pred
Ejemplo n.º 2
0
def batch_experiences(experiences, device, phi, gamma, batch_states=batch_states):
    """Takes a batch of k experiences each of which contains j

    consecutive transitions and vectorizes them, where j is between 1 and n.

    Args:
        experiences: list of experiences. Each experience is a list
            containing between 1 and n dicts containing
              - state (object): State
              - action (object): Action
              - reward (float): Reward
              - is_state_terminal (bool): True iff next state is terminal
              - next_state (object): Next state
        device : GPU or CPU the tensor should be placed on
        phi : Preprocessing function
        gamma: discount factor
        batch_states: function that converts a list to a batch
    Returns:
        dict of batched transitions
    """

    batch_exp = {
        "state": batch_states([elem[0]["state"] for elem in experiences], device, phi),
        "action": torch.as_tensor(
            [elem[0]["action"] for elem in experiences], device=device
        ),
        "reward": torch.as_tensor(
            [
                sum((gamma ** i) * exp[i]["reward"] for i in range(len(exp)))
                for exp in experiences
            ],
            dtype=torch.float32,
            device=device,
        ),
        "next_state": batch_states(
            [elem[-1]["next_state"] for elem in experiences], device, phi
        ),
        "is_state_terminal": torch.as_tensor(
            [
                any(transition["is_state_terminal"] for transition in exp)
                for exp in experiences
            ],
            dtype=torch.float32,
            device=device,
        ),
        "discount": torch.as_tensor(
            [(gamma ** len(elem)) for elem in experiences],
            dtype=torch.float32,
            device=device,
        ),
    }
    if all(elem[-1]["next_action"] is not None for elem in experiences):
        batch_exp["next_action"] = torch.as_tensor(
            [elem[-1]["next_action"] for elem in experiences], device=device
        )
    return batch_exp
Ejemplo n.º 3
0
def _add_log_prob_and_value_to_episodes_recurrent(
    episodes,
    model,
    phi,
    batch_states,
    obs_normalizer,
    device,
):
    # Sort desc by lengths so that pack_sequence does not change the order
    episodes = sorted(episodes, key=len, reverse=True)

    # Prepare data for a recurrent model
    seqs_states = []
    seqs_next_states = []
    for ep in episodes:
        states = batch_states([transition["state"] for transition in ep],
                              device, phi)
        next_states = batch_states(
            [transition["next_state"] for transition in ep], device, phi)
        if obs_normalizer:
            states = obs_normalizer(states, update=False)
            next_states = obs_normalizer(next_states, update=False)
        seqs_states.append(states)
        seqs_next_states.append(next_states)

    flat_transitions = flatten_sequences_time_first(episodes)

    # Predict values using a recurrent model
    with torch.no_grad(), pfrl.utils.evaluating(model):
        rs = concatenate_recurrent_states(
            [ep[0]["recurrent_state"] for ep in episodes])
        next_rs = concatenate_recurrent_states(
            [ep[0]["next_recurrent_state"] for ep in episodes])
        assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs))

        (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs)
        (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states,
                                                next_rs)

        flat_actions = torch.tensor([b["action"] for b in flat_transitions],
                                    device=device)
        flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy()
        flat_vs = flat_vs.cpu().numpy()
        flat_next_vs = flat_next_vs.cpu().numpy()

    # Add predicted values to transitions
    for transition, log_prob, v, next_v in zip(flat_transitions,
                                               flat_log_probs, flat_vs,
                                               flat_next_vs):
        transition["log_prob"] = float(log_prob)
        transition["v_pred"] = float(v)
        transition["next_v_pred"] = float(next_v)
Ejemplo n.º 4
0
    def _update_vf(self, dataset):
        """Update the value function using a given dataset.

        The value function is updated via SGD to minimize TD(lambda) errors.
        """

        assert "state" in dataset[0]
        assert "v_teacher" in dataset[0]

        for batch in _yield_minibatches(
            dataset, minibatch_size=self.vf_batch_size, num_epochs=self.vf_epochs
        ):
            states = batch_states([b["state"] for b in batch], self.device, self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            vs_teacher = torch.as_tensor(
                [b["v_teacher"] for b in batch],
                device=self.device,
                dtype=torch.float,
            )
            vs_pred = self.vf(states)
            vf_loss = F.mse_loss(vs_pred, vs_teacher[..., None])
            self.vf.zero_grad()
            vf_loss.backward()
            if self.max_grad_norm is not None:
                clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm)
            self.vf_optimizer.step()
Ejemplo n.º 5
0
    def _update_policy(self, dataset):
        """Update the policy using a given dataset.

        The policy is updated via CG and line search.
        """

        assert "state" in dataset[0]
        assert "action" in dataset[0]
        assert "adv" in dataset[0]

        # Use full-batch
        states = batch_states([b["state"] for b in dataset], self.device,
                              self.phi)
        if self.obs_normalizer:
            states = self.obs_normalizer(states, update=False)
        actions = torch.as_tensor([b["action"] for b in dataset],
                                  device=self.device)
        advs = torch.as_tensor([b["adv"] for b in dataset],
                               device=self.device,
                               dtype=torch.float)
        if self.standardize_advantages:
            std_advs, mean_advs = torch.std_mean(advs, unbiased=False)
            advs = (advs - mean_advs) / (std_advs + 1e-8)

        # Recompute action distributions for batch backprop
        action_distrib = self.policy(states)

        log_prob_old = torch.as_tensor(
            [transition["log_prob"] for transition in dataset],
            device=self.device,
            dtype=torch.float,
        )

        gain = self._compute_gain(
            log_prob=action_distrib.log_prob(actions),
            log_prob_old=log_prob_old,
            entropy=action_distrib.entropy(),
            advs=advs,
        )

        # Distribution to compute KL div against
        with torch.no_grad():
            # torch.distributions.Distribution cannot be deepcopied
            action_distrib_old = self.policy(states)

        full_step = self._compute_kl_constrained_step(
            action_distrib=action_distrib,
            action_distrib_old=action_distrib_old,
            gain=gain,
        )

        self._line_search(
            full_step=full_step,
            dataset=dataset,
            advs=advs,
            action_distrib_old=action_distrib_old,
            gain=gain,
        )
Ejemplo n.º 6
0
 def _act_eval(self, obs):
     # Use the process-local model for acting
     with torch.no_grad():
         statevar = batch_states([obs], self.device, self.phi)
         if self.recurrent:
             (action_distrib, _,
              _), self.test_recurrent_states = one_step_forward(
                  self.model, statevar, self.test_recurrent_states)
         else:
             action_distrib, _, _ = self.model(statevar)
         if self.act_deterministically:
             return mode_of_distribution(action_distrib).numpy()[0]
         else:
             return action_distrib.sample().numpy()[0]
Ejemplo n.º 7
0
def batch_trajectory(trajectory, device, phi, batch_states=batch_states):
    batch_tr = {
        'state':
        batch_states([elem['state'] for elem in trajectory], device, phi),
        'action':
        np.asarray([elem['action'] for elem in trajectory], dtype=np.int32),
        'reward':
        np.asarray([elem['reward'] for elem in trajectory], dtype=np.float32),
        'is_state_terminal':
        np.asarray([elem['is_state_terminal'] for elem in trajectory],
                   dtype=np.float32),
        # 'feature': [elem['feature'].cpu().detach().clone().numpy() for elem in trajectory]
        'feature':
        np.asarray([elem['feature'] for elem in trajectory], dtype=np.float32),
    }

    return batch_tr
Ejemplo n.º 8
0
Archivo: acer.py Proyecto: pfnet/pfrl
    def _act_train(self, obs):

        statevar = batch_states([obs], self.device, self.phi)

        if self.recurrent:
            (
                (action_distrib, action_value, v),
                self.train_recurrent_states,
            ) = one_step_forward(self.model, statevar, self.train_recurrent_states)
        else:
            action_distrib, action_value, v = self.model(statevar)
        self.past_action_values[self.t] = action_value
        action = action_distrib.sample()[0]

        # Save values for a later update
        self.past_values[self.t] = v
        self.past_action_distrib[self.t] = action_distrib
        with torch.no_grad():
            if self.recurrent:
                (
                    (avg_action_distrib, _, _),
                    self.shared_recurrent_states,
                ) = one_step_forward(
                    self.shared_average_model,
                    statevar,
                    self.shared_recurrent_states,
                )
            else:
                avg_action_distrib, _, _ = self.shared_average_model(statevar)
        self.past_avg_action_distrib[self.t] = avg_action_distrib

        self.past_actions[self.t] = action

        # Update stats
        self.average_value += (1 - self.average_value_decay) * (
            float(v) - self.average_value
        )
        self.average_entropy += (1 - self.average_entropy_decay) * (
            float(action_distrib.entropy()) - self.average_entropy
        )

        self.last_state = obs
        self.last_action = action.numpy()
        self.last_action_distrib = deepcopy_distribution(action_distrib)

        return self.last_action
Ejemplo n.º 9
0
    def _observe_train(self, state, reward, done, reset):
        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        if self.replay_buffer is not None:
            self.replay_buffer.append(
                state=self.last_state,
                action=self.last_action,
                reward=reward,
                next_state=state,
                is_state_terminal=done,
                mu=self.last_action_distrib,
            )
            if done or reset:
                self.replay_buffer.stop_current_episode()

        self.t += 1
        self.past_rewards[self.t - 1] = reward
        if self.process_idx == 0:
            self.logger.debug(
                "t:%s r:%s a:%s",
                self.t,
                reward,
                self.last_action,
            )

        if self.t - self.t_start == self.t_max or done or reset:
            if done:
                statevar = None
            else:
                statevar = batch_states([state], self.device, self.phi)
            self.update_on_policy(statevar)
            for _ in range(self.n_times_replay):
                self.update_from_replay()
        if done or reset:
            self.train_recurrent_states = None
            self.shared_recurrent_states = None

        self.last_state = None
        self.last_action = None
        self.last_action_distrib = None
Ejemplo n.º 10
0
 def _update_obs_normalizer(self, dataset):
     assert self.obs_normalizer
     states = batch_states([b["state"] for b in dataset], self.device, self.phi)
     self.obs_normalizer.experience(states)
Ejemplo n.º 11
0
    def update_from_replay(self):

        if self.replay_buffer is None:
            return

        if len(self.replay_buffer) < self.replay_start_size:
            return

        episode = self.replay_buffer.sample_episodes(1, self.t_max)[0]

        model_recurrent_state = None
        shared_recurrent_state = None
        rewards = {}
        actions = {}
        action_distribs = {}
        action_distribs_mu = {}
        avg_action_distribs = {}
        action_values = {}
        values = {}
        for t, transition in enumerate(episode):
            bs = batch_states([transition["state"]], self.device, self.phi)
            if self.recurrent:
                (
                    (action_distrib, action_value, v),
                    model_recurrent_state,
                ) = one_step_forward(self.model, bs, model_recurrent_state)
            else:
                action_distrib, action_value, v = self.model(bs)
            with torch.no_grad():
                if self.recurrent:
                    (
                        (avg_action_distrib, _, _),
                        shared_recurrent_state,
                    ) = one_step_forward(
                        self.shared_average_model,
                        bs,
                        shared_recurrent_state,
                    )
                else:
                    avg_action_distrib, _, _ = self.shared_average_model(bs)
            actions[t] = transition["action"]
            values[t] = v
            action_distribs[t] = action_distrib
            avg_action_distribs[t] = avg_action_distrib
            rewards[t] = transition["reward"]
            action_distribs_mu[t] = transition["mu"]
            action_values[t] = action_value
        last_transition = episode[-1]
        if last_transition["is_state_terminal"]:
            R = 0
        else:
            with torch.no_grad():
                last_s = batch_states([last_transition["next_state"]],
                                      self.device, self.phi)
                if self.recurrent:
                    (_, _,
                     last_v), _ = one_step_forward(self.model, last_s,
                                                   model_recurrent_state)
                else:
                    _, _, last_v = self.model(last_s)
            R = float(last_v)
        return self.update(
            R=R,
            t_start=0,
            t_stop=len(episode),
            rewards=rewards,
            actions=actions,
            values=values,
            action_distribs=action_distribs,
            action_distribs_mu=action_distribs_mu,
            avg_action_distribs=avg_action_distribs,
            action_values=action_values,
        )
Ejemplo n.º 12
0
def batch_recurrent_experiences(experiences,
                                device,
                                phi,
                                gamma,
                                batch_states=batch_states):
    """Batch experiences for recurrent model updates.

    Args:
        experiences: list of episodes. Each episode is a list
            containing between 1 and n dicts, each containing:
              - state (object): State
              - action (object): Action
              - reward (float): Reward
              - is_state_terminal (bool): True iff next state is terminal
              - next_state (object): Next state
            The list must be sorted desc by lengths to be packed by
            `torch.nn.rnn.pack_sequence` later.
        device : GPU or CPU the tensor should be placed on
        phi : Preprocessing function
        gamma: discount factor
        batch_states: function that converts a list to a batch
    Returns:
        dict of batched transitions
    """
    assert _is_sorted_desc_by_lengths(experiences)
    flat_transitions = flatten_sequences_time_first(experiences)
    batch_exp = {
        "state": [
            batch_states([transition["state"]
                          for transition in ep], device, phi)
            for ep in experiences
        ],
        "action":
        torch.as_tensor(
            [transition["action"] for transition in flat_transitions],
            device=device),
        "reward":
        torch.as_tensor(
            [transition["reward"] for transition in flat_transitions],
            dtype=torch.float,
            device=device,
        ),
        "next_state": [
            batch_states([transition["next_state"]
                          for transition in ep], device, phi)
            for ep in experiences
        ],
        "is_state_terminal":
        torch.as_tensor(
            [
                transition["is_state_terminal"]
                for transition in flat_transitions
            ],
            dtype=torch.float,
            device=device,
        ),
        "discount":
        torch.full((len(flat_transitions), ),
                   gamma,
                   dtype=torch.float,
                   device=device),
        "recurrent_state":
        recurrent_state_from_numpy(
            concatenate_recurrent_states(
                [ep[0]["recurrent_state"] for ep in experiences]),
            device,
        ),
        "next_recurrent_state":
        recurrent_state_from_numpy(
            concatenate_recurrent_states(
                [ep[0]["next_recurrent_state"] for ep in experiences]),
            device,
        ),
    }
    # Batch next actions only when all the transitions have them
    if all(transition["next_action"] is not None
           for transition in flat_transitions):
        batch_exp["next_action"] = torch.as_tensor(
            [transition["next_action"] for transition in flat_transitions],
            device=device,
        )
    return batch_exp