def get_data(ids): if ids is None: return None example = inputs, data.labels, preprocess_weights(ids, n, normalize=True) return (example, )
def build_stale_dataset( stale_model: tf.keras.Model, data: SemiSupervisedSingle, cache: h5py.Group, batch_size: int, *, use_dense_adjacency: bool = False, ) -> tf.data.Dataset: def get_key(name): assert name.startswith("stale_") return name[len("stale_"):] keys = tuple( get_key(layer.name) for layer in stale_model.layers if isinstance(layer, StalePropagation)) def load_prop(key, ids): group = cache[key] # TODO: make these slices rather than gathers return np.array(group["x0"][ids]), np.array(group["y0"][ids]) def load_stale_values(ids): return tuple(itertools.chain(*(load_prop(key, ids) for key in keys))) def map_fn(ids): # TODO: parallelize loading? ids = tf.sort(ids) dtypes = itertools.chain(*((cache[k]["x0"].dtype, cache[k]["y0"].dtype) for k in keys)) stale_values = tf.numpy_function( load_stale_values, [ids], [tf.dtypes.as_dtype(d) for d in dtypes], ) x0s = stale_values[::2] y0s = stale_values[1::2] stale_values = {} for key, x0, y0 in zip(keys, x0s, y0s): x0.set_shape((None, cache[key]["x0"].shape[1])) y0.set_shape((None, cache[key]["y0"].shape[1])) stale_values[key] = {"x0": x0, "y0": y0} features = tf.gather(data.node_features, ids, axis=0) adj = gather_gather(data.adjacency, ids) if use_dense_adjacency: adj = tf.sparse.to_dense(adj) inputs = (features, adj), stale_values labels = tf.gather(data.labels, ids, axis=0) batch_weights = tf.gather(weights, ids, axis=0) return inputs, labels, batch_weights num_nodes = data.adjacency.shape[0] weights = preprocess_weights(data.train_ids, num_nodes, normalize=True) ids_ds = tf.data.Dataset.range(num_nodes).shuffle(num_nodes).batch( batch_size) dataset = ids_ds.map(map_fn) tf.nest.assert_same_structure(dataset.element_spec[0], stale_model.input) return dataset.prefetch(tf.data.AUTOTUNE)
def val_dataset(split_ids): weights = preprocess_weights(split_ids, data.adjacency.shape[0]) def map_fn(mapping, ids): batch_labels = tf.gather(data.labels, ids, axis=0) batch_weights = tf.gather(weights, ids, axis=0) return mapping, batch_labels, batch_weights return output_dataset.map(map_fn)
def get_dataset(ids: tp.Optional[tf.Tensor]): if ids is None: return None weights = preprocess_weights(ids, num_nodes, normalize=True) def map_fn(ids): labels = tf.gather(data.labels, ids, axis=0) w = tf.gather(weights, ids, axis=0) if isinstance(features, tf.SparseTensor): X = stfu.gather(features, ids, axis=0) else: X = tf.gather(features, ids, axis=0) T = gu.get_pairwise_effective_resistance(Z, ids) T = tf.math.softmax(-1 * T, axis=1) # T = tf.where(T == 0, tf.zeros_like(T), 1 / T) # resistance to conductance # T = transformed(T, transition_transform) # T = tf.eye(tf.shape(T)[0]) inputs = T, X return inputs, labels, w return (tf.data.Dataset.range(num_nodes).shuffle( num_nodes, seed=dataset_seed).batch(batch_size).map(map_fn).prefetch( tf.data.AUTOTUNE))
def get_split(ids: tf.Tensor): weights = preprocess_weights(ids, num_nodes, normalize=True) example = (inputs, data.labels, weights) return (example, )