示例#1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    print('Bröther may i have some self-lööps')
    n_nodes = FLAGS.n_nodes
    n_clusters = FLAGS.n_clusters
    train_size = FLAGS.train_size
    batch_size = FLAGS.batch_size
    data_clean, data_dirty, labels = line_gaussians(n_nodes, n_clusters)
    graph_clean = construct_knn_graph(data_clean)
    n_neighbors = [15, 10]  # TODO(tsitsulin): move to FLAGS.
    total_matrix_size = 1 + np.cumprod(n_neighbors).sum()

    train_mask = np.zeros(n_nodes, dtype=np.bool)
    train_mask[np.random.choice(np.arange(n_nodes),
                                int(n_nodes * train_size),
                                replace=False)] = True
    test_mask = ~train_mask
    print(f'Data shape: {data_clean.shape}, graph shape: {graph_clean.shape}')
    print(f'Train size: {train_mask.sum()}, test size: {test_mask.sum()}')

    input_features = tf.keras.layers.Input(shape=(
        total_matrix_size,
        2,
    ))
    input_graph = tf.keras.layers.Input((
        total_matrix_size,
        total_matrix_size,
    ))

    output = multilayer_gcn([input_features, input_graph],
                            [64, 32, n_clusters])
    model = tf.keras.Model(inputs=[input_features, input_graph],
                           outputs=output[:, 0, :])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])
    for epoch in range(FLAGS.n_epochs):
        subgraph_mat, features_mat, node_ids, _ = random_batch(
            graph_clean, data_dirty, batch_size, n_neighbors)
        model.fit([features_mat, subgraph_mat],
                  labels[node_ids],
                  batch_size,
                  shuffle=False)

    subgraph_mat, features_mat, _ = make_batch(graph_clean, data_dirty,
                                               np.arange(n_nodes)[test_mask],
                                               n_neighbors)
    clusters = model([features_mat, subgraph_mat]).numpy().argmax(axis=1)
    print(
        'NMI:',
        normalized_mutual_info_score(labels[test_mask],
                                     clusters,
                                     average_method='arithmetic'))
    print('Accuracy:', accuracy_score(labels[test_mask], clusters))
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    print('Bröther may i have some self-lööps')
    n_nodes = FLAGS.n_nodes
    n_clusters = FLAGS.n_clusters
    train_size = FLAGS.train_size
    data_clean, data_dirty, labels = line_gaussians(n_nodes, n_clusters)
    graph_clean = construct_knn_graph(data_clean).todense().A1.reshape(
        n_nodes, n_nodes)

    train_mask = np.zeros(n_nodes, dtype=np.bool)
    train_mask[np.random.choice(np.arange(n_nodes),
                                int(n_nodes * train_size),
                                replace=False)] = True
    test_mask = ~train_mask
    print(f'Data shape: {data_clean.shape}, graph shape: {graph_clean.shape}')
    print(f'Train size: {train_mask.sum()}, test size: {test_mask.sum()}')

    input_features = tf.keras.layers.Input(shape=(2, ))
    input_graph = tf.keras.layers.Input((n_nodes, ))

    output = multilayer_gcn([input_features, input_graph],
                            [64, 32, n_clusters])
    model = tf.keras.Model(inputs=[input_features, input_graph],
                           outputs=output)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])
    for epoch in range(FLAGS.n_epochs):
        model.fit([data_dirty, graph_clean],
                  labels,
                  n_nodes,
                  shuffle=False,
                  sample_weight=train_mask)
    clusters = model([data_dirty,
                      graph_clean]).numpy().argmax(axis=1)[test_mask]
    print(
        'NMI:',
        normalized_mutual_info_score(labels[test_mask],
                                     clusters,
                                     average_method='arithmetic'))
    print('Accuracy:', accuracy_score(labels[test_mask], clusters))