def __init__(self, node_features, node_labels, list_action_space, n_injected, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'): ''' bilin_q: bilinear q or not mlp_hidden: mlp hidden layer size mav_lv: max rounds of message passing ''' super(QNetNode, self).__init__() self.node_features = node_features self.identity = torch.eye(node_labels.max() + 1).to(node_labels.device) # self.node_labels = self.to_onehot(node_labels) self.n_injected = n_injected self.list_action_space = list_action_space self.total_nodes = len(list_action_space) self.bilin_q = bilin_q self.embed_dim = embed_dim self.mlp_hidden = mlp_hidden self.max_lv = max_lv self.gm = gm if mlp_hidden: self.linear_1 = nn.Linear(embed_dim * 3, mlp_hidden) self.linear_out = nn.Linear(mlp_hidden, 1) else: self.linear_out = nn.Linear(embed_dim * 3, 1) self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim)) self.bias_n2l = Parameter(torch.Tensor(embed_dim)) # self.bias_picked = Parameter(torch.Tensor(1, embed_dim)) self.conv_params = nn.Linear(embed_dim, embed_dim) self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device) weights_init(self) input_dim = (node_labels.max() + 1) * self.n_injected self.label_encoder_1 = nn.Linear(input_dim, mlp_hidden) self.label_encoder_2 = nn.Linear(mlp_hidden, embed_dim) self.device = self.node_features.device
def __init__(self, node_features, node_labels, list_action_space, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'): ''' bilin_q: bilinear q or not mlp_hidden: mlp hidden layer size mav_lv: max rounds of message passing ''' super(QNetNode, self).__init__() self.node_features = node_features self.node_labels = node_labels self.list_action_space = list_action_space self.total_nodes = len(list_action_space) self.bilin_q = bilin_q self.embed_dim = embed_dim self.mlp_hidden = mlp_hidden self.max_lv = max_lv self.gm = gm if bilin_q: last_wout = embed_dim else: last_wout = 1 self.bias_target = Parameter(torch.Tensor(1, embed_dim)) if mlp_hidden: self.linear_1 = nn.Linear(embed_dim * 2, mlp_hidden) self.linear_out = nn.Linear(mlp_hidden, last_wout) else: self.linear_out = nn.Linear(embed_dim * 2, last_wout) self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim)) self.bias_n2l = Parameter(torch.Tensor(embed_dim)) self.bias_picked = Parameter(torch.Tensor(1, embed_dim)) self.conv_params = nn.Linear(embed_dim, embed_dim) self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device) weights_init(self)
def init_setup(): data = Dataset(root='/tmp/', name=args.dataset, setting='gcn') data.features = normalize_feature(data.features) adj, features, labels = data.adj, data.features, data.labels StaticGraph.graph = nx.from_scipy_sparse_matrix(adj) dict_of_lists = nx.to_dict_of_lists(StaticGraph.graph) idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test device = torch.device('cuda') if args.ctx == 'gpu' else 'cpu' # black box setting adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False, sparse=True, device=device) victim_model = load_victim_model(data, device=device, file_path=args.saved_model) setattr(victim_model, 'norm_tool', GraphNormTool(normalize=True, gm='gcn', device=device)) output = victim_model.predict(features, adj) loss_test = F.nll_loss(output[idx_test], labels[idx_test]) acc_test = accuracy(output[idx_test], labels[idx_test]) print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item())) return features, labels, idx_val, idx_test, victim_model, dict_of_lists, adj
class QNetNode(nn.Module): def __init__(self, node_features, node_labels, list_action_space, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'): ''' bilin_q: bilinear q or not mlp_hidden: mlp hidden layer size mav_lv: max rounds of message passing ''' super(QNetNode, self).__init__() self.node_features = node_features self.node_labels = node_labels self.list_action_space = list_action_space self.total_nodes = len(list_action_space) self.bilin_q = bilin_q self.embed_dim = embed_dim self.mlp_hidden = mlp_hidden self.max_lv = max_lv self.gm = gm if bilin_q: last_wout = embed_dim else: last_wout = 1 self.bias_target = Parameter(torch.Tensor(1, embed_dim)) if mlp_hidden: self.linear_1 = nn.Linear(embed_dim * 2, mlp_hidden) self.linear_out = nn.Linear(mlp_hidden, last_wout) else: self.linear_out = nn.Linear(embed_dim * 2, last_wout) self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim)) self.bias_n2l = Parameter(torch.Tensor(embed_dim)) self.bias_picked = Parameter(torch.Tensor(1, embed_dim)) self.conv_params = nn.Linear(embed_dim, embed_dim) self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device) weights_init(self) def make_spmat(self, n_rows, n_cols, row_idx, col_idx): idxes = torch.LongTensor([[row_idx], [col_idx]]) values = torch.ones(1) sp = torch.sparse.FloatTensor(idxes, values, torch.Size([n_rows, n_cols])) if next(self.parameters()).is_cuda: sp = sp.cuda() return sp def forward(self, time_t, states, actions, greedy_acts=False, is_inference=False): if self.node_features.data.is_sparse: input_node_linear = torch.spmm(self.node_features, self.w_n2l) else: input_node_linear = torch.mm(self.node_features, self.w_n2l) input_node_linear += self.bias_n2l # TODO the number of target nodes is batch_size, it actually parallizes target_nodes, batch_graph, picked_nodes = zip(*states) list_pred = [] prefix_sum = [] for i in range(len(batch_graph)): region = self.list_action_space[target_nodes[i]] node_embed = input_node_linear.clone() if picked_nodes is not None and picked_nodes[i] is not None: with torch.set_grad_enabled(mode=not is_inference): picked_sp = self.make_spmat(self.total_nodes, 1, picked_nodes[i], 0) node_embed += torch.spmm(picked_sp, self.bias_picked) region = self.list_action_space[picked_nodes[i]] if not self.bilin_q: with torch.set_grad_enabled(mode=not is_inference): # with torch.no_grad(): target_sp = self.make_spmat(self.total_nodes, 1, target_nodes[i], 0) node_embed += torch.spmm(target_sp, self.bias_target) with torch.set_grad_enabled(mode=not is_inference): device = self.node_features.device adj = self.norm_tool.norm_extra( batch_graph[i].get_extra_adj(device)) lv = 0 input_message = node_embed node_embed = F.relu(input_message) while lv < self.max_lv: n2npool = torch.spmm(adj, node_embed) node_linear = self.conv_params(n2npool) merged_linear = node_linear + input_message node_embed = F.relu(merged_linear) lv += 1 target_embed = node_embed[target_nodes[i], :].view(-1, 1) if region is not None: node_embed = node_embed[region] graph_embed = torch.mean(node_embed, dim=0, keepdim=True) if actions is None: graph_embed = graph_embed.repeat(node_embed.size()[0], 1) else: if region is not None: act_idx = region.index(actions[i]) else: act_idx = actions[i] node_embed = node_embed[act_idx, :].view(1, -1) embed_s_a = torch.cat((node_embed, graph_embed), dim=1) if self.mlp_hidden: embed_s_a = F.relu(self.linear_1(embed_s_a)) raw_pred = self.linear_out(embed_s_a) if self.bilin_q: raw_pred = torch.mm(raw_pred, target_embed) list_pred.append(raw_pred) if greedy_acts: actions, _ = node_greedy_actions(target_nodes, picked_nodes, list_pred, self) return actions, list_pred
class QNetNode(nn.Module): def __init__(self, node_features, node_labels, list_action_space, n_injected, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'): ''' bilin_q: bilinear q or not mlp_hidden: mlp hidden layer size mav_lv: max rounds of message passing ''' super(QNetNode, self).__init__() self.node_features = node_features self.identity = torch.eye(node_labels.max() + 1).to(node_labels.device) # self.node_labels = self.to_onehot(node_labels) self.n_injected = n_injected self.list_action_space = list_action_space self.total_nodes = len(list_action_space) self.bilin_q = bilin_q self.embed_dim = embed_dim self.mlp_hidden = mlp_hidden self.max_lv = max_lv self.gm = gm if mlp_hidden: self.linear_1 = nn.Linear(embed_dim * 3, mlp_hidden) self.linear_out = nn.Linear(mlp_hidden, 1) else: self.linear_out = nn.Linear(embed_dim * 3, 1) self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim)) self.bias_n2l = Parameter(torch.Tensor(embed_dim)) # self.bias_picked = Parameter(torch.Tensor(1, embed_dim)) self.conv_params = nn.Linear(embed_dim, embed_dim) self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device) weights_init(self) input_dim = (node_labels.max() + 1) * self.n_injected self.label_encoder_1 = nn.Linear(input_dim, mlp_hidden) self.label_encoder_2 = nn.Linear(mlp_hidden, embed_dim) self.device = self.node_features.device def to_onehot(self, labels): return self.identity[labels].view(-1, self.identity.shape[1]) def get_label_embedding(self, labels): # int to one hot onehot = self.to_onehot(labels).view(1, -1) x = F.relu(self.label_encoder_1(onehot)) x = F.relu(self.label_encoder_2(x)) return x def get_action_label_encoding(self, label): onehot = self.to_onehot(label) zeros = torch.zeros( (onehot.shape[0], self.embed_dim - onehot.shape[1])).to(onehot.device) return torch.cat((onehot, zeros), dim=1) def get_graph_embedding(self, adj): if self.node_features.data.is_sparse: node_embed = torch.spmm(self.node_features, self.w_n2l) else: node_embed = torch.mm(self.node_features, self.w_n2l) node_embed += self.bias_n2l input_message = node_embed node_embed = F.relu(input_message) for i in range(self.max_lv): n2npool = torch.spmm(adj, node_embed) node_linear = self.conv_params(n2npool) merged_linear = node_linear + input_message node_embed = F.relu(merged_linear) graph_embed = torch.mean(node_embed, dim=0, keepdim=True) return graph_embed, node_embed def make_spmat(self, n_rows, n_cols, row_idx, col_idx): idxes = torch.LongTensor([[row_idx], [col_idx]]) values = torch.ones(1) sp = torch.sparse.FloatTensor(idxes, values, torch.Size([n_rows, n_cols])) if next(self.parameters()).is_cuda: sp = sp.cuda() return sp def forward(self, time_t, states, actions, greedy_acts=False, is_inference=False): preds = torch.zeros(len(states)).to(self.device) batch_graph, modified_labels = zip(*states) greedy_actions = [] with torch.set_grad_enabled(mode=not is_inference): for i in range(len(batch_graph)): if batch_graph[i] is None: continue adj = self.norm_tool.norm_extra(batch_graph[i].get_extra_adj( self.device)) # get graph representation graph_embed, node_embed = self.get_graph_embedding(adj) # get label reprensentation label_embed = self.get_label_embedding(modified_labels[i]) # get action reprensentation if time_t != 2: action_embed = node_embed[actions[i]].view( -1, self.embed_dim) else: action_embed = self.get_action_label_encoding(actions[i]) # concat them and send it to neural network embed_s = torch.cat((graph_embed, label_embed), dim=1) embed_s = embed_s.repeat(len(action_embed), 1) embed_s_a = torch.cat((embed_s, action_embed), dim=1) if self.mlp_hidden: embed_s_a = F.relu(self.linear_1(embed_s_a)) raw_pred = self.linear_out(embed_s_a) if greedy_acts: action_id = raw_pred.argmax(0) raw_pred = raw_pred.max() greedy_actions.append(actions[i][action_id]) else: raw_pred = raw_pred.max() # list_pred.append(raw_pred) preds[i] += raw_pred return greedy_actions, preds