Пример #1
0
    def rollout(self, root_node, stop_at_leaf=False):
        node = root_node
        v = -1
        finish = False
        while not finish:
            if node.is_end():
                break
            v = node.best_child()
            if node.children[v] is None:
                env = MISEnv() if use_dense else MISEnv_Sparse()
                env.set_graph(node.graph)
                next_graph, r, done, info = env.step(v)
                node.children[v] = MCTSNode(next_graph,
                                            self,
                                            idx=v,
                                            parent=node)
                if stop_at_leaf:
                    finish = True
            node = node.children[v]

        # backpropagate V
        V = node.state_value()
        while node is not root_node:
            V += 1
            self.update_parent(node, V)
            node = node.parent
        self.root_max = max(self.root_max, V)
        return V
Пример #2
0
    def policy_search(self, graph):
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            with torch.no_grad():
                p, v = self.gnn(graph)
            action = np.random.choice(n, p=p.detach().numpy())
            graph, reward, done, info = env.step(action)
        return reward
Пример #3
0
    def greedy_v_search(self, graph):
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            with torch.no_grad():
                p, v = self.gnn(graph)
            action = v.detach().numpy().argmax()
            graph, reward, done, info = env.step(action)
        return reward
Пример #4
0
    def train(self, graph, TAU, batch_size=10, iter_p=2, stop_at_leaf=False):
        self.gnnhash.clear()
        mse = torch.nn.MSELoss()
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        graphs = []
        actions = []
        pis = []
        means = []
        stds = []
        done = False
        while not done:
            n, _ = graph.shape
            node = MCTSNode(graph, self)
            means.append(node.reward_mean)
            stds.append(node.reward_std)
            pi = self.get_improved_pi(node,
                                      TAU,
                                      iter_p=iter_p,
                                      stop_at_leaf=stop_at_leaf)
            action = np.random.choice(n, p=pi)
            graphs.append(graph)
            actions.append(action)
            pis.append(pi)
            graph, reward, done, info = env.step(action)

        T = len(graphs)
        idxs = [i for i in range(T)]
        np.random.shuffle(idxs)
        i = 0
        while i < T:
            size = min(batch_size, T - i)
            self.optimizer.zero_grad()
            loss = torch.tensor([0], dtype=torch.float32)
            for j in range(i, i + size):
                idx = idxs[j]
                Timer.start('gnn')
                p, v = self.gnn(graphs[idx], True)
                Timer.end('gnn')
                n, _ = graphs[idx].shape
                # normalize z with mean, std
                z = torch.tensor(((T - idx) - means[idx]) / stds[idx],
                                 dtype=torch.float32)
                loss += torch.tensor(mse(z, v[actions[idx]]) - \
                    (torch.tensor(pis[idx], dtype=torch.float32) * torch.log(p + EPS)).sum(), dtype=torch.float32, requires_grad = True)
            loss /= size
            loss.backward()
            self.optimizer.step()
            i += size
Пример #5
0
    def best_search1(self, graph, TAU=0.1, iter_p=1):
        self.gnnhash.clear()
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        ma = 0
        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            node = MCTSNode(graph, self)
            pi = self.get_improved_pi(node, TAU, iter_p=iter_p)
            ma = max(ma, self.root_max + reward)
            action = np.random.choice(n, p=pi)
            graph, reward, done, info = env.step(action)
        return ma, reward