def __init__(self,
                 node_features,
                 node_labels,
                 list_action_space,
                 n_injected,
                 bilin_q=1,
                 embed_dim=64,
                 mlp_hidden=64,
                 max_lv=1,
                 gm='mean_field',
                 device='cpu'):
        '''
        bilin_q: bilinear q or not
        mlp_hidden: mlp hidden layer size
        mav_lv: max rounds of message passing
        '''
        super(QNetNode, self).__init__()
        self.node_features = node_features
        self.identity = torch.eye(node_labels.max() + 1).to(node_labels.device)
        # self.node_labels = self.to_onehot(node_labels)
        self.n_injected = n_injected

        self.list_action_space = list_action_space
        self.total_nodes = len(list_action_space)

        self.bilin_q = bilin_q
        self.embed_dim = embed_dim
        self.mlp_hidden = mlp_hidden
        self.max_lv = max_lv
        self.gm = gm

        if mlp_hidden:
            self.linear_1 = nn.Linear(embed_dim * 3, mlp_hidden)
            self.linear_out = nn.Linear(mlp_hidden, 1)
        else:
            self.linear_out = nn.Linear(embed_dim * 3, 1)

        self.w_n2l = Parameter(torch.Tensor(node_features.size()[1],
                                            embed_dim))
        self.bias_n2l = Parameter(torch.Tensor(embed_dim))

        # self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
        self.conv_params = nn.Linear(embed_dim, embed_dim)
        self.norm_tool = GraphNormTool(normalize=True,
                                       gm=self.gm,
                                       device=device)
        weights_init(self)

        input_dim = (node_labels.max() + 1) * self.n_injected
        self.label_encoder_1 = nn.Linear(input_dim, mlp_hidden)
        self.label_encoder_2 = nn.Linear(mlp_hidden, embed_dim)
        self.device = self.node_features.device
    def __init__(self,
                 node_features,
                 node_labels,
                 list_action_space,
                 bilin_q=1,
                 embed_dim=64,
                 mlp_hidden=64,
                 max_lv=1,
                 gm='mean_field',
                 device='cpu'):
        '''
        bilin_q: bilinear q or not
        mlp_hidden: mlp hidden layer size
        mav_lv: max rounds of message passing
        '''
        super(QNetNode, self).__init__()
        self.node_features = node_features
        self.node_labels = node_labels
        self.list_action_space = list_action_space
        self.total_nodes = len(list_action_space)

        self.bilin_q = bilin_q
        self.embed_dim = embed_dim
        self.mlp_hidden = mlp_hidden
        self.max_lv = max_lv
        self.gm = gm

        if bilin_q:
            last_wout = embed_dim
        else:
            last_wout = 1
            self.bias_target = Parameter(torch.Tensor(1, embed_dim))

        if mlp_hidden:
            self.linear_1 = nn.Linear(embed_dim * 2, mlp_hidden)
            self.linear_out = nn.Linear(mlp_hidden, last_wout)
        else:
            self.linear_out = nn.Linear(embed_dim * 2, last_wout)

        self.w_n2l = Parameter(torch.Tensor(node_features.size()[1],
                                            embed_dim))
        self.bias_n2l = Parameter(torch.Tensor(embed_dim))
        self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
        self.conv_params = nn.Linear(embed_dim, embed_dim)
        self.norm_tool = GraphNormTool(normalize=True,
                                       gm=self.gm,
                                       device=device)
        weights_init(self)
예제 #3
0
def init_setup():
    data = Dataset(root='/tmp/', name=args.dataset, setting='gcn')

    data.features = normalize_feature(data.features)
    adj, features, labels = data.adj, data.features, data.labels

    StaticGraph.graph = nx.from_scipy_sparse_matrix(adj)
    dict_of_lists = nx.to_dict_of_lists(StaticGraph.graph)

    idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
    device = torch.device('cuda') if args.ctx == 'gpu' else 'cpu'

    # black box setting
    adj, features, labels = preprocess(adj,
                                       features,
                                       labels,
                                       preprocess_adj=False,
                                       sparse=True,
                                       device=device)
    victim_model = load_victim_model(data,
                                     device=device,
                                     file_path=args.saved_model)
    setattr(victim_model, 'norm_tool',
            GraphNormTool(normalize=True, gm='gcn', device=device))
    output = victim_model.predict(features, adj)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    print("Test set results:", "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))

    return features, labels, idx_val, idx_test, victim_model, dict_of_lists, adj
class QNetNode(nn.Module):
    def __init__(self,
                 node_features,
                 node_labels,
                 list_action_space,
                 bilin_q=1,
                 embed_dim=64,
                 mlp_hidden=64,
                 max_lv=1,
                 gm='mean_field',
                 device='cpu'):
        '''
        bilin_q: bilinear q or not
        mlp_hidden: mlp hidden layer size
        mav_lv: max rounds of message passing
        '''
        super(QNetNode, self).__init__()
        self.node_features = node_features
        self.node_labels = node_labels
        self.list_action_space = list_action_space
        self.total_nodes = len(list_action_space)

        self.bilin_q = bilin_q
        self.embed_dim = embed_dim
        self.mlp_hidden = mlp_hidden
        self.max_lv = max_lv
        self.gm = gm

        if bilin_q:
            last_wout = embed_dim
        else:
            last_wout = 1
            self.bias_target = Parameter(torch.Tensor(1, embed_dim))

        if mlp_hidden:
            self.linear_1 = nn.Linear(embed_dim * 2, mlp_hidden)
            self.linear_out = nn.Linear(mlp_hidden, last_wout)
        else:
            self.linear_out = nn.Linear(embed_dim * 2, last_wout)

        self.w_n2l = Parameter(torch.Tensor(node_features.size()[1],
                                            embed_dim))
        self.bias_n2l = Parameter(torch.Tensor(embed_dim))
        self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
        self.conv_params = nn.Linear(embed_dim, embed_dim)
        self.norm_tool = GraphNormTool(normalize=True,
                                       gm=self.gm,
                                       device=device)
        weights_init(self)

    def make_spmat(self, n_rows, n_cols, row_idx, col_idx):
        idxes = torch.LongTensor([[row_idx], [col_idx]])
        values = torch.ones(1)

        sp = torch.sparse.FloatTensor(idxes, values,
                                      torch.Size([n_rows, n_cols]))
        if next(self.parameters()).is_cuda:
            sp = sp.cuda()
        return sp

    def forward(self,
                time_t,
                states,
                actions,
                greedy_acts=False,
                is_inference=False):

        if self.node_features.data.is_sparse:
            input_node_linear = torch.spmm(self.node_features, self.w_n2l)
        else:
            input_node_linear = torch.mm(self.node_features, self.w_n2l)

        input_node_linear += self.bias_n2l

        # TODO the number of target nodes is batch_size, it actually parallizes
        target_nodes, batch_graph, picked_nodes = zip(*states)

        list_pred = []
        prefix_sum = []
        for i in range(len(batch_graph)):
            region = self.list_action_space[target_nodes[i]]

            node_embed = input_node_linear.clone()
            if picked_nodes is not None and picked_nodes[i] is not None:
                with torch.set_grad_enabled(mode=not is_inference):
                    picked_sp = self.make_spmat(self.total_nodes, 1,
                                                picked_nodes[i], 0)
                    node_embed += torch.spmm(picked_sp, self.bias_picked)
                    region = self.list_action_space[picked_nodes[i]]

            if not self.bilin_q:
                with torch.set_grad_enabled(mode=not is_inference):
                    # with torch.no_grad():
                    target_sp = self.make_spmat(self.total_nodes, 1,
                                                target_nodes[i], 0)
                    node_embed += torch.spmm(target_sp, self.bias_target)

            with torch.set_grad_enabled(mode=not is_inference):
                device = self.node_features.device
                adj = self.norm_tool.norm_extra(
                    batch_graph[i].get_extra_adj(device))

                lv = 0
                input_message = node_embed

                node_embed = F.relu(input_message)
                while lv < self.max_lv:
                    n2npool = torch.spmm(adj, node_embed)
                    node_linear = self.conv_params(n2npool)
                    merged_linear = node_linear + input_message
                    node_embed = F.relu(merged_linear)
                    lv += 1

                target_embed = node_embed[target_nodes[i], :].view(-1, 1)
                if region is not None:
                    node_embed = node_embed[region]

                graph_embed = torch.mean(node_embed, dim=0, keepdim=True)

                if actions is None:
                    graph_embed = graph_embed.repeat(node_embed.size()[0], 1)
                else:
                    if region is not None:
                        act_idx = region.index(actions[i])
                    else:
                        act_idx = actions[i]
                    node_embed = node_embed[act_idx, :].view(1, -1)

                embed_s_a = torch.cat((node_embed, graph_embed), dim=1)
                if self.mlp_hidden:
                    embed_s_a = F.relu(self.linear_1(embed_s_a))
                raw_pred = self.linear_out(embed_s_a)

                if self.bilin_q:
                    raw_pred = torch.mm(raw_pred, target_embed)
                list_pred.append(raw_pred)

        if greedy_acts:
            actions, _ = node_greedy_actions(target_nodes, picked_nodes,
                                             list_pred, self)

        return actions, list_pred
class QNetNode(nn.Module):
    def __init__(self,
                 node_features,
                 node_labels,
                 list_action_space,
                 n_injected,
                 bilin_q=1,
                 embed_dim=64,
                 mlp_hidden=64,
                 max_lv=1,
                 gm='mean_field',
                 device='cpu'):
        '''
        bilin_q: bilinear q or not
        mlp_hidden: mlp hidden layer size
        mav_lv: max rounds of message passing
        '''
        super(QNetNode, self).__init__()
        self.node_features = node_features
        self.identity = torch.eye(node_labels.max() + 1).to(node_labels.device)
        # self.node_labels = self.to_onehot(node_labels)
        self.n_injected = n_injected

        self.list_action_space = list_action_space
        self.total_nodes = len(list_action_space)

        self.bilin_q = bilin_q
        self.embed_dim = embed_dim
        self.mlp_hidden = mlp_hidden
        self.max_lv = max_lv
        self.gm = gm

        if mlp_hidden:
            self.linear_1 = nn.Linear(embed_dim * 3, mlp_hidden)
            self.linear_out = nn.Linear(mlp_hidden, 1)
        else:
            self.linear_out = nn.Linear(embed_dim * 3, 1)

        self.w_n2l = Parameter(torch.Tensor(node_features.size()[1],
                                            embed_dim))
        self.bias_n2l = Parameter(torch.Tensor(embed_dim))

        # self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
        self.conv_params = nn.Linear(embed_dim, embed_dim)
        self.norm_tool = GraphNormTool(normalize=True,
                                       gm=self.gm,
                                       device=device)
        weights_init(self)

        input_dim = (node_labels.max() + 1) * self.n_injected
        self.label_encoder_1 = nn.Linear(input_dim, mlp_hidden)
        self.label_encoder_2 = nn.Linear(mlp_hidden, embed_dim)
        self.device = self.node_features.device

    def to_onehot(self, labels):
        return self.identity[labels].view(-1, self.identity.shape[1])

    def get_label_embedding(self, labels):
        # int to one hot
        onehot = self.to_onehot(labels).view(1, -1)

        x = F.relu(self.label_encoder_1(onehot))
        x = F.relu(self.label_encoder_2(x))
        return x

    def get_action_label_encoding(self, label):
        onehot = self.to_onehot(label)
        zeros = torch.zeros(
            (onehot.shape[0],
             self.embed_dim - onehot.shape[1])).to(onehot.device)
        return torch.cat((onehot, zeros), dim=1)

    def get_graph_embedding(self, adj):
        if self.node_features.data.is_sparse:
            node_embed = torch.spmm(self.node_features, self.w_n2l)
        else:
            node_embed = torch.mm(self.node_features, self.w_n2l)

        node_embed += self.bias_n2l

        input_message = node_embed
        node_embed = F.relu(input_message)

        for i in range(self.max_lv):
            n2npool = torch.spmm(adj, node_embed)
            node_linear = self.conv_params(n2npool)
            merged_linear = node_linear + input_message
            node_embed = F.relu(merged_linear)

        graph_embed = torch.mean(node_embed, dim=0, keepdim=True)
        return graph_embed, node_embed

    def make_spmat(self, n_rows, n_cols, row_idx, col_idx):
        idxes = torch.LongTensor([[row_idx], [col_idx]])
        values = torch.ones(1)

        sp = torch.sparse.FloatTensor(idxes, values,
                                      torch.Size([n_rows, n_cols]))
        if next(self.parameters()).is_cuda:
            sp = sp.cuda()
        return sp

    def forward(self,
                time_t,
                states,
                actions,
                greedy_acts=False,
                is_inference=False):

        preds = torch.zeros(len(states)).to(self.device)

        batch_graph, modified_labels = zip(*states)
        greedy_actions = []
        with torch.set_grad_enabled(mode=not is_inference):

            for i in range(len(batch_graph)):
                if batch_graph[i] is None:
                    continue
                adj = self.norm_tool.norm_extra(batch_graph[i].get_extra_adj(
                    self.device))
                # get graph representation
                graph_embed, node_embed = self.get_graph_embedding(adj)

                # get label reprensentation
                label_embed = self.get_label_embedding(modified_labels[i])

                # get action reprensentation
                if time_t != 2:
                    action_embed = node_embed[actions[i]].view(
                        -1, self.embed_dim)
                else:
                    action_embed = self.get_action_label_encoding(actions[i])

                # concat them and send it to neural network
                embed_s = torch.cat((graph_embed, label_embed), dim=1)
                embed_s = embed_s.repeat(len(action_embed), 1)
                embed_s_a = torch.cat((embed_s, action_embed), dim=1)

                if self.mlp_hidden:
                    embed_s_a = F.relu(self.linear_1(embed_s_a))

                raw_pred = self.linear_out(embed_s_a)

                if greedy_acts:
                    action_id = raw_pred.argmax(0)
                    raw_pred = raw_pred.max()
                    greedy_actions.append(actions[i][action_id])
                else:
                    raw_pred = raw_pred.max()
                # list_pred.append(raw_pred)
                preds[i] += raw_pred

        return greedy_actions, preds