예제 #1
0
    def _batch_act_train(self, batch_obs):
        assert self.training
        b_state = self.batch_states(batch_obs, self.device, self.phi)

        if self.obs_normalizer:
            b_state = self.obs_normalizer(b_state, update=False)

        num_envs = len(batch_obs)
        if self.batch_last_episode is None:
            self._initialize_batch_variables(num_envs)
        assert len(self.batch_last_episode) == num_envs
        assert len(self.batch_last_state) == num_envs
        assert len(self.batch_last_action) == num_envs

        # action_distrib will be recomputed when computing gradients
        with torch.no_grad(), pfrl.utils.evaluating(self.model):
            if self.recurrent:
                assert self.train_prev_recurrent_states is None
                self.train_prev_recurrent_states = self.train_recurrent_states
                (
                    (action_distrib, batch_value),
                    self.train_recurrent_states,
                ) = one_step_forward(self.model, b_state,
                                     self.train_prev_recurrent_states)
            else:
                action_distrib, batch_value = self.model(b_state)
            batch_action = action_distrib.sample().cpu().numpy()
            self.entropy_record.extend(action_distrib.entropy().cpu().numpy())
            self.value_record.extend(batch_value.cpu().numpy())

        self.batch_last_state = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action
예제 #2
0
    def _act_train(self, obs):

        batch_obs = self.batch_states([obs], self.device, self.phi)
        if self.recurrent:
            action_distrib, self.train_recurrent_states = one_step_forward(
                self.model, batch_obs, self.train_recurrent_states
            )
        else:
            action_distrib = self.model(batch_obs)
        batch_action = action_distrib.sample()

        # Save values used to compute losses
        self.log_prob_sequences[-1].append(action_distrib.log_prob(batch_action))
        self.entropy_sequences[-1].append(action_distrib.entropy())

        action = batch_action.cpu().numpy()[0]

        self.logger.debug("t:%s a:%s", self.t, action)

        # Update stats
        self.average_entropy += (1 - self.average_entropy_decay) * (
            float(action_distrib.entropy()) - self.average_entropy
        )

        return action
예제 #3
0
파일: acer.py 프로젝트: imatge-upc/PiCoEDL
    def update_on_policy(self, statevar):
        assert self.t_start < self.t

        if not self.disable_online_update:
            if statevar is None:
                R = 0
            else:
                with torch.no_grad():
                    if self.recurrent:
                        (_, _,
                         v), _ = one_step_forward(self.model, statevar,
                                                  self.train_recurrent_states)
                    else:
                        _, _, v = self.model(statevar)
                R = float(v)
            self.update(
                t_start=self.t_start,
                t_stop=self.t,
                R=R,
                actions=self.past_actions,
                rewards=self.past_rewards,
                values=self.past_values,
                action_values=self.past_action_values,
                action_distribs=self.past_action_distrib,
                action_distribs_mu=None,
                avg_action_distribs=self.past_avg_action_distrib,
            )

        self.init_history_data_for_online_update()
        self.train_recurrent_states = detach_recurrent_state(
            self.train_recurrent_states)
예제 #4
0
파일: acer.py 프로젝트: pfnet/pfrl
    def _act_train(self, obs):

        statevar = batch_states([obs], self.device, self.phi)

        if self.recurrent:
            (
                (action_distrib, action_value, v),
                self.train_recurrent_states,
            ) = one_step_forward(self.model, statevar, self.train_recurrent_states)
        else:
            action_distrib, action_value, v = self.model(statevar)
        self.past_action_values[self.t] = action_value
        action = action_distrib.sample()[0]

        # Save values for a later update
        self.past_values[self.t] = v
        self.past_action_distrib[self.t] = action_distrib
        with torch.no_grad():
            if self.recurrent:
                (
                    (avg_action_distrib, _, _),
                    self.shared_recurrent_states,
                ) = one_step_forward(
                    self.shared_average_model,
                    statevar,
                    self.shared_recurrent_states,
                )
            else:
                avg_action_distrib, _, _ = self.shared_average_model(statevar)
        self.past_avg_action_distrib[self.t] = avg_action_distrib

        self.past_actions[self.t] = action

        # Update stats
        self.average_value += (1 - self.average_value_decay) * (
            float(v) - self.average_value
        )
        self.average_entropy += (1 - self.average_entropy_decay) * (
            float(action_distrib.entropy()) - self.average_entropy
        )

        self.last_state = obs
        self.last_action = action.numpy()
        self.last_action_distrib = deepcopy_distribution(action_distrib)

        return self.last_action
예제 #5
0
파일: dqn.py 프로젝트: xylee95/pfrl
 def _evaluate_model_and_update_recurrent_states(
     self, batch_obs: Sequence[Any]
 ) -> ActionValue:
     batch_xs = self.batch_states(batch_obs, self.device, self.phi)
     if self.recurrent:
         if self.training:
             self.train_prev_recurrent_states = self.train_recurrent_states
             batch_av, self.train_recurrent_states = one_step_forward(
                 self.model, batch_xs, self.train_recurrent_states
             )
         else:
             batch_av, self.test_recurrent_states = one_step_forward(
                 self.model, batch_xs, self.test_recurrent_states
             )
     else:
         batch_av = self.model(batch_xs)
     return batch_av
예제 #6
0
파일: a3c.py 프로젝트: xylee95/pfrl
 def _act_eval(self, obs):
     # Use the process-local model for acting
     with torch.no_grad(), pfrl.utils.evaluating(self.model):
         statevar = self.batch_states([obs], self.device, self.phi)
         if self.recurrent:
             (pout, _), self.test_recurrent_states = one_step_forward(
                 self.model, statevar, self.test_recurrent_states)
         else:
             pout, _ = self.model(statevar)
         if self.act_deterministically:
             return mode_of_distribution(pout).cpu().numpy()[0]
         else:
             return pout.sample().cpu().numpy()[0]
예제 #7
0
 def _act_eval(self, obs):
     with torch.no_grad():
         batch_obs = self.batch_states([obs], self.device, self.phi)
         if self.recurrent:
             action_distrib, self.test_recurrent_states = one_step_forward(
                 self.model, batch_obs, self.test_recurrent_states
             )
         else:
             action_distrib = self.model(batch_obs)
         if self.act_deterministically:
             return mode_of_distribution(action_distrib).cpu().numpy()[0]
         else:
             return action_distrib.sample().cpu().numpy()[0]
예제 #8
0
 def _act_eval(self, obs):
     # Use the process-local model for acting
     with torch.no_grad():
         statevar = batch_states([obs], self.device, self.phi)
         if self.recurrent:
             (action_distrib, _,
              _), self.test_recurrent_states = one_step_forward(
                  self.model, statevar, self.test_recurrent_states)
         else:
             action_distrib, _, _ = self.model(statevar)
         if self.act_deterministically:
             return mode_of_distribution(action_distrib).numpy()[0]
         else:
             return action_distrib.sample().numpy()[0]
예제 #9
0
    def _batch_act_eval(self, batch_obs):
        assert not self.training
        b_state = self.batch_states(batch_obs, self.device, self.phi)

        if self.obs_normalizer:
            b_state = self.obs_normalizer(b_state, update=False)

        with torch.no_grad(), pfrl.utils.evaluating(self.model):
            if self.recurrent:
                (action_distrib, _), self.test_recurrent_states = one_step_forward(
                    self.model, b_state, self.test_recurrent_states
                )
            else:
                action_distrib, _ = self.model(b_state)
            if self.act_deterministically:
                action = mode_of_distribution(action_distrib).cpu().numpy()
            else:
                action = action_distrib.sample().cpu().numpy()

        return action
예제 #10
0
파일: a3c.py 프로젝트: xylee95/pfrl
    def _act_train(self, obs):

        self.past_obs[self.t] = obs

        with torch.no_grad():
            statevar = self.batch_states([obs], self.device, self.phi)
            if self.recurrent:
                self.past_recurrent_state[self.t] = self.train_recurrent_states
                (pout, vout), self.train_recurrent_states = one_step_forward(
                    self.model, statevar, self.train_recurrent_states)
            else:
                pout, vout = self.model(statevar)
            # Do not backprop through sampled actions
            action = pout.sample()
            self.past_action[self.t] = action[0].detach()
            action = action.cpu().numpy()[0]

        # Update stats
        self.average_value += (1 - self.average_value_decay) * (
            float(vout) - self.average_value)
        self.average_entropy += (1 - self.average_entropy_decay) * (
            float(pout.entropy()) - self.average_entropy)

        return action
예제 #11
0
파일: a3c.py 프로젝트: xylee95/pfrl
    def update(self, statevar):
        assert self.t_start < self.t

        n = self.t - self.t_start

        self.assert_shared_memory()

        if statevar is None:
            R = 0
        else:
            with torch.no_grad(), pfrl.utils.evaluating(self.model):
                if self.recurrent:
                    (_,
                     vout), _ = one_step_forward(self.model, statevar,
                                                 self.train_recurrent_states)
                else:
                    _, vout = self.model(statevar)
            R = float(vout)

        pi_loss_factor = self.pi_loss_coef
        v_loss_factor = self.v_loss_coef

        # Normalize the loss of sequences truncated by terminal states
        if self.keep_loss_scale_same and self.t - self.t_start < self.t_max:
            factor = self.t_max / (self.t - self.t_start)
            pi_loss_factor *= factor
            v_loss_factor *= factor

        if self.normalize_grad_by_t_max:
            pi_loss_factor /= self.t - self.t_start
            v_loss_factor /= self.t - self.t_start

        # Batch re-compute for efficient backprop
        batch_obs = self.batch_states(
            [self.past_obs[i] for i in range(self.t_start, self.t)],
            self.device,
            self.phi,
        )
        if self.recurrent:
            (batch_distrib, batch_v), _ = pack_and_forward(
                self.model,
                [batch_obs],
                self.past_recurrent_state[self.t_start],
            )
        else:
            batch_distrib, batch_v = self.model(batch_obs)
        batch_action = torch.stack(
            [self.past_action[i] for i in range(self.t_start, self.t)])
        batch_log_prob = batch_distrib.log_prob(batch_action)
        batch_entropy = batch_distrib.entropy()
        rev_returns = []
        for i in reversed(range(self.t_start, self.t)):
            R *= self.gamma
            R += self.past_rewards[i]
            rev_returns.append(R)
        batch_return = torch.as_tensor(list(reversed(rev_returns)),
                                       dtype=torch.float)
        batch_adv = batch_return - batch_v.detach().squeeze(-1)
        assert batch_log_prob.shape == (n, )
        assert batch_adv.shape == (n, )
        assert batch_entropy.shape == (n, )
        pi_loss = torch.sum(-batch_adv * batch_log_prob -
                            self.beta * batch_entropy,
                            dim=0)
        assert batch_v.shape == (n, 1)
        assert batch_return.shape == (n, )
        v_loss = F.mse_loss(batch_v, batch_return[..., None],
                            reduction="sum") / 2

        if pi_loss_factor != 1.0:
            pi_loss *= pi_loss_factor

        if v_loss_factor != 1.0:
            v_loss *= v_loss_factor

        if self.process_idx == 0:
            logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss)

        total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss)

        # Compute gradients using thread-specific model
        self.model.zero_grad()
        total_loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        # Copy the gradients to the globally shared model
        copy_param.copy_grad(target_link=self.shared_model,
                             source_link=self.model)
        # Update the globally shared model
        self.optimizer.step()
        if self.process_idx == 0:
            logger.debug("update")

        self.sync_parameters()

        self.past_obs = {}
        self.past_action = {}
        self.past_rewards = {}
        self.past_recurrent_state = {}

        self.t_start = self.t
예제 #12
0
    def update_from_replay(self):

        if self.replay_buffer is None:
            return

        if len(self.replay_buffer) < self.replay_start_size:
            return

        episode = self.replay_buffer.sample_episodes(1, self.t_max)[0]

        model_recurrent_state = None
        shared_recurrent_state = None
        rewards = {}
        actions = {}
        action_distribs = {}
        action_distribs_mu = {}
        avg_action_distribs = {}
        action_values = {}
        values = {}
        for t, transition in enumerate(episode):
            bs = batch_states([transition["state"]], self.device, self.phi)
            if self.recurrent:
                (
                    (action_distrib, action_value, v),
                    model_recurrent_state,
                ) = one_step_forward(self.model, bs, model_recurrent_state)
            else:
                action_distrib, action_value, v = self.model(bs)
            with torch.no_grad():
                if self.recurrent:
                    (
                        (avg_action_distrib, _, _),
                        shared_recurrent_state,
                    ) = one_step_forward(
                        self.shared_average_model,
                        bs,
                        shared_recurrent_state,
                    )
                else:
                    avg_action_distrib, _, _ = self.shared_average_model(bs)
            actions[t] = transition["action"]
            values[t] = v
            action_distribs[t] = action_distrib
            avg_action_distribs[t] = avg_action_distrib
            rewards[t] = transition["reward"]
            action_distribs_mu[t] = transition["mu"]
            action_values[t] = action_value
        last_transition = episode[-1]
        if last_transition["is_state_terminal"]:
            R = 0
        else:
            with torch.no_grad():
                last_s = batch_states([last_transition["next_state"]],
                                      self.device, self.phi)
                if self.recurrent:
                    (_, _,
                     last_v), _ = one_step_forward(self.model, last_s,
                                                   model_recurrent_state)
                else:
                    _, _, last_v = self.model(last_s)
            R = float(last_v)
        return self.update(
            R=R,
            t_start=0,
            t_stop=len(episode),
            rewards=rewards,
            actions=actions,
            values=values,
            action_distribs=action_distribs,
            action_distribs_mu=action_distribs_mu,
            avg_action_distribs=avg_action_distribs,
            action_values=action_values,
        )
예제 #13
0
 def get_and_concat_rs_forward():
     _, rs = one_step_forward(par, x_t0, None)
     rs0 = get_recurrent_state_at(rs, 0, detach=True)
     rs1 = get_recurrent_state_at(rs, 1, detach=True)
     concat_rs = concatenate_recurrent_states([rs0, rs1])
     return one_step_forward(par, x_t1, concat_rs)
예제 #14
0
 def mask01_forward_twice():
     _, rs = one_step_forward(par, x_t0, None)
     rs = mask_recurrent_state_at(rs, [0, 1])
     return one_step_forward(par, x_t1, rs)
예제 #15
0
 def no_mask_forward_twice():
     _, rs = one_step_forward(par, x_t0, None)
     return one_step_forward(par, x_t1, rs)