def rollout(self, root_node, stop_at_leaf=False): node = root_node v = -1 finish = False while not finish: if node.is_end(): break v = node.best_child() if node.children[v] is None: env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(node.graph) next_graph, r, done, info = env.step(v) node.children[v] = MCTSNode(next_graph, self, idx=v, parent=node) if stop_at_leaf: finish = True node = node.children[v] # backpropagate V V = node.state_value() while node is not root_node: V += 1 self.update_parent(node, V) node = node.parent self.root_max = max(self.root_max, V) return V
def greedy_v_search(self, graph): env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) reward = 0 done = False while not done: n, _ = graph.shape with torch.no_grad(): p, v = self.gnn(graph) action = v.detach().numpy().argmax() graph, reward, done, info = env.step(action) return reward
def policy_search(self, graph): env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) reward = 0 done = False while not done: n, _ = graph.shape with torch.no_grad(): p, v = self.gnn(graph) action = np.random.choice(n, p=p.detach().numpy()) graph, reward, done, info = env.step(action) return reward
def train(self, graph, TAU, batch_size=10, iter_p=2, stop_at_leaf=False): self.gnnhash.clear() mse = torch.nn.MSELoss() env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) graphs = [] actions = [] pis = [] means = [] stds = [] done = False while not done: n, _ = graph.shape node = MCTSNode(graph, self) means.append(node.reward_mean) stds.append(node.reward_std) pi = self.get_improved_pi(node, TAU, iter_p=iter_p, stop_at_leaf=stop_at_leaf) action = np.random.choice(n, p=pi) graphs.append(graph) actions.append(action) pis.append(pi) graph, reward, done, info = env.step(action) T = len(graphs) idxs = [i for i in range(T)] np.random.shuffle(idxs) i = 0 while i < T: size = min(batch_size, T - i) self.optimizer.zero_grad() loss = torch.tensor([0], dtype=torch.float32) for j in range(i, i + size): idx = idxs[j] Timer.start('gnn') p, v = self.gnn(graphs[idx], True) Timer.end('gnn') n, _ = graphs[idx].shape # normalize z with mean, std z = torch.tensor(((T - idx) - means[idx]) / stds[idx], dtype=torch.float32) loss += torch.tensor(mse(z, v[actions[idx]]) - \ (torch.tensor(pis[idx], dtype=torch.float32) * torch.log(p + EPS)).sum(), dtype=torch.float32, requires_grad = True) loss /= size loss.backward() self.optimizer.step() i += size
def best_search1(self, graph, TAU=0.1, iter_p=1): self.gnnhash.clear() env = MISEnv() if use_dense else MISEnv_Sparse() env.set_graph(graph) ma = 0 reward = 0 done = False while not done: n, _ = graph.shape node = MCTSNode(graph, self) pi = self.get_improved_pi(node, TAU, iter_p=iter_p) ma = max(ma, self.root_max + reward) action = np.random.choice(n, p=pi) graph, reward, done, info = env.step(action) return ma, reward
class Trainer: def __init__(self, policy, test_graphs=[]): self.env = MISEnv() if use_dense else MISEnv_Sparse() self.policy = policy self.optimizer = torch.optim.Adam(self.policy.model.parameters(), lr=0.01) self.test_graphs = test_graphs self.rewards = [] def train(self, adj, iter=10, batch=10, print_log=True): self.env.set_graph(adj) reward_sum = 0 for epoch in range(iter): self.optimizer.zero_grad() rewards = torch.empty(batch) log_probs = torch.zeros(batch) for n in range(batch): graph = self.env.reset() done = False while done == False: action, prob = self.policy.act(graph) log_probs[n] += torch.log(prob) graph, reward, done, info = self.env.step(action) rewards[n] = reward if print_log: print(rewards) reward_sum += rewards.detach().numpy().sum() reward_mean = reward_sum / ((epoch + 1) * batch) loss = -((rewards - reward_mean) * log_probs).mean() loss.backward() self.optimizer.step() def solution(self, adj): g = adj self.env.set_graph(g) while 1: v, prob = self.policy.act(g) g, r, finish, info = self.env.step(v) if finish: break return r