Exemplo n.º 1
0
    def handle_batch(self, batch: Sequence[torch.Tensor]):
        # model train/valid step
        # states, actions, rewards, dones, next_states = batch
        states, actions, rewards, next_states, dones = (
            batch["state"].squeeze_(1).to(torch.float32),
            batch["action"].to(torch.float32),
            batch["reward"].to(torch.float32),
            batch["next_state"].squeeze_(1).to(torch.float32),
            batch["done"].to(torch.bool),
        )

        # get actions for the current state
        pred_actions = self.actor(states)
        # get q-values for the actions in current states
        pred_critic_states = torch.cat([states, pred_actions], 1)
        # use q-values to train the actor model
        policy_loss = (-self.critic(pred_critic_states)).mean()

        with torch.no_grad():
            # get possible actions for the next states
            next_state_actions = self.target_actor(next_states)
            # get possible q-values for the next actions
            next_critic_states = torch.cat([next_states, next_state_actions],
                                           1)
            next_state_values = self.target_critic(
                next_critic_states).detach().squeeze()
            next_state_values[dones] = 0.0

        # compute Bellman's equation value
        target_state_values = next_state_values * self.gamma + rewards
        # compute predicted values
        critic_states = torch.cat([states, actions], 1)
        state_values = self.critic(critic_states).squeeze()

        # train the critic model
        value_loss = self.criterion(state_values, target_state_values.detach())

        self.batch_metrics.update({
            "critic_loss": value_loss,
            "actor_loss": policy_loss
        })
        for key in ["critic_loss", "actor_loss"]:
            self.meters[key].update(self.batch_metrics[key].item(),
                                    self.batch_size)

        if self.is_train_loader:
            self.actor.zero_grad()
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            self.actor_optimizer.step()

            self.critic.zero_grad()
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            self.critic_optimizer.step()

            if self.global_batch_step % self.tau_period == 0:
                soft_update(self.target_actor, self.actor, self.tau)
                soft_update(self.target_critic, self.critic, self.tau)
Exemplo n.º 2
0
 def on_stage_start(self, runner: dl.IRunner):
     super().on_stage_start(runner)
     self.actor = self.model[self.actor_key]
     self.critic = self.model[self.critic_key]
     self.target_actor = self.model[self.target_actor_key]
     self.target_critic = self.model[self.target_critic_key]
     soft_update(self.target_actor, self.actor, 1.0)
     soft_update(self.target_critic, self.critic, 1.0)
     self.actor_optimizer = self.optimizer[self.actor_optimizer_key]
     self.critic_optimizer = self.optimizer[self.critic_optimizer_key]
Exemplo n.º 3
0
    def handle_batch(self, batch: Sequence[np.array]):
        # model train/valid step
        states, actions, rewards, next_states, dones = (
            batch["state"].squeeze_(1).to(torch.float32),
            batch["action"].to(torch.int64),
            batch["reward"].to(torch.float32),
            batch["next_state"].squeeze_(1).to(torch.float32),
            batch["done"].to(torch.bool),
        )

        # get q-values for all actions in current states
        state_qvalues = self.origin_network(states)
        # select q-values for chosen actions
        state_action_qvalues = state_qvalues.gather(
            1, actions.unsqueeze(-1)).squeeze(-1)

        # compute q-values for all actions in next states
        # compute V*(next_states) using predicted next q-values
        # at the last state we shall use simplified formula:
        # Q(s,a) = r(s,a) since s' doesn't exist
        with torch.no_grad():
            next_state_qvalues = self.target_network(next_states)
            next_state_values = next_state_qvalues.max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        # compute "target q-values" for loss,
        # it's what's inside square parentheses in the above formula.
        target_state_action_qvalues = next_state_values * self.gamma + rewards

        # mean squared error loss to minimize
        loss = self.criterion(state_action_qvalues,
                              target_state_action_qvalues.detach())
        self.batch_metrics.update({"loss": loss})
        for key in ["loss"]:
            self.meters[key].update(self.batch_metrics[key].item(),
                                    self.batch_size)

        if self.is_train_loader:
            self.engine.backward(loss)
            self.optimizer.step()
            self.optimizer.zero_grad()

            if self.batch_step % self.tau_period == 0:
                soft_update(self.target_network, self.origin_network, self.tau)
Exemplo n.º 4
0
 def on_stage_start(self, runner: dl.IRunner):
     super().on_stage_start(runner)
     self.origin_network = self.model[self.origin_key]
     self.target_network = self.model[self.target_key]
     soft_update(self.target_network, self.origin_network, 1.0)