Example #1
0
def main():
    
    # Load data
    start = time.time()
    N, _adj, _feats, _labels, train_adj, train_feats, train_nodes, val_nodes, test_nodes, y_train, y_val, y_test, val_mask, test_mask = utils.load_data(args.dataset)
    print('Loaded data in {:.2f} seconds!'.format(time.time() - start))
    
    # Prepare Train Data
    start = time.time()
    _, parts = utils.partition_graph(train_adj, train_nodes, args.num_clusters_train)
    parts = [np.array(pt) for pt in parts]
    train_features, train_support, y_train = utils.preprocess_multicluster(train_adj, parts, train_feats, y_train, args.num_clusters_train, args.batch_size)    
    print('Train Data pre-processed in {:.2f} seconds!'.format(time.time() - start))
    
    # Prepare Test Data
    if args.test == 1:    
        y_test, test_mask = y_val, val_mask
        start = time.time()
        _, test_features, test_support, y_test, test_mask = utils.preprocess(_adj, _feats, y_test, np.arange(N), args.num_clusters_test, test_mask) 
        print('Test Data pre-processed in {:.2f} seconds!'.format(time.time() - start))
    
    # Shuffle Batches
    batch_idxs = list(range(len(train_features)))
    
    # model
    model = GCN(fan_in=_in, fan_out=_out, layers=args.layers, dropout=args.dropout, normalize=True, bias=False).float()
    model.cuda()

    # Loss Function
    criterion = torch.nn.CrossEntropyLoss()
    
    # Optimization Algorithm
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        
    # Learning Rate Schedule    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=int(args.num_clusters_train/args.batch_size), epochs=args.epochs+1, anneal_strategy='linear')
    model.train()

    
    # Train
    for epoch in range(args.epochs + 1):
        np.random.shuffle(batch_idxs)
        avg_loss = 0
        start = time.time()
        for batch in batch_idxs:
            loss = train(model.cuda(), criterion, optimizer, train_features[batch], train_support[batch], y_train[batch], dataset=args.dataset)
            if args.lr_scheduler == 1:
                scheduler.step()
            avg_loss += loss.item()
        
        # Write Train stats to tensorboard
        writer.add_scalar('time/train', time.time() - start, epoch)
        writer.add_scalar('loss/train', avg_loss/len(train_features), epoch)
        
    if args.test == 1:    
        # Test on cpu
        f1 = test(model.cpu(), test_features, test_support, y_test, test_mask, device='cpu')
        print('f1: {:.4f}'.format(f1))
Example #2
0
def main(unused_argv):
  """Main function for running experiments."""
  # Load data
  (train_adj, full_adj, train_feats, test_feats, y_train, y_val, y_test,
   train_mask, val_mask, test_mask, _, val_data, test_data, num_data,
   visible_data) = load_data(FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc)

  # Partition graph and do preprocessing
  if FLAGS.bsize > 1:
    _, parts = partition_utils.partition_graph(train_adj, visible_data,
                                               FLAGS.num_clusters)
    parts = [np.array(pt) for pt in parts]
  else:
    (parts, features_batches, support_batches, y_train_batches,
     train_mask_batches) = utils.preprocess(train_adj, train_feats, y_train,
                                            train_mask, visible_data,
                                            FLAGS.num_clusters,
                                            FLAGS.diag_lambda)

  (_, val_features_batches, val_support_batches, y_val_batches,
   val_mask_batches) = utils.preprocess(full_adj, test_feats, y_val, val_mask,
                                        np.arange(num_data),
                                        FLAGS.num_clusters_val,
                                        FLAGS.diag_lambda)

  (_, test_features_batches, test_support_batches, y_test_batches,
   test_mask_batches) = utils.preprocess(full_adj, test_feats, y_test,
                                         test_mask, np.arange(num_data),
                                         FLAGS.num_clusters_test,
                                         FLAGS.diag_lambda)
  idx_parts = list(range(len(parts)))

  # Some preprocessing
  model_func = models.GCN

  # Define placeholders
  placeholders = {
      'support':
          tf.sparse_placeholder(tf.float32),
      'features':
          tf.placeholder(tf.float32),
      'labels':
          tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
      'labels_mask':
          tf.placeholder(tf.int32),
      'dropout':
          tf.placeholder_with_default(0., shape=()),
      'num_features_nonzero':
          tf.placeholder(tf.int32)  # helper variable for sparse dropout
  }

  # Create model
  model = model_func(
      placeholders,
      input_dim=test_feats.shape[1],
      logging=True,
      multilabel=FLAGS.multilabel,
      norm=FLAGS.layernorm,
      precalc=FLAGS.precalc,
      num_layers=FLAGS.num_layers)

  # Initialize session
  sess = tf.Session()
  tf.set_random_seed(seed)

  # Init variables
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  cost_val = []
  total_training_time = 0.0
  # Train model
  for epoch in range(FLAGS.epochs):
    t = time.time()
    np.random.shuffle(idx_parts)
    if FLAGS.bsize > 1:
      (features_batches, support_batches, y_train_batches,
       train_mask_batches) = utils.preprocess_multicluster(
           train_adj, parts, train_feats, y_train, train_mask,
           FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda)
      for pid in range(len(features_batches)):
        # Use preprocessed batch data
        features_b = features_batches[pid]
        support_b = support_batches[pid]
        y_train_b = y_train_batches[pid]
        train_mask_b = train_mask_batches[pid]
        # Construct feed dictionary
        feed_dict = utils.construct_feed_dict(features_b, support_b, y_train_b,
                                              train_mask_b, placeholders)
        feed_dict.update({placeholders['dropout']: FLAGS.dropout})
        # Training step
        outs = sess.run([model.opt_op, model.loss, model.accuracy],
                        feed_dict=feed_dict)
    else:
      np.random.shuffle(idx_parts)
      for pid in idx_parts:
        # Use preprocessed batch data
        features_b = features_batches[pid]
        support_b = support_batches[pid]
        y_train_b = y_train_batches[pid]
        train_mask_b = train_mask_batches[pid]
        # Construct feed dictionary
        feed_dict = utils.construct_feed_dict(features_b, support_b, y_train_b,
                                              train_mask_b, placeholders)
        feed_dict.update({placeholders['dropout']: FLAGS.dropout})
        # Training step
        outs = sess.run([model.opt_op, model.loss, model.accuracy],
                        feed_dict=feed_dict)

    total_training_time += time.time() - t
    print_str = 'Epoch: %04d ' % (epoch + 1) + 'training time: {:.5f} '.format(
        total_training_time) + 'train_acc= {:.5f} '.format(outs[2])

    # Validation
    if FLAGS.validation:
      cost, acc, micro, macro = evaluate(sess, model, val_features_batches,
                                         val_support_batches, y_val_batches,
                                         val_mask_batches, val_data,
                                         placeholders)
      cost_val.append(cost)
      print_str += 'val_acc= {:.5f} '.format(
          acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

    tf.logging.info(print_str)

    if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
        cost_val[-(FLAGS.early_stopping + 1):-1]):
      tf.logging.info('Early stopping...')
      break

  tf.logging.info('Optimization Finished!')

  # Save model
  saver.save(sess, FLAGS.save_name)

  # Load model (using CPU for inference)
  with tf.device('/cpu:0'):
    sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
    sess_cpu.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess_cpu, FLAGS.save_name)
    # Testing
    test_cost, test_acc, micro, macro = evaluate(
        sess_cpu, model, test_features_batches, test_support_batches,
        y_test_batches, test_mask_batches, test_data, placeholders)
    print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
        test_cost) + 'accuracy= {:.5f} '.format(
            test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
    tf.logging.info(print_str)
Example #3
0
def main():
    # Make dir
    temp = "./tmp"
    os.makedirs(temp, exist_ok=True)
    # Load data
    start = time.time()
    (train_adj, full_adj, train_feats, test_feats, y_train, y_val, y_test,
     train_mask, val_mask, test_mask, _, val_nodes, test_nodes,
     num_data, visible_data) = utils.load_data(args.dataset)
    print('Loaded data in {:.2f} seconds!'.format(time.time() - start))

    start = time.time()
    # Prepare Train Data
    if args.batch_size > 1:
        start = time.time()
        _, parts = utils.partition_graph(train_adj, visible_data, args.num_clusters_train)
        print('Partition graph in {:.2f} seconds!'.format(time.time() - start))
        parts = [np.array(pt) for pt in parts]
    else:
        start = time.time()
        (parts, features_batches, support_batches, y_train_batches, train_mask_batches) = utils.preprocess(
            train_adj, train_feats, y_train, train_mask, visible_data, args.num_clusters_train, diag_lambda=args.diag_lambda)
        print('Partition graph in {:.2f} seconds!'.format(time.time() - start))

    # Prepare valid Data
    start = time.time()
    (_, val_features_batches, val_support_batches, y_val_batches, val_mask_batches) = utils.preprocess(
        full_adj, test_feats, y_val, val_mask, np.arange(num_data), args.num_clusters_val, diag_lambda=args.diag_lambda)
    print('Partition graph in {:.2f} seconds!'.format(time.time() - start))

    # Prepare Test Data
    start = time.time()
    (_, test_features_batches, test_support_batches, y_test_batches, test_mask_batches) = utils.preprocess(
        full_adj, test_feats, y_test, test_mask, np.arange(num_data), args.num_clusters_test, diag_lambda=args.diag_lambda)
    print('Partition graph in {:.2f} seconds!'.format(time.time() - start))

    idx_parts = list(range(len(parts)))

    # model
    model = GCN(
        fan_in=_in, fan_out=_out, layers=args.layers, dropout=args.dropout, normalize=True, bias=False, precalc=True).float()
    model.to(torch.device('cuda'))
    print(model)

    # Loss Function
    if args.multilabel:
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    # Optimization Algorithm
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Learning Rate Schedule
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     optimizer, max_lr=args.lr, steps_per_epoch=int(args.num_clusters_train/args.batch_size), epochs=args.epochs+1,
    #     anneal_strategy='linear')

    pbar = tqdm.tqdm(total=args.epochs, dynamic_ncols=True)
    for epoch in range(args.epochs + 1):
        # Train
        np.random.shuffle(idx_parts)
        start = time.time()
        avg_loss = 0
        total_correct = 0
        n_nodes = 0
        if args.batch_size > 1:
            (features_batches, support_batches, y_train_batches, train_mask_batches) = utils.preprocess_multicluster(
                train_adj, parts, train_feats, y_train, train_mask,
                args.num_clusters_train, args.batch_size, args.diag_lambda)
            for pid in range(len(features_batches)):
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                loss, pred, labels = train(
                    model.train(), criterion, optimizer,
                    features_b, support_b, y_train_b, train_mask_b, torch.device('cuda'))
                avg_loss += loss
                n_nodes += pred.squeeze().numel()
                total_correct += torch.eq(pred.squeeze(), labels.squeeze()).sum().item()
        else:
            np.random.shuffle(idx_parts)
            for pid in idx_parts:
                # use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                loss, pred, labels = train(
                    model.train(), criterion, optimizer,
                    features_b, support_b, y_train_b, train_mask_b, torch.device('cuda'))
                avg_loss = loss.item()
                n_nodes += pred.squeeze().numel()
                total_correct += torch.eq(pred.squeeze(), labels.squeeze()).sum().item()
        train_acc = total_correct / n_nodes
        # Write Train stats to tensorboard
        writer.add_scalar('time/train', time.time() - start, epoch)
        writer.add_scalar('loss/train', avg_loss/len(features_batches), epoch)
        writer.add_scalar('acc/train', train_acc, epoch)

        # Validation
        cost, acc, micro, macro = evaluate(
            model.eval(), criterion, val_features_batches, val_support_batches, y_val_batches, val_mask_batches,
            val_nodes, torch.device("cuda"))

        # Write Valid stats to tensorboard
        writer.add_scalar('acc/valid', acc, epoch)
        writer.add_scalar('mi_F1/valid', micro, epoch)
        writer.add_scalar('ma_F1/valid', macro, epoch)
        writer.add_scalar('loss/valid', cost, epoch)
        pbar.set_postfix({"t": avg_loss/len(features_batches),"t_acc": train_acc, "v": cost, "v_acc": acc})
        pbar.update()
    pbar.close()

    # Test
    if args.test == 1:
        # Test on cpu
        cost, acc, micro, macro = test(
            model.eval(), criterion, test_features_batches, test_support_batches, y_test_batches, test_mask_batches,
            torch.device("cpu"))
        writer.add_scalar('acc/test', acc, epoch)
        writer.add_scalar('mi_F1/test', micro, epoch)
        writer.add_scalar('ma_F1/test', macro, epoch)
        writer.add_scalar('loss/test', cost, epoch)
        print('test: acc: {:.4f}'.format(acc))
        print('test: mi_f1: {:.4f}, ma_f1: {:.4f}'.format(micro, macro))
Example #4
0
def main(unused_argv):
    """Main function for running experiments."""
    # Load data
    (train_adj, full_adj, train_feats, test_feats, y_train, y_val, y_test,
     train_mask, val_mask, test_mask, _, val_data, test_data, num_data,
     visible_data) = load_data(FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc,
                               FLAGS.dataset, FLAGS.graph_dir)

    # Partition graph and do preprocessing
    if FLAGS.bsize > 1:
        parts_file = FLAGS.dataset + "-parts-txt"
        parts_pickle_file = FLAGS.dataset + "-parts-pickle" if (
            FLAGS.dataset != 'None') else FLAGS.custom_data + "-parts-pickle"
        print("parts_pickle_file", parts_pickle_file)
        if False and os.path.exists(parts_pickle_file):
            f = open(parts_pickle_file, 'rb')
            parts = pickle.load(f)
            f.close()
        else:
            _, parts = partition_utils.partition_graph(train_adj, visible_data,
                                                       FLAGS.num_clusters)
            #f = open(parts_pickle_file, 'wb')
            #pickle.dump(parts, f)
            #f.close()

        if not os.path.exists(parts_file):
            with open(parts_file, 'w') as f:
                s = "{"
                part_i = 0
                for part in parts:
                    s += '"%d"' % part_i + ':' + str(part) + "\n"
                    if part_i != len(parts) - 1:
                        s += ','
                    part_i += 1
                f.write(s + "}")

        parts = [np.array(pt) for pt in parts]

    else:
        (parts, features_batches, support_batches, y_train_batches,
         train_mask_batches) = utils.preprocess(train_adj, train_feats,
                                                y_train, train_mask,
                                                visible_data,
                                                FLAGS.num_clusters,
                                                FLAGS.diag_lambda)

    # (_, val_features_batches, val_support_batches, y_val_batches,
    #  val_mask_batches) = utils.preprocess(full_adj, test_feats, y_val, val_mask,
    #                                       np.arange(num_data),
    #                                       FLAGS.num_clusters_val,
    #                                       FLAGS.diag_lambda)

    # (_, test_features_batches, test_support_batches, y_test_batches,
    #  test_mask_batches) = utils.preprocess(full_adj, test_feats, y_test,
    #                                        test_mask, np.arange(num_data),
    #                                        FLAGS.num_clusters_test,
    #                                        FLAGS.diag_lambda)
    idx_parts = list(range(len(parts)))

    # Some preprocessing
    model_func = models.GCN

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        'features': tf.placeholder(tf.float32),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'labels_mask': tf.placeholder(tf.int32),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'num_features_nonzero':
        tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    model = model_func(placeholders,
                       input_dim=test_feats.shape[1],
                       logging=True,
                       multilabel=FLAGS.multilabel,
                       norm=FLAGS.layernorm,
                       precalc=FLAGS.precalc,
                       num_layers=FLAGS.num_layers)

    # Initialize session
    sess = tf.Session()
    tf.set_random_seed(seed)

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    cost_val = []
    total_training_time = 0.0
    sampling_time = 0
    training_time = 0
    extraction_time = 0
    # Train model
    for epoch in range(FLAGS.epochs):
        t = time.time()
        np.random.shuffle(idx_parts)
        if FLAGS.bsize > 1:
            t0 = time.time()
            (features_batches, support_batches, y_train_batches,
             train_mask_batches, t) = utils.preprocess_multicluster(
                 train_adj, parts, train_feats, y_train, train_mask,
                 FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda)
            t1 = time.time()
            extraction_time += t
            sampling_time += t1 - t0
            for pid in range(len(features_batches)):
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                t0 = time.time()
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
                t1 = time.time()
                training_time += t1 - t0
        else:
            np.random.shuffle(idx_parts)
            for pid in idx_parts:
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)

        total_training_time += time.time() - t

        print_str = 'Epoch: %04d ' % (
            epoch + 1) + 'training time: {:.5f} '.format(
                total_training_time) + 'train_acc= {:.5f} '.format(outs[2])
    print("sampling_time (clustergcn)", sampling_time)
    print("training_time:", training_time)

    #print(sampling_time,"Total sampling time")
    #print(training_time,"Total training time")
    #print(extraction_time,"Total extraction time")
    #   return
    #   # Validation
    #   if FLAGS.validation:
    #     cost, acc, micro, macro = evaluate(sess, model, val_features_batches,
    #                                        val_support_batches, y_val_batches,
    #                                        val_mask_batches, val_data,
    #                                        placeholders)
    #     cost_val.append(cost)
    #     print_str += 'val_acc= {:.5f} '.format(
    #         acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

    #   tf.logging.info(print_str)

    #   if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
    #       cost_val[-(FLAGS.early_stopping + 1):-1]):
    #     tf.logging.info('Early stopping...')
    #     break

    # tf.logging.info('Optimization Finished!')
    return
    # Save model
    saver.save(sess, FLAGS.save_name)

    # Load model (using CPU for inference)
    with tf.device('/cpu:0'):
        sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
        sess_cpu.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess_cpu, FLAGS.save_name)
        # Testing
        test_cost, test_acc, micro, macro = evaluate(
            sess_cpu, model, test_features_batches, test_support_batches,
            y_test_batches, test_mask_batches, test_data, placeholders)
        print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        tf.logging.info(print_str)
Example #5
0
def main(unused_argv):
    """Main function for running experiments."""
    # Load data
    utils.tab_printer(FLAGS.flag_values_dict())
    (full_adj, feats, y_train, y_val, y_test, train_mask, val_mask, test_mask,
     train_data, val_data, test_data,
     num_data) = utils.load_ne_data_transductive_sparse(
         FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc,
         list(map(float, FLAGS.split)))

    # Partition graph and do preprocessing
    if FLAGS.bsize > 1:  # multi cluster per epoch
        _, parts = partition_utils.partition_graph(full_adj,
                                                   np.arange(num_data),
                                                   FLAGS.num_clusters)

        parts = [np.array(pt) for pt in parts]
    else:
        (parts, features_batches, support_batches, y_train_batches,
         train_mask_batches) = utils.preprocess(full_adj,
                                                feats,
                                                y_train,
                                                train_mask,
                                                np.arange(num_data),
                                                FLAGS.num_clusters,
                                                FLAGS.diag_lambda,
                                                sparse_input=True)
    # valid & test in the same time
    # validation set
    (_, val_features_batches, test_features_batches, val_support_batches,
     y_val_batches, y_test_batches,
     val_mask_batches, test_mask_batches) = utils.preprocess_val_test(
         full_adj, feats, y_val, val_mask, y_test, test_mask,
         np.arange(num_data), FLAGS.num_clusters_val, FLAGS.diag_lambda)

    # (_, val_features_batches, val_support_batches, y_val_batches,
    #  val_mask_batches) = utils.preprocess(full_adj, feats, y_val, val_mask,
    #                                       np.arange(num_data),
    #                                       FLAGS.num_clusters_val,
    #                                       FLAGS.diag_lambda)
    # # test set
    # (_, test_features_batches, test_support_batches, y_test_batches,
    #  test_mask_batches) = utils.preprocess(full_adj, feats, y_test,
    #                                        test_mask, np.arange(num_data),
    #                                        FLAGS.num_clusters_test,
    #                                        FLAGS.diag_lambda)
    idx_parts = list(range(len(parts)))

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        # 'features':
        #     tf.placeholder(tf.float32),
        'features': tf.sparse_placeholder(tf.float32),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'labels_mask': tf.placeholder(tf.int32),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'fm_dropout': tf.placeholder_with_default(0., shape=()),
        'gat_dropout': tf.placeholder_with_default(0.,
                                                   shape=()),  # gat attn drop
        'num_features_nonzero':
        tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    if FLAGS.model == 'gcn':
        model = models.GCN(placeholders,
                           input_dim=feats.shape[1],
                           logging=True,
                           multilabel=FLAGS.multilabel,
                           norm=FLAGS.layernorm,
                           precalc=FLAGS.precalc,
                           num_layers=FLAGS.num_layers,
                           residual=False,
                           sparse_inputs=True)
    elif FLAGS.model == 'gcn_nfm':
        model = models.GCN_NFM(placeholders,
                               input_dim=feats.shape[1],
                               logging=True,
                               multilabel=FLAGS.multilabel,
                               norm=FLAGS.layernorm,
                               precalc=FLAGS.precalc,
                               num_layers=FLAGS.num_layers,
                               residual=False,
                               sparse_inputs=True)
    elif FLAGS.model == 'gat_nfm':
        gat_layers = list(map(int, FLAGS.gat_layers))
        model = models.GAT_NFM(placeholders,
                               input_dim=feats.shape[1],
                               logging=True,
                               multilabel=FLAGS.multilabel,
                               norm=FLAGS.layernorm,
                               precalc=FLAGS.precalc,
                               num_layers=FLAGS.num_layers,
                               residual=False,
                               sparse_inputs=True,
                               gat_layers=gat_layers)
    else:
        raise ValueError(str(FLAGS.model))

    # Initialize session
    sess = tf.Session()

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    cost_val = []
    acc_val = []
    total_training_time = 0.0
    # Train model
    for epoch in range(FLAGS.epochs):
        t = time.time()
        np.random.shuffle(idx_parts)
        if FLAGS.bsize > 1:
            (features_batches, support_batches, y_train_batches,
             train_mask_batches) = utils.preprocess_multicluster(
                 full_adj, parts, feats, y_train, train_mask,
                 FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda, True)
            for pid in range(len(features_batches)):
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                feed_dict.update(
                    {placeholders['fm_dropout']: FLAGS.fm_dropout})
                feed_dict.update(
                    {placeholders['gat_dropout']: FLAGS.gat_dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
                # debug
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
        else:
            np.random.shuffle(idx_parts)
            for pid in idx_parts:
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                feed_dict.update(
                    {placeholders['fm_dropout']: FLAGS.fm_dropout})
                feed_dict.update(
                    {placeholders['gat_dropout']: FLAGS.gat_dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)

        total_training_time += time.time() - t
        print_str = 'Epoch: %04d ' % (
            epoch + 1) + 'training time: {:.5f} '.format(
                total_training_time) + 'train_acc= {:.5f} '.format(outs[2])

        # Validation
        ## todo: merge validation in train procedure
        if FLAGS.validation:
            cost, acc, micro, macro = evaluate(sess, model,
                                               val_features_batches,
                                               val_support_batches,
                                               y_val_batches, val_mask_batches,
                                               val_data, placeholders)
            cost_val.append(cost)
            acc_val.append(acc)
            print_str += 'val_acc= {:.5f} '.format(
                acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

        # tf.logging.info(print_str)
        print(print_str)

        if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
                cost_val[-(FLAGS.early_stopping + 1):-1]):
            tf.logging.info('Early stopping...')
            break

        ### use acc early stopping, lower performance than using loss
        # if epoch > FLAGS.early_stopping and acc_val[-1] < np.mean(
        #     acc_val[-(FLAGS.early_stopping + 1):-1]):
        #   tf.logging.info('Early stopping...')
        #   break

    tf.logging.info('Optimization Finished!')

    # Save model
    saver.save(sess, FLAGS.save_name)

    # Load model (using CPU for inference)
    with tf.device('/cpu:0'):
        sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
        sess_cpu.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess_cpu, FLAGS.save_name)
        # Testing
        test_cost, test_acc, micro, macro = evaluate(
            sess_cpu, model, test_features_batches, val_support_batches,
            y_test_batches, test_mask_batches, test_data, placeholders)
        print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        tf.logging.info(print_str)