def rollout(self, root_node, stop_at_leaf=False): node = root_node v = -1 finish = False while not finish: if node.is_end(): break v = node.best_child() if node.children[v] is None: env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(node.graph) next_graph, r, done, info = env.step(v) node.children[v] = MCTSNode(next_graph, self, idx=v, parent=node) if stop_at_leaf: finish = True node = node.children[v] # backpropagate V V = node.state_value() while node is not root_node: V += 1 self.update_parent(node, V) node = node.parent self.root_max = max(self.root_max, V) return V
def policy_search(self, graph): env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) reward = 0 done = False while not done: n, _ = graph.shape with torch.no_grad(): p, v = self.gnn(graph) action = np.random.choice(n, p=p.detach().numpy()) graph, reward, done, info = env.step(action) return reward
def greedy_v_search(self, graph): env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) reward = 0 done = False while not done: n, _ = graph.shape with torch.no_grad(): p, v = self.gnn(graph) action = v.detach().numpy().argmax() graph, reward, done, info = env.step(action) return reward
def train(self, graph, TAU, batch_size=10, iter_p=2, stop_at_leaf=False): self.gnnhash.clear() mse = torch.nn.MSELoss() env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) graphs = [] actions = [] pis = [] means = [] stds = [] done = False while not done: n, _ = graph.shape node = MCTSNode(graph, self) means.append(node.reward_mean) stds.append(node.reward_std) pi = self.get_improved_pi(node, TAU, iter_p=iter_p, stop_at_leaf=stop_at_leaf) action = np.random.choice(n, p=pi) graphs.append(graph) actions.append(action) pis.append(pi) graph, reward, done, info = env.step(action) T = len(graphs) idxs = [i for i in range(T)] np.random.shuffle(idxs) i = 0 while i < T: size = min(batch_size, T - i) self.optimizer.zero_grad() loss = torch.tensor([0], dtype=torch.float32) for j in range(i, i + size): idx = idxs[j] Timer.start('gnn') p, v = self.gnn(graphs[idx], True) Timer.end('gnn') n, _ = graphs[idx].shape # normalize z with mean, std z = torch.tensor(((T - idx) - means[idx]) / stds[idx], dtype=torch.float32) loss += torch.tensor(mse(z, v[actions[idx]]) - \ (torch.tensor(pis[idx], dtype=torch.float32) * torch.log(p + EPS)).sum(), dtype=torch.float32, requires_grad = True) loss /= size loss.backward() self.optimizer.step() i += size
def best_search1(self, graph, TAU=0.1, iter_p=1): self.gnnhash.clear() env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) ma = 0 reward = 0 done = False while not done: n, _ = graph.shape node = MCTSNode(graph, self) pi = self.get_improved_pi(node, TAU, iter_p=iter_p) ma = max(ma, self.root_max + reward) action = np.random.choice(n, p=pi) graph, reward, done, info = env.step(action) return ma, reward