Ejemplo n.º 1
0
    def rollout(self, root_node, stop_at_leaf=False):
        node = root_node
        v = -1
        finish = False
        while not finish:
            if node.is_end():
                break
            v = node.best_child()
            if node.children[v] is None:
                env = MISEnv() if use_dense else MISEnv_Sparse()
                env.set_graph(node.graph)
                next_graph, r, done, info = env.step(v)
                node.children[v] = MCTSNode(next_graph,
                                            self,
                                            idx=v,
                                            parent=node)
                if stop_at_leaf:
                    finish = True
            node = node.children[v]

        # backpropagate V
        V = node.state_value()
        while node is not root_node:
            V += 1
            self.update_parent(node, V)
            node = node.parent
        self.root_max = max(self.root_max, V)
        return V
Ejemplo n.º 2
0
 def __init__(self, policy, test_graphs=[]):
     self.env = MISEnv() if use_dense else MISEnv_Sparse()
     self.policy = policy
     self.optimizer = torch.optim.Adam(self.policy.model.parameters(),
                                       lr=0.01)
     self.test_graphs = test_graphs
     self.rewards = []
Ejemplo n.º 3
0
    def greedy_v_search(self, graph):
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            with torch.no_grad():
                p, v = self.gnn(graph)
            action = v.detach().numpy().argmax()
            graph, reward, done, info = env.step(action)
        return reward
Ejemplo n.º 4
0
    def policy_search(self, graph):
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            with torch.no_grad():
                p, v = self.gnn(graph)
            action = np.random.choice(n, p=p.detach().numpy())
            graph, reward, done, info = env.step(action)
        return reward
Ejemplo n.º 5
0
    def train(self, graph, TAU, batch_size=10, iter_p=2, stop_at_leaf=False):
        self.gnnhash.clear()
        mse = torch.nn.MSELoss()
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        graphs = []
        actions = []
        pis = []
        means = []
        stds = []
        done = False
        while not done:
            n, _ = graph.shape
            node = MCTSNode(graph, self)
            means.append(node.reward_mean)
            stds.append(node.reward_std)
            pi = self.get_improved_pi(node,
                                      TAU,
                                      iter_p=iter_p,
                                      stop_at_leaf=stop_at_leaf)
            action = np.random.choice(n, p=pi)
            graphs.append(graph)
            actions.append(action)
            pis.append(pi)
            graph, reward, done, info = env.step(action)

        T = len(graphs)
        idxs = [i for i in range(T)]
        np.random.shuffle(idxs)
        i = 0
        while i < T:
            size = min(batch_size, T - i)
            self.optimizer.zero_grad()
            loss = torch.tensor([0], dtype=torch.float32)
            for j in range(i, i + size):
                idx = idxs[j]
                Timer.start('gnn')
                p, v = self.gnn(graphs[idx], True)
                Timer.end('gnn')
                n, _ = graphs[idx].shape
                # normalize z with mean, std
                z = torch.tensor(((T - idx) - means[idx]) / stds[idx],
                                 dtype=torch.float32)
                loss += torch.tensor(mse(z, v[actions[idx]]) - \
                    (torch.tensor(pis[idx], dtype=torch.float32) * torch.log(p + EPS)).sum(), dtype=torch.float32, requires_grad = True)
            loss /= size
            loss.backward()
            self.optimizer.step()
            i += size
Ejemplo n.º 6
0
    def best_search1(self, graph, TAU=0.1, iter_p=1):
        self.gnnhash.clear()
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        ma = 0
        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            node = MCTSNode(graph, self)
            pi = self.get_improved_pi(node, TAU, iter_p=iter_p)
            ma = max(ma, self.root_max + reward)
            action = np.random.choice(n, p=pi)
            graph, reward, done, info = env.step(action)
        return ma, reward
Ejemplo n.º 7
0
class Trainer:
    def __init__(self, policy, test_graphs=[]):
        self.env = MISEnv() if use_dense else MISEnv_Sparse()
        self.policy = policy
        self.optimizer = torch.optim.Adam(self.policy.model.parameters(),
                                          lr=0.01)
        self.test_graphs = test_graphs
        self.rewards = []

    def train(self, adj, iter=10, batch=10, print_log=True):
        self.env.set_graph(adj)
        reward_sum = 0
        for epoch in range(iter):
            self.optimizer.zero_grad()
            rewards = torch.empty(batch)
            log_probs = torch.zeros(batch)
            for n in range(batch):
                graph = self.env.reset()
                done = False
                while done == False:
                    action, prob = self.policy.act(graph)
                    log_probs[n] += torch.log(prob)
                    graph, reward, done, info = self.env.step(action)

                rewards[n] = reward
            if print_log:
                print(rewards)
            reward_sum += rewards.detach().numpy().sum()
            reward_mean = reward_sum / ((epoch + 1) * batch)

            loss = -((rewards - reward_mean) * log_probs).mean()
            loss.backward()
            self.optimizer.step()

    def solution(self, adj):
        g = adj
        self.env.set_graph(g)
        while 1:
            v, prob = self.policy.act(g)
            g, r, finish, info = self.env.step(v)
            if finish:
                break
        return r