Example #1
0
 def get_data(ids):
     if ids is None:
         return None
     example = inputs, data.labels, preprocess_weights(ids,
                                                       n,
                                                       normalize=True)
     return (example, )
Example #2
0
def build_stale_dataset(
    stale_model: tf.keras.Model,
    data: SemiSupervisedSingle,
    cache: h5py.Group,
    batch_size: int,
    *,
    use_dense_adjacency: bool = False,
) -> tf.data.Dataset:
    def get_key(name):
        assert name.startswith("stale_")
        return name[len("stale_"):]

    keys = tuple(
        get_key(layer.name) for layer in stale_model.layers
        if isinstance(layer, StalePropagation))

    def load_prop(key, ids):
        group = cache[key]
        # TODO: make these slices rather than gathers
        return np.array(group["x0"][ids]), np.array(group["y0"][ids])

    def load_stale_values(ids):
        return tuple(itertools.chain(*(load_prop(key, ids) for key in keys)))

    def map_fn(ids):
        # TODO: parallelize loading?
        ids = tf.sort(ids)
        dtypes = itertools.chain(*((cache[k]["x0"].dtype, cache[k]["y0"].dtype)
                                   for k in keys))
        stale_values = tf.numpy_function(
            load_stale_values,
            [ids],
            [tf.dtypes.as_dtype(d) for d in dtypes],
        )
        x0s = stale_values[::2]
        y0s = stale_values[1::2]
        stale_values = {}

        for key, x0, y0 in zip(keys, x0s, y0s):
            x0.set_shape((None, cache[key]["x0"].shape[1]))
            y0.set_shape((None, cache[key]["y0"].shape[1]))
            stale_values[key] = {"x0": x0, "y0": y0}

        features = tf.gather(data.node_features, ids, axis=0)
        adj = gather_gather(data.adjacency, ids)
        if use_dense_adjacency:
            adj = tf.sparse.to_dense(adj)
        inputs = (features, adj), stale_values
        labels = tf.gather(data.labels, ids, axis=0)
        batch_weights = tf.gather(weights, ids, axis=0)
        return inputs, labels, batch_weights

    num_nodes = data.adjacency.shape[0]
    weights = preprocess_weights(data.train_ids, num_nodes, normalize=True)
    ids_ds = tf.data.Dataset.range(num_nodes).shuffle(num_nodes).batch(
        batch_size)
    dataset = ids_ds.map(map_fn)
    tf.nest.assert_same_structure(dataset.element_spec[0], stale_model.input)
    return dataset.prefetch(tf.data.AUTOTUNE)
Example #3
0
        def val_dataset(split_ids):
            weights = preprocess_weights(split_ids, data.adjacency.shape[0])

            def map_fn(mapping, ids):
                batch_labels = tf.gather(data.labels, ids, axis=0)
                batch_weights = tf.gather(weights, ids, axis=0)
                return mapping, batch_labels, batch_weights

            return output_dataset.map(map_fn)
Example #4
0
    def get_dataset(ids: tp.Optional[tf.Tensor]):
        if ids is None:
            return None
        weights = preprocess_weights(ids, num_nodes, normalize=True)

        def map_fn(ids):
            labels = tf.gather(data.labels, ids, axis=0)
            w = tf.gather(weights, ids, axis=0)
            if isinstance(features, tf.SparseTensor):
                X = stfu.gather(features, ids, axis=0)
            else:
                X = tf.gather(features, ids, axis=0)
            T = gu.get_pairwise_effective_resistance(Z, ids)
            T = tf.math.softmax(-1 * T, axis=1)
            # T = tf.where(T == 0, tf.zeros_like(T), 1 / T)  # resistance to conductance
            # T = transformed(T, transition_transform)
            # T = tf.eye(tf.shape(T)[0])
            inputs = T, X
            return inputs, labels, w

        return (tf.data.Dataset.range(num_nodes).shuffle(
            num_nodes,
            seed=dataset_seed).batch(batch_size).map(map_fn).prefetch(
                tf.data.AUTOTUNE))
Example #5
0
 def get_split(ids: tf.Tensor):
     weights = preprocess_weights(ids, num_nodes, normalize=True)
     example = (inputs, data.labels, weights)
     return (example, )