def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    print('Starting', format_filename())
    if FLAGS.load_strategy == 'schur':
        adjacency, features, labels, label_mask = load_npz_to_sparse_graph(
            FLAGS.graph_path)
    elif FLAGS.load_strategy == 'kipf':
        adjacency, features, labels, label_mask = load_kipf_data(
            *os.path.split(FLAGS.graph_path))
    else:
        raise Exception('Unknown loading strategy!')
    n_nodes = adjacency.shape[0]
    feature_size = features.shape[1]
    architecture = [int(x) for x in FLAGS.architecture.strip('[]').split('_')]
    architecture.append(FLAGS.n_clusters)
    graph_clean = scipy_to_tf(adjacency)
    graph_clean_normalized = scipy_to_tf(
        normalize_graph(adjacency.copy(), normalized=True))

    input_features = tf.keras.layers.Input(shape=(feature_size, ))
    input_graph = tf.keras.layers.Input((n_nodes, ), sparse=True)
    input_adjacency = tf.keras.layers.Input((n_nodes, ), sparse=True)

    model = gcn_modularity(
        [input_features, input_graph, input_adjacency],
        architecture,
        dropout_rate=FLAGS.dropout_rate,
        orthogonality_regularization=FLAGS.orthogonality_regularization,
        cluster_size_regularization=FLAGS.cluster_size_regularization)

    def grad(model, inputs):
        with tf.GradientTape() as tape:
            _ = model(inputs, training=True)
            loss_value = sum(model.losses)
        return model.losses, tape.gradient(loss_value,
                                           model.trainable_variables)

    optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
    model.compile(optimizer, None)

    for epoch in range(FLAGS.n_epochs):
        loss_values, grads = grad(
            model, [features, graph_clean_normalized, graph_clean])
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        print(f'epoch {epoch}, losses: ' + ' '.join(
            [f'{loss_value.numpy():.4f}' for loss_value in loss_values]))

    _, assignments = model([features, graph_clean_normalized, graph_clean],
                           training=False)
    assignments = assignments.numpy()
    clusters = assignments.argmax(axis=1)
    print('Conductance:', conductance(adjacency, clusters))
    print('Modularity:', modularity(adjacency, clusters))
    print(
        'NMI:',
        normalized_mutual_info_score(labels,
                                     clusters[label_mask],
                                     average_method='arithmetic'))
    print('Precision:', precision(labels, clusters[label_mask]))
    print('Recall:', recall(labels, clusters[label_mask]))
    with open(format_filename(), 'w') as out_file:
        print('Conductance:', conductance(adjacency, clusters), file=out_file)
        print('Modularity:', modularity(adjacency, clusters), file=out_file)
        print('NMI:',
              normalized_mutual_info_score(labels,
                                           clusters[label_mask],
                                           average_method='arithmetic'),
              file=out_file)
        print('Precision:',
              precision(labels, clusters[label_mask]),
              file=out_file)
        print('Recall:', recall(labels, clusters[label_mask]), file=out_file)
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  print('Starting', format_filename())
  if FLAGS.load_strategy == 'schur':
    adjacency, features, labels, label_mask = load_npz_to_sparse_graph(
        FLAGS.graph_path)
  elif FLAGS.load_strategy == 'kipf':
    adjacency, features, labels, label_mask = load_kipf_data(
        *os.path.split(FLAGS.graph_path))
  else:
    raise Exception('Unknown loading strategy!')
  n_nodes = adjacency.shape[0]
  feature_size = features.shape[1]
  architecture = [int(x) for x in FLAGS.architecture.strip('[]').split('_')]
  graph_clean_normalized = scipy_to_tf(
      normalize_graph(adjacency.copy(), normalized=True))

  input_features = tf.keras.layers.Input(shape=(feature_size,))
  input_features_corrupted = tf.keras.layers.Input(shape=(feature_size,))
  input_graph = tf.keras.layers.Input((n_nodes,), sparse=True)

  encoder = [GCN(512) for size in architecture]

  model = deep_graph_infomax(
      [input_features, input_features_corrupted, input_graph], encoder)

  def loss(model, x, y, training):
    _, y_ = model(x, training=training)
    return loss_object(y_true=y, y_pred=y_)

  def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
      loss_value = loss(model, inputs, targets, training=True)
      for loss_internal in model.losses:
        loss_value += loss_internal
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  patience = 20

  best_loss = 999
  patience_counter = 0

  for epoch in range(FLAGS.n_epochs):
    features_corr = features.copy()
    pseudolabels = tf.concat([tf.zeros([n_nodes, 1]), tf.ones([n_nodes, 1])], 0)
    features_corr = features_corr.copy()
    np.random.shuffle(features_corr)
    loss_value, grads = grad(model,
                             [features, features_corr, graph_clean_normalized],
                             pseudolabels)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    loss_value = loss_value.numpy()
    print(epoch, loss_value)
    if loss_value > best_loss:
      patience_counter += 1
      if patience_counter == patience:
        break
    else:
      best_loss = loss_value
      patience_counter = 0
  representations = model([features, features, graph_clean_normalized],
                          training=False)[0].numpy()
  clf = KMeans(n_clusters=FLAGS.n_clusters)
  clf.fit(representations)
  clusters = clf.labels_
  print('Conductance:', conductance(adjacency, clusters))
  print('Modularity:', modularity(adjacency, clusters))
  print(
      'NMI:',
      normalized_mutual_info_score(
          labels, clusters[label_mask], average_method='arithmetic'))
  print('Precision:', precision(labels, clusters[label_mask]))
  print('Recall:', recall(labels, clusters[label_mask]))
  with open(format_filename(), 'w') as out_file:
    print('Conductance:', conductance(adjacency, clusters), file=out_file)
    print('Modularity:', modularity(adjacency, clusters), file=out_file)
    print(
        'NMI:',
        normalized_mutual_info_score(
            labels, clusters[label_mask], average_method='arithmetic'),
        file=out_file)
    print('Precision:', precision(labels, clusters[label_mask]), file=out_file)
    print('Recall:', recall(labels, clusters[label_mask]), file=out_file)