示例#1
0
plt.xlabel('epochs')
plt.ylabel('Validation Accuracy')
plt.show()

#plotting accuracy and loss graphs for training set per epoch.
plt.plot(a, loss_append)
plt.xlabel('epochs')
plt.ylabel('Training Loss')
plt.show()

plt.plot(a, train_acc_append)
plt.xlabel('epochs')
plt.ylabel('Training Accuracy')
plt.show()

net.eval()
#testing
correct = 0
total = 0
with torch.no_grad():
    for (inputs, labels) in testloader:
        inputs = Variable(inputs)
        labels = Variable(labels).long()
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' %
      (100 * correct / total))
class DQN:
    def __init__(self,
                 memory_size=50000,
                 batch_size=128,
                 gamma=0.99,
                 lr=1e-3,
                 n_step=500000):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.gamma = gamma

        # memory
        self.memory_size = memory_size
        self.Memory = ReplayMemory(self.memory_size)
        self.batch_size = batch_size

        # network
        self.target_net = Net().to(self.device)
        self.eval_net = Net().to(self.device)
        self.target_update()  # initialize same weight
        self.target_net.eval()

        # optim
        self.optimizer = optim.Adam(self.eval_net.parameters(), lr=lr)

    def select_action(self, state, eps):
        prob = random.random()
        if prob > eps:
            return self.eval_net.act(state), False
        else:
            return (torch.tensor(
                [[random.randrange(0, 9)]],
                device=self.device,
                dtype=torch.long,
            ), True)

    def select_dummy_action(self, state):
        state = state.reshape(3, 3, 3)

        open_spots = state[:, :, 0].reshape(-1)

        p = open_spots / open_spots.sum()

        return np.random.choice(np.arange(9), p=p)

    def target_update(self):
        self.target_net.load_state_dict(self.eval_net.state_dict())

    def learn(self):
        if self.Memory.__len__() < self.batch_size:
            return

        # random batch sampling
        transitions = self.Memory.sampling(self.batch_size)
        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(
            tuple(map(lambda s: s is not None, batch.next_state)),
            device=self.device,
            dtype=torch.bool,
        )

        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None]).to(self.device)
        state_batch = torch.cat(batch.state).to(self.device)
        action_batch = torch.cat(batch.action).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)

        # Q(s)
        Q_s = self.eval_net(state_batch).gather(1, action_batch)

        # maxQ(s') no grad for target_net
        Q_s_ = torch.zeros(self.batch_size, device=self.device)
        Q_s_[non_final_mask] = self.target_net(non_final_next_states).max(
            1)[0].detach()

        # Q_target=R+γ*maxQ(s')
        Q_target = reward_batch + (Q_s_ * self.gamma)

        # loss_fnc---(R+γ*maxQ(s'))-Q(s)
        # huber loss with delta=1
        loss = F.smooth_l1_loss(Q_s, Q_target.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.eval_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def load_net(self, name):
        self.action_net = torch.load(name).cpu()

    def load_weight(self, name):
        self.eval_net.load_state_dict(torch.load(name))
        self.eval_net = self.eval_net.cpu()

    def act(self, state):
        with torch.no_grad():
            p = F.softmax(self.action_net.forward(state)).cpu().numpy()
            valid_moves = (state.cpu().numpy().reshape(
                3, 3, 3).argmax(axis=2).reshape(-1) == 0)
            p = valid_moves * p
            return p.argmax()