Ejemplo n.º 1
0
def create_partition(edge, node_feature, p=50, q=1):
    # Pooling
    edge_pool = torch.zeros(edge[0].shape)
    for adj in edge:
        edge_pool[adj.nonzero()] = 1
    # Partitioning
    idx_nodes = np.array([n for n in range(edge_pool.shape[0])],
                         dtype=np.int32)
    part_adj, parts = partition_graph(sp.csr_matrix(edge_pool), idx_nodes, p)

    # Creating subgraph from randomly chosen clusters
    batch = []
    random.shuffle(parts)
    for idx in range(q):
        while len(parts[idx]) == 0:  # Ignore empty clusters
            idx += 1
        for node in parts[idx]:
            batch.append(node)
    new_edge = torch.zeros((edge.shape[0], len(batch), len(batch)))
    new_node_feature = torch.zeros(
        (node_feature.shape[0], len(batch), node_feature.shape[2]))
    for i in range(edge.shape[0]):
        new_edge[i] = edge[i][batch][:, batch]
        new_node_feature[i] = node_feature[i][batch]
    edge = new_edge
    node_feature = new_node_feature
    return edge, node_feature
Ejemplo n.º 2
0
def preprocess_val_test_afm(adj,
                            features,
                            features_idx,
                            features_val,
                            y_val,
                            val_mask,
                            y_test,
                            test_mask,
                            visible_data,
                            num_clusters,
                            diag_lambda=-1):
    """Do graph partitioning and preprocessing for SGD training. Patition validation and test set in the same time"""

    # Do graph partitioning
    part_adj, parts = partition_utils.partition_graph(adj, visible_data,
                                                      num_clusters)
    if diag_lambda == -1:
        part_adj = normalize_adj(part_adj)
    elif diag_lambda == -2:
        part_adj = sym_normalize_adj(part_adj)
    elif diag_lambda == 0 and FLAGS.model == 'gat_nfm':
        part_adj = part_adj
    else:
        part_adj = normalize_adj_diag_enhance(part_adj, diag_lambda)
    parts = [np.array(pt) for pt in parts]

    # TODO: feature_idx/ feature_val的计算只与验证集和测试集自身有关,无需加入训练集
    features_val_batches = [[], [],
                            []]  # [features_sp, features_idx, features_val]
    support_batches = []
    y_val_batches = []
    val_mask_batches = []
    y_test_batches = []
    test_mask_batches = []
    total_nnz = 0
    for pt in parts:
        features_val_batches[0].append(sparse_to_tuple(
            features[pt, :]))  # features_sp
        features_val_batches[1].append(features_idx[pt, :])
        features_val_batches[2].append(features_val[pt, :])

        now_part = part_adj[pt, :][:, pt]
        total_nnz += now_part.count_nonzero()
        support_batches.append(sparse_to_tuple(now_part))
        y_val_batches.append(y_val[pt, :])
        y_test_batches.append(y_test[pt, :])

        val_pt = []
        test_pt = []
        for newidx, idx in enumerate(pt):
            if val_mask[idx]:
                val_pt.append(newidx)
            if test_mask[idx]:
                test_pt.append(newidx)
        val_mask_batches.append(sample_mask(val_pt, len(pt)))
        test_mask_batches.append(sample_mask(test_pt, len(pt)))
    features_test_batches = features_val_batches
    return (parts, features_val_batches, features_test_batches,
            support_batches, y_val_batches, y_test_batches, val_mask_batches,
            test_mask_batches)
Ejemplo n.º 3
0
 def __init__(self, adj_matrix, train_nodes, num_clusters):
     assert(adj_matrix.diagonal().sum() == 0)  # make sure diagnal is zero
     # make sure is symmetric
     assert((adj_matrix != adj_matrix.T).nnz == 0)
     self.adj_matrix = adj_matrix
     self.lap_matrix = normalize(adj_matrix+sp.eye(adj_matrix.shape[0]))
     self.train_nodes = train_nodes
     self.num_clusters = num_clusters
     self.parts = partition_graph(
         adj_matrix, train_nodes, num_clusters)
Ejemplo n.º 4
0
def preprocess_train_afm(adj,
                         features,
                         features_idx,
                         features_val,
                         y_train,
                         train_mask,
                         visible_data,
                         num_clusters,
                         diag_lambda=-1,
                         sparse_input=False):
    """Do graph partitioning and preprocessing for SGD training. Patition train dataset."""
    part_adj, parts = partition_utils.partition_graph(adj, visible_data,
                                                      num_clusters)
    if diag_lambda == -1:
        part_adj = normalize_adj(part_adj)
    elif diag_lambda == -2:
        part_adj = sym_normalize_adj(part_adj)
    elif diag_lambda == 0 and FLAGS.model == 'gat_nfm':
        part_adj = unnormlize_adj(part_adj)
    else:
        part_adj = normalize_adj_diag_enhance(part_adj, diag_lambda)
    parts = [np.array(pt) for pt in parts]

    features_batches = [[], [], []]
    support_batches = []
    y_train_batches = []
    train_mask_batches = []
    total_nnz = 0
    for pt in parts:
        if sparse_input:
            features_batches[0].append(sparse_to_tuple(
                features[pt, :]))  # features_sp
        else:
            features_batches.append(features[pt, :])
        features_batches[1].append(features_idx[pt, :])
        features_batches[2].append(features_val[pt, :])
        now_part = part_adj[pt, :][:, pt]
        total_nnz += now_part.count_nonzero()
        support_batches.append(sparse_to_tuple(now_part))
        y_train_batches.append(y_train[pt, :])

        train_pt = []
        for newidx, idx in enumerate(pt):
            if train_mask[idx]:
                train_pt.append(newidx)
        train_mask_batches.append(sample_mask(train_pt, len(pt)))
    return (parts, features_batches, support_batches, y_train_batches,
            train_mask_batches)
Ejemplo n.º 5
0
def preprocess(adj,
               features,
               y_train,
               train_mask,
               visible_data,
               num_clusters,
               diag_lambda=-1,
               label_cluster=None):
  """Do graph partitioning and preprocessing for SGD training."""

  # Do graph partitioning
  part_adj, parts = partition_utils.partition_graph(adj, visible_data,
                                                    num_clusters, label_cluster)
  if diag_lambda == -1:
    part_adj = normalize_adj(part_adj)
  else:
    part_adj = normalize_adj_diag_enhance(part_adj, diag_lambda)
  parts = [np.array(pt) for pt in parts]

  features_batches = []
  support_batches = []
  y_train_batches = []
  train_mask_batches = []
  total_nnz = 0
  part_cluster = []
  for pt in parts:
    features_batches.append(features[pt, :])
    now_part = part_adj[pt, :][:, pt]
    part_cluster.append(now_part)
    total_nnz += now_part.count_nonzero()
    support_batches.append(sparse_to_tuple(now_part))
    y_train_batches.append(y_train[pt, :])

    train_pt = []
    for newidx, idx in enumerate(pt):
      if train_mask[idx]:
        train_pt.append(newidx)
    train_mask_batches.append(sample_mask(train_pt, len(pt)))
  return (parts, features_batches, support_batches, y_train_batches,
          train_mask_batches, part_cluster)
Ejemplo n.º 6
0
def main(unused_argv):
  """Main function for running experiments."""
  # Load data
  (train_adj, full_adj, train_feats, test_feats, y_train, y_val, y_test,
   train_mask, val_mask, test_mask, _, val_data, test_data, num_data,
   visible_data) = load_data(FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc)

  # Partition graph and do preprocessing
  if FLAGS.bsize > 1:
    _, parts = partition_utils.partition_graph(train_adj, visible_data,
                                               FLAGS.num_clusters)
    parts = [np.array(pt) for pt in parts]
  else:
    (parts, features_batches, support_batches, y_train_batches,
     train_mask_batches) = utils.preprocess(train_adj, train_feats, y_train,
                                            train_mask, visible_data,
                                            FLAGS.num_clusters,
                                            FLAGS.diag_lambda)

  (_, val_features_batches, val_support_batches, y_val_batches,
   val_mask_batches) = utils.preprocess(full_adj, test_feats, y_val, val_mask,
                                        np.arange(num_data),
                                        FLAGS.num_clusters_val,
                                        FLAGS.diag_lambda)

  (_, test_features_batches, test_support_batches, y_test_batches,
   test_mask_batches) = utils.preprocess(full_adj, test_feats, y_test,
                                         test_mask, np.arange(num_data),
                                         FLAGS.num_clusters_test,
                                         FLAGS.diag_lambda)
  idx_parts = list(range(len(parts)))

  # Some preprocessing
  model_func = models.GCN

  # Define placeholders
  placeholders = {
      'support':
          tf.sparse_placeholder(tf.float32),
      'features':
          tf.placeholder(tf.float32),
      'labels':
          tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
      'labels_mask':
          tf.placeholder(tf.int32),
      'dropout':
          tf.placeholder_with_default(0., shape=()),
      'num_features_nonzero':
          tf.placeholder(tf.int32)  # helper variable for sparse dropout
  }

  # Create model
  model = model_func(
      placeholders,
      input_dim=test_feats.shape[1],
      logging=True,
      multilabel=FLAGS.multilabel,
      norm=FLAGS.layernorm,
      precalc=FLAGS.precalc,
      num_layers=FLAGS.num_layers)

  # Initialize session
  sess = tf.Session()
  tf.set_random_seed(seed)

  # Init variables
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  cost_val = []
  total_training_time = 0.0
  # Train model
  for epoch in range(FLAGS.epochs):
    t = time.time()
    np.random.shuffle(idx_parts)
    if FLAGS.bsize > 1:
      (features_batches, support_batches, y_train_batches,
       train_mask_batches) = utils.preprocess_multicluster(
           train_adj, parts, train_feats, y_train, train_mask,
           FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda)
      for pid in range(len(features_batches)):
        # Use preprocessed batch data
        features_b = features_batches[pid]
        support_b = support_batches[pid]
        y_train_b = y_train_batches[pid]
        train_mask_b = train_mask_batches[pid]
        # Construct feed dictionary
        feed_dict = utils.construct_feed_dict(features_b, support_b, y_train_b,
                                              train_mask_b, placeholders)
        feed_dict.update({placeholders['dropout']: FLAGS.dropout})
        # Training step
        outs = sess.run([model.opt_op, model.loss, model.accuracy],
                        feed_dict=feed_dict)
    else:
      np.random.shuffle(idx_parts)
      for pid in idx_parts:
        # Use preprocessed batch data
        features_b = features_batches[pid]
        support_b = support_batches[pid]
        y_train_b = y_train_batches[pid]
        train_mask_b = train_mask_batches[pid]
        # Construct feed dictionary
        feed_dict = utils.construct_feed_dict(features_b, support_b, y_train_b,
                                              train_mask_b, placeholders)
        feed_dict.update({placeholders['dropout']: FLAGS.dropout})
        # Training step
        outs = sess.run([model.opt_op, model.loss, model.accuracy],
                        feed_dict=feed_dict)

    total_training_time += time.time() - t
    print_str = 'Epoch: %04d ' % (epoch + 1) + 'training time: {:.5f} '.format(
        total_training_time) + 'train_acc= {:.5f} '.format(outs[2])

    # Validation
    if FLAGS.validation:
      cost, acc, micro, macro = evaluate(sess, model, val_features_batches,
                                         val_support_batches, y_val_batches,
                                         val_mask_batches, val_data,
                                         placeholders)
      cost_val.append(cost)
      print_str += 'val_acc= {:.5f} '.format(
          acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

    tf.logging.info(print_str)

    if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
        cost_val[-(FLAGS.early_stopping + 1):-1]):
      tf.logging.info('Early stopping...')
      break

  tf.logging.info('Optimization Finished!')

  # Save model
  saver.save(sess, FLAGS.save_name)

  # Load model (using CPU for inference)
  with tf.device('/cpu:0'):
    sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
    sess_cpu.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess_cpu, FLAGS.save_name)
    # Testing
    test_cost, test_acc, micro, macro = evaluate(
        sess_cpu, model, test_features_batches, test_support_batches,
        y_test_batches, test_mask_batches, test_data, placeholders)
    print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
        test_cost) + 'accuracy= {:.5f} '.format(
            test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
    tf.logging.info(print_str)
Ejemplo n.º 7
0
    def test2(self):
        self.restore_model(self.resume_iters)

        self.G.eval()

        self.iteration = self.test_len - self.num_clips + 1
        self.graph_size = self.test_size
        self.model_save_dir = os.path.join(self.checkpoint_dir, self.model_dir)
        self.Th_error = np.load(self.model_save_dir + '_threshold.npy')

        abnormal = torch.zeros(self.num_clips, self.graph_size)

        abnormal = abnormal.to(self.device)

        tp = 0.
        tn = 0.
        fp = 0.
        fn = 0.

        for idx in range(self.iteration):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            node_feature = torch.zeros(
                (self.num_clips, self.graph_size, self.graph_ch),
                dtype=torch.float)
            edge = torch.zeros(
                (self.num_clips, self.graph_size, self.graph_size),
                dtype=torch.float)

            for d in range(self.num_clips):
                node_path = self.dataset_name + '/node/node' + str(
                    idx + d + self.train_len + 1) + '.npy'
                edge_path = self.dataset_name + '/graph/graph' + str(
                    idx + d + self.train_len + 1) + '.npy'
                dict_path = self.dataset_name + '/dict/node_dict' + str(
                    idx + d + self.train_len +
                    1) + '.npy' if self.dataset_name != 'DBLP5' else None
                node_feature[d], edge[d], dic, _ = load_graph(
                    node_path, edge_path, dict_path)

            help = torch.eye(edge.shape[1], dtype=torch.float)

            node_exist = torch.sum(torch.mul(help, edge),
                                   dim=-1)  # whether or not the node exists
            edge = torch.mul(1. - help, edge)

            p = 50  # number of partitions
            q = 1  # number of clusters to use in each batch
            edge_pool = torch.zeros(edge[0].shape)
            # print(edge.shape)
            # print(edge_pool.shape)
            # print(node_feature.shape)
            for adj in edge:
                # print(adj.shape)
                edge_pool[adj.nonzero()] = 1
            train_data = np.array([n for n in range(edge_pool.shape[0])],
                                  dtype=np.int32)
            part_adj, parts = partition_utils.partition_graph(
                sp.csr_matrix(edge_pool), train_data, p)
            # print(type(part_adj),type(parts))
            # print(part_adj)
            # print(parts)
            # print(len(parts))
            # for part in parts:
            #     print(len(part))
            # print(part_adj.shape)
            batch = []
            random.shuffle(parts)
            for idx in range(q):
                while len(parts[idx]) == 0:
                    idx += 1
                for node in parts[idx]:
                    batch.append(node)
            # print(batch)
            # print(len(batch))
            # my_part = edge_pool[batch][:,batch]
            # print(my_part.shape)
            new_edge = torch.zeros((self.num_clips, len(batch), len(batch)))
            new_node_feature = torch.zeros(
                (self.num_clips, len(batch), self.graph_ch))
            for i in range(self.num_clips):
                new_edge[i] = edge[i][batch][:, batch]
                new_node_feature[i] = node_feature[i][batch]
            # edge = my_part
            # print(new_edge.shape)
            edge = new_edge
            node_feature = new_node_feature

            edge = edge.to(self.device)
            node_feature = node_feature.to(self.device)

            # =================================================================================== #
            #                             2. Train the Auto-encoder                              #
            # =================================================================================== #
            recon_a, recon_x, node_embedding = self.G(node_feature, edge)

            if self.dataset_name == 'reddit_data':
                recon_a = self.egde_weight(recon_a)

            a_score = self.loss_function(recon_a, edge, graph=False)
            x_score = self.loss_function(recon_x, node_feature, graph=False)

            Anomaly_score = (self.ax_w * a_score + (1 - self.ax_w) * x_score)

            record1 = (Anomaly_score > self.Th_error[0]
                       ).float()  # == abnormal[idx:idx + self.num_clips]
            # record2 = (Anomaly_score > self.Th_error[1]).float() == abnormal[idx:idx + self.num_clips]

            for t in range(record1.shape[0]):
                for n in range(record1.shape[1]):
                    if record1[t, n] == abnormal[t, n]:
                        if record1[t, n] == 0:
                            tp += 1.
                        else:
                            tn += 1.
                    else:
                        if record1[t, n] == 0:
                            fp += 1.
                        else:
                            fn += 1.

            anomaly_cpu = Anomaly_score.detach().cpu().numpy()
            anomaly_max = np.max(anomaly_cpu)
            max_indicate = np.where(anomaly_cpu == anomaly_max)
            if node_exist[max_indicate[0], max_indicate[1]]:
                print('idx={}'.format(idx))
                print(anomaly_max)
                print(max_indicate)
                print('\n')
            else:
                print('Not exists')

            torch.cuda.empty_cache()

        acc = (tp + tn) / (tp + tn + fp + fn)
        recall = tp / (tp + fn)
        prec = tp / (tp + fp)
        f1 = 2 * (recall * prec) / (recall + prec)
        print('\n')
        print(acc)
        print(recall)
        print(prec)
        print(f1)
        f = open("dict.txt", "w")
        f.write(str(dic))
        f.close()
Ejemplo n.º 8
0
    def test(self):
        my_start_time = time.time()  # Adrian
        self.restore_model(self.resume_iters)

        self.G.eval()
        self.set_requires_grad(self.G, False)

        #self.device = torch.device('cpu')
        #self.G.to(self.device)

        self.iteration = self.test_len - self.num_clips + 1
        self.graph_size = self.test_size
        self.model_save_dir = os.path.join(self.checkpoint_dir, self.model_dir)

        with torch.no_grad():
            self.Th_error = np.load(self.model_save_dir + '_threshold.npy')

            tp = 0.
            tn = 0.
            fp = 0.
            fn = 0.

            for idx in range(self.iteration):
                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
                node_feature = torch.zeros(
                    (self.num_clips, self.graph_size, self.graph_ch),
                    dtype=torch.float)
                edge = torch.zeros(
                    (self.num_clips, self.graph_size, self.graph_size),
                    dtype=torch.float)
                abnormal = torch.zeros((self.num_clips, self.graph_size),
                                       dtype=torch.float)

                for d in range(self.num_clips):
                    node_path = self.dataset_name + '/node/testnode' + str(
                        idx + d + 1) + '.npy'
                    edge_path = self.dataset_name + '/graph/testgraph' + str(
                        idx + d + 1) + '.npy'
                    ab_path = self.dataset_name + '/abnormal/abnormal' + str(
                        idx + d + 1) + '.npy'
                    node_feature[d], edge[d], _, abnormal[d] = load_graph(
                        node_path, edge_path, abnormal_path=ab_path)

                help = torch.eye(edge.shape[1], dtype=torch.float)

                node_exist = torch.sum(torch.mul(
                    help, edge), dim=-1)  # whether or not the node exists
                edge = torch.mul(1. - help, edge)

                p = 50  # number of partitions
                q = 1  # number of clusters to use in each batch
                edge_pool = torch.zeros(edge[0].shape)
                # print(edge.shape)
                # print(edge_pool.shape)
                # print(node_feature.shape)
                for adj in edge:
                    # print(adj.shape)
                    edge_pool[adj.nonzero()] = 1
                train_data = np.array([n for n in range(edge_pool.shape[0])],
                                      dtype=np.int32)
                part_adj, parts = partition_utils.partition_graph(
                    sp.csr_matrix(edge_pool), train_data, p)
                # print(type(part_adj),type(parts))
                # print(part_adj)
                # print(parts)
                # print(len(parts))
                # for part in parts:
                #     print(len(part))
                # print(part_adj.shape)
                batch = []
                random.shuffle(parts)
                for idx in range(q):
                    while len(parts[idx]) == 0:
                        idx += 1
                    for node in parts[idx]:
                        batch.append(node)
                # print(batch)
                # print(len(batch))
                # my_part = edge_pool[batch][:,batch]
                # print(my_part.shape)
                new_edge = torch.zeros(
                    (self.num_clips, len(batch), len(batch)))
                new_node_feature = torch.zeros(
                    (self.num_clips, len(batch), self.graph_ch))
                for i in range(self.num_clips):
                    new_edge[i] = edge[i][batch][:, batch]
                    new_node_feature[i] = node_feature[i][batch]
                # edge = my_part
                # print(new_edge.shape)
                edge = new_edge
                node_feature = new_node_feature

                edge = edge.to(self.device)
                node_feature = node_feature.to(self.device)

                # =================================================================================== #
                #                             2. Train the Auto-encoder                              #
                # =================================================================================== #
                recon_a, recon_x, node_embedding = self.G(node_feature, edge)

                if self.dataset_name == 'reddit_data':
                    recon_a = self.egde_weight(recon_a)

                a_score = self.loss_function(recon_a, edge, graph=False)
                x_score = self.loss_function(recon_x,
                                             node_feature,
                                             graph=False)

                Anomaly_score = (self.ax_w * a_score +
                                 (1 - self.ax_w) * x_score)

                _, indicates = torch.topk(Anomaly_score.flatten().cpu(),
                                          k=10,
                                          dim=-1)
                record1 = torch.zeros_like(Anomaly_score,
                                           dtype=torch.float).cpu()
                for ind in indicates:
                    # t = ind // self.graph_size
                    # n = ind % self.graph_size
                    t = ind // len(batch)  # Adrian
                    n = ind % len(batch)  # Adrian
                    record1[t, n] = 1.
                # print(len(batch), Anomaly_score.shape, record1.shape)

                #record1 = (((Anomaly_score - Anomaly_score.mean()) / Anomaly_score.std()) > 1).float().cpu()

                #record1 = (Anomaly_score > self.Th_error[0]).float().cpu() * node_exist # == abnormal[idx:idx + self.num_clips]
                #record2 = (Anomaly_score > self.Th_error[1]).float() == abnormal[idx:idx + self.num_clips]

                for t in range(record1.shape[0]):
                    for n in range(record1.shape[1]):
                        if record1[t, n] == abnormal[t, n]:
                            if record1[t, n] == 0:
                                tn += 1.
                            else:
                                tp += 1.
                        else:
                            if record1[t, n] == 0:
                                fn += 1.
                            else:
                                fp += 1.

                # for item in record1[1]:
                #     t = item // self.graph_size
                #     n = item % self.graph_size
                #     if abnormal[t, n] == 0:
                #         fp += 1.
                #     else:
                #         tp += 1.

                # arecord += torch.sum(record1.float())
                # brecord += torch.sum(record2.float())
                #hist, bins = np.histogram(torch.flatten(Anomaly_score.cpu()),
                #bins=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
                #print(hist)
                #print(bins)

                self.reset_grad()
                del recon_a, recon_x, Anomaly_score, a_score, x_score, record1, abnormal
                del node_feature, edge
                torch.cuda.empty_cache()
        confusion_matrix = np.array([[tp, fp], [fn, tn]])
        print(confusion_matrix)

        acc = (tp + tn) / (tp + tn + fp + fn)
        recall = tp / (tp + fn)
        prec = tp / (tp + fp)
        f1 = 2 * (recall * prec) / (recall + prec + 1e-6)
        tnr = tn / (tn + fp)
        print(acc)
        print(recall)
        print(prec)
        print(f1)
        print(tnr)

        # p = tp/(tp+fp)
        # ground = 1200 if self.dataset_name=="reddit_data" else 200
        # recall = tp/ground
        # print(p)
        # print(recall)

        # Adrian
        my_et = time.time() - my_start_time
        my_et = str(datetime.timedelta(seconds=my_et))[:-7]
        print("Test time:", my_et)
Ejemplo n.º 9
0
    def train(self):
        # my_size = 1000
        # my_data = np.random.randint(2, size=(my_size,my_size))
        # # my_data = np.concatenate((np.zeros((int(my_size * .8),my_size)),np.ones((int(my_size * .2),my_size))), axis=0)
        # my_data = torch.from_numpy(my_data)
        # my_other_data = np.random.randint(2, size=(my_size,self.graph_ch))
        # # my_other_data = np.zeros((my_size,self.graph_ch))
        # my_other_data = torch.from_numpy(my_other_data)
        # my_edge = torch.zeros((self.num_clips, my_size, my_size), dtype=torch.float)
        # my_node_feature = torch.zeros((self.num_clips, my_size, self.graph_ch), dtype=torch.float)
        my_start_time = time.time()  # Adrian
        start_iters = self.resume_iters if not self.new_start else 0
        self.restore_model(self.resume_iters)

        self.iteration = self.train_len - self.num_clips + 1
        self.graph_size = self.train_size

        start_epoch = (int)(start_iters / self.iteration)
        start_batch_id = start_iters - start_epoch * self.iteration

        # loop for epoch
        start_time = time.time()
        lr = self.init_lr

        self.set_requires_grad([self.G], True)

        self.G.train()

        for epoch in range(start_epoch, self.epoch):
            if self.decay_flag and epoch > self.decay_epoch:
                lr = self.init_lr * (self.epoch - epoch) / (
                    self.epoch - self.decay_epoch)  # linear decay
                self.update_lr(lr)

            for idx in range(start_batch_id, self.iteration):
                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
                node_feature = torch.zeros(
                    (self.num_clips, self.graph_size, self.graph_ch),
                    dtype=torch.float)
                edge = torch.zeros(
                    (self.num_clips, self.graph_size, self.graph_size),
                    dtype=torch.float)

                for d in range(self.num_clips):
                    node_path = self.dataset_name + '/node/node' + str(
                        idx + d + 1) + '.npy'
                    edge_path = self.dataset_name + '/graph/graph' + str(
                        idx + d + 1) + '.npy'
                    dict_path = self.dataset_name + '/dict/node_dict' + str(
                        idx + d +
                        1) + '.npy' if self.dataset_name != 'DBLP5' else None
                    node_feature[d], edge[d], _, _ = load_graph(
                        node_path, edge_path, dict_path)

                # edge = my_edge
                # node_feature = my_node_feature

                help = torch.eye(edge.shape[1], dtype=torch.float)

                node_exist = torch.sum(torch.mul(
                    help, edge), dim=-1)  # whether or not the node exists
                edge = torch.mul(1. - help, edge)

                p = 50  # number of partitions
                q = 1  # number of clusters to use in each batch
                edge_pool = torch.zeros(edge[0].shape)
                # print(edge.shape)
                # print(edge_pool.shape)
                # print(node_feature.shape)
                for adj in edge:
                    # print(adj.shape)
                    edge_pool[adj.nonzero()] = 1
                train_data = np.array([n for n in range(edge_pool.shape[0])],
                                      dtype=np.int32)
                part_adj, parts = partition_utils.partition_graph(
                    sp.csr_matrix(edge_pool), train_data, p)
                # print(type(part_adj),type(parts))
                # print(part_adj)
                # print(parts)
                # print(len(parts))
                # for part in parts:
                #     print(len(part))
                # print(part_adj.shape)
                batch = []
                random.shuffle(parts)
                for idx in range(q):
                    while len(parts[idx]) == 0:
                        idx += 1
                    for node in parts[idx]:
                        batch.append(node)
                # print(batch)
                # print(len(batch))
                # my_part = edge_pool[batch][:,batch]
                # print(my_part.shape)
                new_edge = torch.zeros(
                    (self.num_clips, len(batch), len(batch)))
                new_node_feature = torch.zeros(
                    (self.num_clips, len(batch), self.graph_ch))
                for i in range(self.num_clips):
                    new_edge[i] = edge[i][batch][:, batch]
                    new_node_feature[i] = node_feature[i][batch]
                # edge = my_part
                # print(new_edge.shape)
                edge = new_edge
                node_feature = new_node_feature

                edge = edge.to(self.device)
                node_feature = node_feature.to(self.device)

                loss = {}

                # =================================================================================== #
                #                             2. Train the Auto-encoder                              #
                # =================================================================================== #
                node_feature = F.dropout(node_feature, self.denoising)
                edge = F.dropout(edge, self.denoising)
                recon_a, recon_x, _ = self.G(node_feature, edge)

                if self.dataset_name == 'reddit_data':
                    recon_a = self.egde_weight(recon_a)

                self.recon_a_error = self.loss_function(recon_a, edge)
                self.recon_x_error = self.loss_function(recon_x, node_feature)

                self.Reconstruction_error = (
                    self.ax_w * self.recon_a_error +
                    (1 - self.ax_w) * self.recon_x_error)

                # Logging.
                loss['Edge_reconstruction_error'] = self.recon_a_error.item()
                loss['feature_reconstruction_error'] = self.recon_x_error.item(
                )
                # loss['G/loss_cycle'] = self.cycle_loss.item()
                loss['Reconstruction_error'] = self.Reconstruction_error.item()

                del recon_a
                del recon_x
                torch.cuda.empty_cache()

                self.reset_grad()
                self.Reconstruction_error.backward()
                self.g_optimizer.step()

                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #
                start_iters += 1

                # Print out training information.
                if idx % self.print_freq == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Epoch [{}/{}], Iteration [{}/{}]".format(
                        et, epoch + 1, self.epoch, idx + 1, self.iteration)
                    for tag, value in loss.items():
                        if 'error' in tag:  # != 'G/lable' and tag !='O/lable':
                            log += ", {}: {:.4f}".format(tag, value)
                            if self.use_tensorboard:
                                self.logger.scalar_summary(
                                    tag, value, start_iters)
                    print(log)

                torch.cuda.empty_cache()

                # Save model checkpoints.
                if (idx + 1) % self.save_freq == 0:
                    self.save(self.checkpoint_dir, start_iters)
                    torch.cuda.empty_cache()

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model for final step
            self.save(self.checkpoint_dir, start_iters)

            torch.cuda.empty_cache()

        #caculat thresold
        self.thresold()

        # Adrian
        my_et = time.time() - my_start_time
        my_et = str(datetime.timedelta(seconds=my_et))[:-7]
        print("Train time:", my_et)
Ejemplo n.º 10
0
def main(unused_argv):
    """Main function for running experiments."""
    # Load data
    (train_adj, full_adj, train_feats, test_feats, y_train, y_val, y_test,
     train_mask, val_mask, test_mask, _, val_data, test_data, num_data,
     visible_data) = load_data(FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc,
                               FLAGS.dataset, FLAGS.graph_dir)

    # Partition graph and do preprocessing
    if FLAGS.bsize > 1:
        parts_file = FLAGS.dataset + "-parts-txt"
        parts_pickle_file = FLAGS.dataset + "-parts-pickle" if (
            FLAGS.dataset != 'None') else FLAGS.custom_data + "-parts-pickle"
        print("parts_pickle_file", parts_pickle_file)
        if False and os.path.exists(parts_pickle_file):
            f = open(parts_pickle_file, 'rb')
            parts = pickle.load(f)
            f.close()
        else:
            _, parts = partition_utils.partition_graph(train_adj, visible_data,
                                                       FLAGS.num_clusters)
            #f = open(parts_pickle_file, 'wb')
            #pickle.dump(parts, f)
            #f.close()

        if not os.path.exists(parts_file):
            with open(parts_file, 'w') as f:
                s = "{"
                part_i = 0
                for part in parts:
                    s += '"%d"' % part_i + ':' + str(part) + "\n"
                    if part_i != len(parts) - 1:
                        s += ','
                    part_i += 1
                f.write(s + "}")

        parts = [np.array(pt) for pt in parts]

    else:
        (parts, features_batches, support_batches, y_train_batches,
         train_mask_batches) = utils.preprocess(train_adj, train_feats,
                                                y_train, train_mask,
                                                visible_data,
                                                FLAGS.num_clusters,
                                                FLAGS.diag_lambda)

    # (_, val_features_batches, val_support_batches, y_val_batches,
    #  val_mask_batches) = utils.preprocess(full_adj, test_feats, y_val, val_mask,
    #                                       np.arange(num_data),
    #                                       FLAGS.num_clusters_val,
    #                                       FLAGS.diag_lambda)

    # (_, test_features_batches, test_support_batches, y_test_batches,
    #  test_mask_batches) = utils.preprocess(full_adj, test_feats, y_test,
    #                                        test_mask, np.arange(num_data),
    #                                        FLAGS.num_clusters_test,
    #                                        FLAGS.diag_lambda)
    idx_parts = list(range(len(parts)))

    # Some preprocessing
    model_func = models.GCN

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        'features': tf.placeholder(tf.float32),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'labels_mask': tf.placeholder(tf.int32),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'num_features_nonzero':
        tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    model = model_func(placeholders,
                       input_dim=test_feats.shape[1],
                       logging=True,
                       multilabel=FLAGS.multilabel,
                       norm=FLAGS.layernorm,
                       precalc=FLAGS.precalc,
                       num_layers=FLAGS.num_layers)

    # Initialize session
    sess = tf.Session()
    tf.set_random_seed(seed)

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    cost_val = []
    total_training_time = 0.0
    sampling_time = 0
    training_time = 0
    extraction_time = 0
    # Train model
    for epoch in range(FLAGS.epochs):
        t = time.time()
        np.random.shuffle(idx_parts)
        if FLAGS.bsize > 1:
            t0 = time.time()
            (features_batches, support_batches, y_train_batches,
             train_mask_batches, t) = utils.preprocess_multicluster(
                 train_adj, parts, train_feats, y_train, train_mask,
                 FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda)
            t1 = time.time()
            extraction_time += t
            sampling_time += t1 - t0
            for pid in range(len(features_batches)):
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                t0 = time.time()
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
                t1 = time.time()
                training_time += t1 - t0
        else:
            np.random.shuffle(idx_parts)
            for pid in idx_parts:
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)

        total_training_time += time.time() - t

        print_str = 'Epoch: %04d ' % (
            epoch + 1) + 'training time: {:.5f} '.format(
                total_training_time) + 'train_acc= {:.5f} '.format(outs[2])
    print("sampling_time (clustergcn)", sampling_time)
    print("training_time:", training_time)

    #print(sampling_time,"Total sampling time")
    #print(training_time,"Total training time")
    #print(extraction_time,"Total extraction time")
    #   return
    #   # Validation
    #   if FLAGS.validation:
    #     cost, acc, micro, macro = evaluate(sess, model, val_features_batches,
    #                                        val_support_batches, y_val_batches,
    #                                        val_mask_batches, val_data,
    #                                        placeholders)
    #     cost_val.append(cost)
    #     print_str += 'val_acc= {:.5f} '.format(
    #         acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

    #   tf.logging.info(print_str)

    #   if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
    #       cost_val[-(FLAGS.early_stopping + 1):-1]):
    #     tf.logging.info('Early stopping...')
    #     break

    # tf.logging.info('Optimization Finished!')
    return
    # Save model
    saver.save(sess, FLAGS.save_name)

    # Load model (using CPU for inference)
    with tf.device('/cpu:0'):
        sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
        sess_cpu.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess_cpu, FLAGS.save_name)
        # Testing
        test_cost, test_acc, micro, macro = evaluate(
            sess_cpu, model, test_features_batches, test_support_batches,
            y_test_batches, test_mask_batches, test_data, placeholders)
        print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        tf.logging.info(print_str)
Ejemplo n.º 11
0
def main(unused_argv):
    """Main function for running experiments."""
    # Load data
    utils.tab_printer(FLAGS.flag_values_dict())
    (full_adj, feats, y_train, y_val, y_test, train_mask, val_mask, test_mask,
     train_data, val_data, test_data,
     num_data) = utils.load_ne_data_transductive_sparse(
         FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc,
         list(map(float, FLAGS.split)))

    # Partition graph and do preprocessing
    if FLAGS.bsize > 1:  # multi cluster per epoch
        _, parts = partition_utils.partition_graph(full_adj,
                                                   np.arange(num_data),
                                                   FLAGS.num_clusters)

        parts = [np.array(pt) for pt in parts]
    else:
        (parts, features_batches, support_batches, y_train_batches,
         train_mask_batches) = utils.preprocess(full_adj,
                                                feats,
                                                y_train,
                                                train_mask,
                                                np.arange(num_data),
                                                FLAGS.num_clusters,
                                                FLAGS.diag_lambda,
                                                sparse_input=True)
    # valid & test in the same time
    # validation set
    (_, val_features_batches, test_features_batches, val_support_batches,
     y_val_batches, y_test_batches,
     val_mask_batches, test_mask_batches) = utils.preprocess_val_test(
         full_adj, feats, y_val, val_mask, y_test, test_mask,
         np.arange(num_data), FLAGS.num_clusters_val, FLAGS.diag_lambda)

    # (_, val_features_batches, val_support_batches, y_val_batches,
    #  val_mask_batches) = utils.preprocess(full_adj, feats, y_val, val_mask,
    #                                       np.arange(num_data),
    #                                       FLAGS.num_clusters_val,
    #                                       FLAGS.diag_lambda)
    # # test set
    # (_, test_features_batches, test_support_batches, y_test_batches,
    #  test_mask_batches) = utils.preprocess(full_adj, feats, y_test,
    #                                        test_mask, np.arange(num_data),
    #                                        FLAGS.num_clusters_test,
    #                                        FLAGS.diag_lambda)
    idx_parts = list(range(len(parts)))

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        # 'features':
        #     tf.placeholder(tf.float32),
        'features': tf.sparse_placeholder(tf.float32),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'labels_mask': tf.placeholder(tf.int32),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'fm_dropout': tf.placeholder_with_default(0., shape=()),
        'gat_dropout': tf.placeholder_with_default(0.,
                                                   shape=()),  # gat attn drop
        'num_features_nonzero':
        tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    if FLAGS.model == 'gcn':
        model = models.GCN(placeholders,
                           input_dim=feats.shape[1],
                           logging=True,
                           multilabel=FLAGS.multilabel,
                           norm=FLAGS.layernorm,
                           precalc=FLAGS.precalc,
                           num_layers=FLAGS.num_layers,
                           residual=False,
                           sparse_inputs=True)
    elif FLAGS.model == 'gcn_nfm':
        model = models.GCN_NFM(placeholders,
                               input_dim=feats.shape[1],
                               logging=True,
                               multilabel=FLAGS.multilabel,
                               norm=FLAGS.layernorm,
                               precalc=FLAGS.precalc,
                               num_layers=FLAGS.num_layers,
                               residual=False,
                               sparse_inputs=True)
    elif FLAGS.model == 'gat_nfm':
        gat_layers = list(map(int, FLAGS.gat_layers))
        model = models.GAT_NFM(placeholders,
                               input_dim=feats.shape[1],
                               logging=True,
                               multilabel=FLAGS.multilabel,
                               norm=FLAGS.layernorm,
                               precalc=FLAGS.precalc,
                               num_layers=FLAGS.num_layers,
                               residual=False,
                               sparse_inputs=True,
                               gat_layers=gat_layers)
    else:
        raise ValueError(str(FLAGS.model))

    # Initialize session
    sess = tf.Session()

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    cost_val = []
    acc_val = []
    total_training_time = 0.0
    # Train model
    for epoch in range(FLAGS.epochs):
        t = time.time()
        np.random.shuffle(idx_parts)
        if FLAGS.bsize > 1:
            (features_batches, support_batches, y_train_batches,
             train_mask_batches) = utils.preprocess_multicluster(
                 full_adj, parts, feats, y_train, train_mask,
                 FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda, True)
            for pid in range(len(features_batches)):
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                feed_dict.update(
                    {placeholders['fm_dropout']: FLAGS.fm_dropout})
                feed_dict.update(
                    {placeholders['gat_dropout']: FLAGS.gat_dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
                # debug
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
        else:
            np.random.shuffle(idx_parts)
            for pid in idx_parts:
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                feed_dict.update(
                    {placeholders['fm_dropout']: FLAGS.fm_dropout})
                feed_dict.update(
                    {placeholders['gat_dropout']: FLAGS.gat_dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)

        total_training_time += time.time() - t
        print_str = 'Epoch: %04d ' % (
            epoch + 1) + 'training time: {:.5f} '.format(
                total_training_time) + 'train_acc= {:.5f} '.format(outs[2])

        # Validation
        ## todo: merge validation in train procedure
        if FLAGS.validation:
            cost, acc, micro, macro = evaluate(sess, model,
                                               val_features_batches,
                                               val_support_batches,
                                               y_val_batches, val_mask_batches,
                                               val_data, placeholders)
            cost_val.append(cost)
            acc_val.append(acc)
            print_str += 'val_acc= {:.5f} '.format(
                acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

        # tf.logging.info(print_str)
        print(print_str)

        if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
                cost_val[-(FLAGS.early_stopping + 1):-1]):
            tf.logging.info('Early stopping...')
            break

        ### use acc early stopping, lower performance than using loss
        # if epoch > FLAGS.early_stopping and acc_val[-1] < np.mean(
        #     acc_val[-(FLAGS.early_stopping + 1):-1]):
        #   tf.logging.info('Early stopping...')
        #   break

    tf.logging.info('Optimization Finished!')

    # Save model
    saver.save(sess, FLAGS.save_name)

    # Load model (using CPU for inference)
    with tf.device('/cpu:0'):
        sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
        sess_cpu.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess_cpu, FLAGS.save_name)
        # Testing
        test_cost, test_acc, micro, macro = evaluate(
            sess_cpu, model, test_features_batches, val_support_batches,
            y_test_batches, test_mask_batches, test_data, placeholders)
        print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        tf.logging.info(print_str)
Ejemplo n.º 12
0
def main():
    print("Program start, environment initializing ...")
    torch.autograd.set_detect_anomaly(True)
    args = parameter_parser()
    utils.print2file(str(args), args.logDir, True)

    if args.device >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    pic = {}

    # check if pickles, otherwise load data
    # pickle_name = args.data_prefix+args.dataset+"-"+str(args.bsize)+"-"+str(args.num_clusters)+"_main"+".pickle"
    # if os.path.isfile(pickle_name):
    #     print("Loading Pickle.")
    #     load_time = time.time()
    #     pic = pickle.load(open(pickle_name, "rb"))
    #     print("Loading Done. " + str(time.time()-load_time) + " seconds.")
    # else:
    if True:
        print("Data Pre-processing")
        # Load data
        (pic["train_adj"], full_adj, pic["train_feats"], pic["test_feats"],
         pic["y_train"], y_val, y_test, pic["train_mask"], pic["val_mask"],
         test_mask, _, pic["val_data"], pic["test_data"], num_data,
         visible_data) = utils.load_data(args.data_prefix,
                                         args.dataset,
                                         args.precalc,
                                         amazon=True)

        print("Partition graph and do preprocessing")
        if args.bsize > 1:
            _, pic["parts"] = partition_utils.partition_graph(
                pic["train_adj"], visible_data, args.num_clusters)
            pic["parts"] = [np.array(pt) for pt in pic["parts"]]

            (pic["features_batches"], pic["support_batches"],
             pic["y_train_batches"],
             pic["train_mask_batches"]) = utils.preprocess_multicluster_v2(
                 pic["train_adj"], pic["parts"], pic["train_feats"],
                 pic["y_train"], pic["train_mask"], args.num_clusters,
                 args.bsize, args.diag_lambda)

        else:
            (pic["parts"], pic["features_batches"], pic["support_batches"],
             pic["y_train_batches"],
             pic["train_mask_batches"]) = utils.preprocess(
                 pic["train_adj"], pic["train_feats"], pic["y_train"],
                 pic["train_mask"], visible_data, args.num_clusters,
                 args.diag_lambda)

        (_, pic["val_features_batches"], pic["val_support_batches"],
         pic["y_val_batches"], pic["val_mask_batches"]) = utils.preprocess(
             full_adj, pic["test_feats"], y_val, pic["val_mask"],
             np.arange(num_data), args.num_clusters_val, args.diag_lambda)

        (_, pic["test_features_batches"], pic["test_support_batches"],
         pic["y_test_batches"], pic["test_mask_batches"]) = utils.preprocess(
             full_adj, pic["test_feats"], y_test, test_mask,
             np.arange(num_data), args.num_clusters_test, args.diag_lambda)

        # pickle.dump(pic, open(pickle_name, "wb"))

    idx_parts = list(range(len(pic["parts"])))
    print("Preparing model ...")
    model = StackedGCN(args,
                       pic["test_feats"].shape[1],
                       pic["y_train"].shape[1],
                       precalc=args.precalc,
                       num_layers=args.num_layers,
                       norm=args.layernorm)

    w_server = model.cpu().state_dict()

    print("Start training ...")
    model_saved = "./model/" + args.dataset + "-" + args.logDir[6:-4] + ".pt"

    try:
        for epoch in range(args.epochs):
            # Training process
            w_locals, loss_locals, epoch_acc = [], [], []
            all_time = []
            best_val_acc = 0

            for pid in range(len(pic["features_batches"])):
                # for pid in range(10):
                # Use preprocessed batch data
                package = {
                    "features": pic["features_batches"][pid],
                    "support": pic["support_batches"][pid],
                    "y_train": pic["y_train_batches"][pid],
                    "train_mask": pic["train_mask_batches"][pid]
                }

                model.load_state_dict(w_server)
                out_dict = slave_run_train(model, args, package, pid)

                w_locals.append(copy.deepcopy(out_dict['params']))
                loss_locals.append(copy.deepcopy(out_dict['loss']))
                all_time.append(out_dict["time"])
                epoch_acc.append(out_dict["acc"])

            # update global weights
            a_start_time = time.time()
            if args.agg == 'avg':
                w_server = average_agg(w_locals, args.dp)
            elif args.agg == 'att':
                w_server = weighted_agg(w_locals,
                                        w_server,
                                        args.epsilon,
                                        args.ord,
                                        dp=args.dp)
            else:
                exit('Unrecognized aggregation')

            model.load_state_dict(w_server)
            # agg_time = time.time() - a_start_time
            # print(str(sum(all_time)/len(all_time) + agg_time))
            print2file(
                'Epoch: ' + str(epoch) + ' Average Train acc: ' +
                str(sum(epoch_acc) / len(epoch_acc)), args.logDir, True)

            if epoch % args.val_freq == 0:
                val_cost, val_acc, val_micro, val_macro = evaluate(
                    model,
                    args,
                    pic["val_features_batches"],
                    pic["val_support_batches"],
                    pic["y_val_batches"],
                    pic["val_mask_batches"],
                    pic["val_data"],
                    pid="validation")

                log_str = 'Validateion set results: ' + 'cost= {:.5f} '.format(
                    val_cost) + 'accuracy= {:.5f} '.format(
                        val_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(
                            val_micro, val_macro)
                print2file(log_str, args.logDir, True)

                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save(model.state_dict(), model_saved)
                    print2file(
                        "Best val_acc: " + str(best_val_acc) +
                        " with epoch: " + str(epoch), args.logDir, True)

        torch.save(
            model.state_dict(),
            "./model/" + args.dataset + "-" + args.logDir[6:-4] + "Done.pt")
        print2file("Training Done. Model Saved.", args.logDir, True)
        # Test Model
        # Perform two test, one with last model, another with best val_acc model
        # 1)
        test_cost, test_acc, micro, macro = evaluate(
            model,
            args,
            pic["test_features_batches"],
            pic["test_support_batches"],
            pic["y_test_batches"],
            pic["test_mask_batches"],
            pic["test_data"],
            pid="Final test")

        log_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        print2file(log_str, args.logDir, True)

        # 2)
        test_model = StackedGCN(args,
                                pic["test_feats"].shape[1],
                                pic["y_train"].shape[1],
                                precalc=args.precalc,
                                num_layers=args.num_layers,
                                norm=args.layernorm)
        test_model.load_state_dict(torch.load(model_saved))
        test_model.eval()
        test_cost, test_acc, micro, macro = evaluate(
            test_model,
            args,
            pic["test_features_batches"],
            pic["test_support_batches"],
            pic["y_test_batches"],
            pic["test_mask_batches"],
            pic["test_data"],
            pid="Best test")

        log_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        print2file(log_str, args.logDir, True)

    except KeyboardInterrupt:
        print("==" * 20)
        print("Existing from training earlier than the plan.")

    print("End..so far so good.")