Exemple #1
0
    def train(self, batch=None, update_target=False):
        self.local.train()

        states, actions, rewards, next_states, terminals = batch

        states = to_tensor(states)
        actions = to_tensor(actions)
        rewards = to_tensor(rewards)
        next_states = to_tensor(next_states)
        terminals = to_tensor(terminals)

        batch_indices = range_tensor(states.size(0))

        q_next = self.target(next_states).detach()
        q_next = q_next.max(1)[0]
        q_next = self.gamma * q_next * (1 - terminals)
        q_next.add_(rewards)

        q = self.local(states)
        q = q[batch_indices, actions.long()]
        loss = self.loss(q, q_next)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.local.parameters(), 1)
        self.optimizer.step()

        if update_target:
            self.softupdate()

        return loss.item()
Exemple #2
0
    def train(self, batch=None, update_target=True):
        self.actor_local.train()

        states, actions, rewards, next_states, terminals = batch

        states = to_tensor(states)
        actions = to_tensor(actions)
        rewards = to_tensor(rewards).unsqueeze(-1)
        next_states = to_tensor(next_states)
        terminals = to_tensor(terminals).unsqueeze(-1)

        a_next = self.actor_target(next_states)
        q_next = self.critic_target(next_states, a_next)

        q_next = self.gamma * q_next * (1 - terminals)
        q_next.add_(rewards)
        q_next = q_next.detach()

        q = self.critic_local(states, actions)

        critic_loss = self.critic_loss(q, q_next)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        a = self.actor_local(states)

        actor_loss = -self.critic_local(states, a).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        if update_target:
            self.softupdate('actor')
            self.softupdate('critic')

        return [actor_loss.item()] + [critic_loss.item()]
Exemple #3
0
    def get_action(self, state, policy):
        if policy.use_network():
            action = self.model.predict(to_tensor(state))
            action = self.move_labels[int(action)]
            action = Move.from_uci(action)
        else:
            if len(list(self.environment.generate_legal_moves())) == 0:
                import cv2
                cv2.waitKey()

            action = random.choice(list(self.environment.generate_legal_moves()))

        if not self.environment.is_legal(action):
            action = random.choice(list(self.environment.generate_legal_moves()))
            # action = MCTSGameController().get_next_move(self.environment, time_allowed=1)

        return self.move_labels.index(action.str())
Exemple #4
0
 def predict(self, state=None):
     self.actor_local.eval()
     return to_numpy(self.actor_local(to_tensor(state))).flatten()
Exemple #5
0
 def predict(self, state=None):
     self.local.eval()
     return np.argmax(to_numpy(self.local(to_tensor(state))).flatten())