Example #1
0
def preprocess_single(
    data: SemiSupervisedSingle,
    *,
    num_eigs: int = 6,
    features_transform: tp.Union[TensorTransform,
                                 tp.Iterable[TensorTransform]] = (),
    eigenvalue_transform: tp.Union[TensorTransform,
                                   tp.Iterable[TensorTransform]] = (),
    largest_component_only: bool = False,
    normalized: bool = True,
    use_ritz_vectors: bool = False,
) -> DataSplit:
    if largest_component_only:
        data = get_largest_component(data, directed=False)
    num_nodes = data.adjacency.shape[0]
    L = (normalized_laplacian if normalized else laplacian)(data.adjacency)
    if use_ritz_vectors:
        w, v = ritz_embedding(L, tf.ones((num_nodes, )), num_eigs)
    else:
        w, v = eigsh_lap(L, k=num_eigs)  # pylint: disable=unpacking-non-sequence
    w = transformed(w, eigenvalue_transform)

    features = transformed(data.node_features, features_transform)
    inputs = features, w, v

    def get_split(ids: tf.Tensor):
        weights = preprocess_weights(ids, num_nodes, normalize=True)
        example = (inputs, data.labels, weights)
        return (example, )

    return DataSplit(
        get_split(data.train_ids),
        get_split(data.validation_ids),
        get_split(data.test_ids),
    )
Example #2
0
def get_spectral_split(
    data: AutoencoderData,
    spectral_size: int,
    which: str = "LM",
    adjacency_transform: tp.Optional[tp.Callable] = None,
    features_transform: tp.Optional[tp.Callable] = None,
) -> DataSplit:
    adjacency = transformed(data.adjacency, adjacency_transform)
    w, v = eigsh(adjacency, spectral_size, which=which)
    del w
    features = transformed(data.features, features_transform)
    if features is not None:
        if isinstance(features, tf.SparseTensor):
            features = tf.sparse.to_dense(features)
        v = tf.concat((v, features), 1)

    def get_examples(labels, weights) -> tp.Iterable:
        example = v, labels, weights
        return (example, )

    return DataSplit(
        get_examples(data.train_labels, data.train_weights),
        get_examples(data.true_labels, data.validation_weights),
        get_examples(data.true_labels, data.test_weights),
    )
Example #3
0
def get_prefactorized_data(
    data: SemiSupervisedSingle,
    adjacency_transform: tp.Union[SparseTensorTransform,
                                  tp.Sequence[SparseTensorTransform]] = (),
    features_transform: tp.Union[TensorTransform,
                                 tp.Sequence[TensorTransform]] = (),
    factoriser: tp.Callable = get_eigen_factorization,
    normalize: bool = False,
    largest_component_only: bool = False,
) -> DataSplit:
    if largest_component_only:
        data = get_largest_component(data, directed=False)
    features = transformed(data.node_features, features_transform)
    A = transformed(data.adjacency, adjacency_transform)
    n = A.shape[0]
    assert n is not None

    V = factoriser(A)
    if normalize:
        d = tf.matmul(
            V,
            tf.linalg.matvec(V,
                             tf.ones((n, ), dtype=V.dtype),
                             transpose_a=True))
        V = V * tf.math.rsqrt(tf.abs(d))
    inputs = (features, V)

    def get_data(ids):
        if ids is None:
            return None
        example = inputs, data.labels, preprocess_weights(ids,
                                                          n,
                                                          normalize=True)
        return (example, )

    return DataSplit(get_data(data.train_ids), get_data(data.validation_ids),
                     get_data(data.test_ids))
Example #4
0
def preprocess_single(data: SemiSupervisedSingle,
                      batch_size: int,
                      features_transform=(),
                      transition_transform=(),
                      jl_factor: float = 4.0,
                      z_seed: int = 0,
                      dataset_seed: int = 0,
                      **cg_kwargs) -> DataSplit:
    Z = gu.approx_effective_resistance_z(data.adjacency,
                                         jl_factor=jl_factor,
                                         rng=z_seed,
                                         **cg_kwargs)
    # T = gu.get_pairwise_effective_resistance(Z, tf.range(Z.shape[0], dtype=tf.int64))
    # import matplotlib.pyplot as plt
    # import numpy as np

    # # T = tf.where(
    # #     T < 1e-5, tf.fill(tf.shape(T), -tf.constant(np.inf, dtype=T.dtype)), 1 / T
    # # )
    # # T = tf.where(T < 1e-5, tf.zeros_like(T), 1 / T)
    # # T = 1 / (1 + T)
    # # T = tf.math.softmax(T, axis=1)
    # # T = tf.exp(T)
    # T = tf.exp(-T)
    # # T = tf.math.softmax(-T, axis=1)
    # T = T.numpy().reshape(-1)
    # T = np.sort(T)
    # # T = T.numpy()
    # # T = T[T > 1e-4]
    # # T = 1 / T
    # # T = tf.sigmoid(T - 5 * T.mean())
    # # T = np.sort(T)
    # # n = T.shape[0]
    # # T = T[int(0.01 * n) : int(0.99 * n)]
    # print(T.min(), T.max())
    # fig, (ax0, ax1) = plt.subplots(1, 2)
    # ax0.hist(T, bins=50)
    # ax1.hist(T, cumulative=True, bins=50)
    # plt.show()
    # raise Exception("debug")
    # Z = gu.effective_resistance_z(data.adjacency, **cg_kwargs)
    num_nodes = data.adjacency.dense_shape[0]
    features = transformed(data.node_features, features_transform)

    def get_dataset(ids: tp.Optional[tf.Tensor]):
        if ids is None:
            return None
        weights = preprocess_weights(ids, num_nodes, normalize=True)

        def map_fn(ids):
            labels = tf.gather(data.labels, ids, axis=0)
            w = tf.gather(weights, ids, axis=0)
            if isinstance(features, tf.SparseTensor):
                X = stfu.gather(features, ids, axis=0)
            else:
                X = tf.gather(features, ids, axis=0)
            T = gu.get_pairwise_effective_resistance(Z, ids)
            T = tf.math.softmax(-1 * T, axis=1)
            # T = tf.where(T == 0, tf.zeros_like(T), 1 / T)  # resistance to conductance
            # T = transformed(T, transition_transform)
            # T = tf.eye(tf.shape(T)[0])
            inputs = T, X
            return inputs, labels, w

        return (tf.data.Dataset.range(num_nodes).shuffle(
            num_nodes,
            seed=dataset_seed).batch(batch_size).map(map_fn).prefetch(
                tf.data.AUTOTUNE))

    return DataSplit(
        get_dataset(data.train_ids),
        get_dataset(data.validation_ids),
        get_dataset(data.test_ids),
    )
Example #5
0
def get_spectral_split_v3(
    data: AutoencoderData,
    spectral_size: int,
    batch_size: int,
    shuffle_buffer: int = 256,
    prefetch_buffer: tp.Optional[int] = -1,
    adjacency_transform: tp.Optional[tp.Callable] = None,
    which: str = "LM",
) -> DataSplit:
    adjacency = transformed(data.adjacency, adjacency_transform)
    w, v = eigsh(adjacency, spectral_size, which=which)
    del w
    num_nodes = v.shape[0]
    rem = num_nodes % batch_size
    padding = batch_size - rem if rem else 0
    num_nodes_ceil = num_nodes + padding
    assert num_nodes_ceil % batch_size == 0, (num_nodes_ceil, batch_size)
    num_parts = num_nodes_ceil // batch_size

    v = tf.pad(v, [[0, 1], [0, 0]])  # pylint: disable=no-value-for-parameter

    def get_examples(labels, weights, shuffle: bool = True):
        if weights is None:
            return None
        # pad so that weights[num_nodes] == 0
        weights = tf.pad(weights, [[0, 1]])  # pylint: disable=no-value-for-parameter
        labels = tf.pad(  # pylint: disable=no-value-for-parameter
            labels, [[0, 1], [0, 0]])

        def generator_fn():
            row_parts = tf.concat(
                (
                    tf.range(num_nodes, dtype=tf.int64),
                    tf.fill((padding, ), tf.cast(num_nodes, tf.int64)),
                ),
                0,
            )
            col_parts = row_parts

            if shuffle:
                row_parts = tf.random.shuffle(row_parts)
                col_parts = tf.random.shuffle(col_parts)
            row_parts = tf.reshape(row_parts, (num_parts, batch_size))
            col_parts = tf.reshape(col_parts, (num_parts, batch_size))

            for i in range(num_parts):
                for j in range(num_parts):
                    yield tf.stack((row_parts[i], col_parts[j]), axis=1)

        dataset = tf.data.Dataset.from_generator(
            generator_fn,
            output_signature=tf.TensorSpec((batch_size, 2), dtype=tf.int64),
        )
        dataset = dataset.apply(
            tf.data.experimental.assert_cardinality(num_parts**2))
        if shuffle and shuffle_buffer:
            dataset = dataset.shuffle(shuffle_buffer)

        def map_fun(indices):
            row_ids, col_ids = tf.unstack(indices, axis=1)
            indices_2d = tf.stack(tf.meshgrid(row_ids, col_ids, indexing="ij"),
                                  axis=-1)
            valid = tf.reduce_all(indices_2d < num_nodes, axis=-1)
            indices_1d = ravel_multi_index(
                indices_2d,
                tf.convert_to_tensor((num_nodes, num_nodes), tf.int64),
                axis=-1,
            )
            indices_1d = tf.where(valid, indices_1d,
                                  num_nodes**2 * tf.ones_like(indices_1d))
            indices_1d = tf.reshape(indices_1d, (-1, ))  # [batch_size ** 2]
            labels_ = tf.gather(labels, indices_1d, axis=0)
            weights_ = tf.gather(weights, indices_1d, axis=0)
            features = tf.gather(v, indices,
                                 axis=0)  # [batch_size, 2, spectral_dim]

            labels_ = tf.expand_dims(labels_, -1)  # [batch_size**2, 1]
            return features, labels_, weights_

        dataset = dataset.map(map_fun)
        if prefetch_buffer:
            dataset = dataset.prefetch(prefetch_buffer)
        return dataset

    return DataSplit(
        get_examples(data.train_labels, data.train_weights, shuffle_buffer),
        get_examples(data.true_labels, data.test_weights,
                     shuffle_buffer),  # HACK: reversed validation/test splits
        get_examples(data.true_labels, data.validation_weights,
                     shuffle_buffer),
    )