def test_autoregressive_connect_graph_dynamic(): graphs = GraphsTuple(nodes=tf.range(20), n_node=tf.constant([12, 8]), n_edge=tf.constant([0, 0]), edges=None, receivers=None, senders=None, globals=None) graphs = GraphsTuple(nodes=tf.range(6), n_node=tf.constant([6, 0]), n_edge=tf.constant([0, 0]), edges=None, receivers=None, senders=None, globals=None) graphs = autoregressive_connect_graph_dynamic(graphs, exclude_self_edges=False) import networkx as nx G = nx.MultiDiGraph() for sender, receiver in zip(graphs.senders.numpy(), graphs.receivers.numpy()): G.add_edge(sender, receiver) nx.drawing.draw_circular( G, with_labels=True, node_color=(0, 0, 0), font_color=(1, 1, 1), font_size=25, node_size=1000, arrowsize=30, ) import pylab as plt plt.show()
def test_batch_reshape(): data = dict(nodes=tf.reshape(tf.range(3 * 2), (3, 2)), edges=tf.reshape(tf.range(40 * 5), (40, 5)), senders=tf.random.uniform((40, ), minval=0, maxval=3, dtype=tf.int32), receivers=tf.random.uniform((40, ), minval=0, maxval=3, dtype=tf.int32), n_node=tf.constant([3]), n_edge=tf.constant([40]), globals=None) graph = GraphsTuple(**data) graphs = utils_tf.concat([graph] * 4, axis=0) batched_graphs = graph_batch_reshape(graphs) assert tf.reduce_all( batched_graphs.nodes[0] == batched_graphs.nodes[1]).numpy() assert tf.reduce_all( batched_graphs.edges[0] == batched_graphs.edges[1]).numpy() assert tf.reduce_all( batched_graphs.senders[0] + graphs.n_node[0] == batched_graphs.senders[1]).numpy() assert tf.reduce_all( batched_graphs.receivers[0] + graphs.n_node[0] == batched_graphs.receivers[1]).numpy() # print(batched_graphs) unbatched_graphs = graph_unbatch_reshape(batched_graphs) for (t1, t2) in zip(graphs, unbatched_graphs): if t1 is not None: assert tf.reduce_all(t1 == t2).numpy()
def _build(self, graph, crossing_steps): n_non_components = graph.n_node latent = self.node_block(graph) for _ in range(self.num_components): latent = self.edge_block(latent) latent = self.global_block(latent) component = latent.globals new_nodes = tf.concat([latent.nodes, component], axis=0) n_nodes = latent.n_node new_senders = tf.concat([ tf.range(n_nodes + 1), tf.fill(dims=(n_nodes + 1), value=n_nodes) ]) new_receivers = tf.reverse(new_senders, axis=0) new_edges = tf.fill(dims=((n_nodes + 1)**2), value=0.) n_edges = (n_nodes + 1)**2 latent = GraphsTuple(nodes=new_nodes, edges=new_edges, globals=None, senders=new_senders, receivers=new_receivers, n_node=n_nodes + 1, n_edge=n_edges) gaussian_components = latent.nodes[n_non_components:] return gaussian_components
def nearest_neighbours_connected_graph(virtual_positions, k): kdtree = cKDTree(virtual_positions) dist, idx = kdtree.query(virtual_positions, k=k + 1) receivers = idx[:, 1:] # N,k senders = np.arange(virtual_positions.shape[0]) # N senders = np.tile(senders[:, None], [1, k]) # N,k receivers = receivers.flatten() senders = senders.flatten() graph_nodes = tf.convert_to_tensor(virtual_positions, tf.float32) graph_nodes.set_shape([None, 3]) receivers = tf.convert_to_tensor(receivers, tf.int32) receivers.set_shape([None]) senders = tf.convert_to_tensor(senders, tf.int32) senders.set_shape([None]) n_node = tf.shape(graph_nodes)[0:1] n_edge = tf.shape(senders)[0:1] graph_data_dict = dict(nodes=graph_nodes, edges=tf.zeros((n_edge[0], 1)), globals=tf.zeros([1]), receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge) return GraphsTuple(**graph_data_dict)
def _init_call_func(self, observations, training=False): """Graph nets implementation.""" node_vals, edge_indices, node_to_graph, edge_vals = GraphObserver.graph( observations=observations, graph_dims=self._graph_dims, dense=True) batch_size = tf.shape(observations)[0] node_counts = tf.unique_with_counts(node_to_graph)[2] edge_counts = tf.math.square(node_counts) input_graph = GraphsTuple( nodes=tf.cast(node_vals, tf.float32), edges=tf.cast(edge_vals, tf.float32), globals=tf.tile([[0.]], [batch_size, 1]), receivers=tf.cast(edge_indices[:, 1], tf.int32), senders=tf.cast(edge_indices[:, 0], tf.int32), n_node=node_counts, n_edge=edge_counts) self._latent_trace = [] latent = input_graph for gb in self._graph_blocks: latent = gb(latent) self._latent_trace.append(latent) node_values = tf.reshape(latent.nodes, [batch_size, -1, self._embedding_size]) return node_values
def build_dataset(data_dir): """ Build data set from a directory of tfrecords. Args: data_dir: str, path to *.tfrecords Returns: Dataset obj. """ tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords')) dataset = tf.data.TFRecordDataset(tfrecords).map(partial(decode_examples_old, node_shape=(11,), image_shape=(256, 256, 1))) # (graph, image, spsh, proj) _graphs = dataset.map(lambda graph_data_dict, img, spsh, proj: (graph_data_dict, spsh, proj)).shuffle(buffer_size=50) _images = dataset.map(lambda graph_data_dict, img, spsh, proj: (img, spsh, proj)).shuffle(buffer_size=50) shuffled_dataset = tf.data.Dataset.zip((_graphs, _images)) # ((graph_data_dict, idx1), (img, idx2)) shuffled_dataset = shuffled_dataset.map(lambda ds1, ds2: (ds1[0], ds2[0], (ds1[1] == ds2[1]) and (ds1[2] == ds2[2]))) # (graph, img, yes/no) shuffled_dataset = shuffled_dataset.filter(lambda graph_data_dict, img, c: ~c) shuffled_dataset = shuffled_dataset.map(lambda graph_data_dict, img, c: (graph_data_dict, img, tf.cast(c, tf.int32))) nonshuffeled_dataset = dataset.map( lambda graph_data_dict, img, spsh, proj : (graph_data_dict, img, tf.constant(1, dtype=tf.int32))) # (graph, img, yes) dataset = tf.data.experimental.sample_from_datasets([shuffled_dataset, nonshuffeled_dataset]) dataset = dataset.map(lambda graph_data_dict, img, c: (GraphsTuple(**graph_data_dict), img, c)) # dataset = batch_dataset_set_graph_tuples(all_graphs_same_size=True, dataset=dataset, batch_size=16) return dataset
def build_dataset(data_dir): """ Build data set from a directory of tfrecords. Args: data_dir: str, path to *.tfrecords Returns: Dataset obj. """ tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords')) dataset = tf.data.TFRecordDataset(tfrecords).map( partial(decode_examples, node_shape=(4, ), image_shape=(18, 18, 1))) # (graph_data_dict, image, idx) _graphs = dataset.map(lambda graph_data_dict, img, idx: (graph_data_dict, idx)).shuffle(buffer_size=50) _images = dataset.map(lambda graph_data_dict, img, idx: (img, idx)).shuffle(buffer_size=50) shuffled_dataset = tf.data.Dataset.zip( (_graphs, _images)) # ((graph_data_dict, idx1), (img, idx2)) shuffled_dataset = shuffled_dataset.map( lambda ds1, ds2: (ds1[0], ds2[0], ds1[1] == ds2[1])) # (graph, img, yes/no) shuffled_dataset = shuffled_dataset.filter(lambda graph, img, c: ~c) shuffled_dataset = shuffled_dataset.map(lambda graph_data_dict, img, c: ( graph_data_dict, img, tf.cast(c, tf.int32))) nonshuffeled_dataset = dataset.map( lambda graph_data_dict, img, idx: (graph_data_dict, img, tf.constant(1, dtype=tf.int32) )) # (graph, img, yes) dataset = tf.data.experimental.sample_from_datasets( [shuffled_dataset, nonshuffeled_dataset]) dataset = dataset.map(lambda graph_data_dict, img, c: (GraphsTuple( globals=None, edges=None, **graph_data_dict), img, c)) return dataset
def construct_input_graph(self, input_sequence, N2): G, N = get_shape(input_sequence) # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3, embedding_dim input_tokens = tf.nn.embedding_lookup(self.embeddings, input_sequence) self.initialize_positional_encodings(input_tokens) nodes = input_tokens + self.positional_encodings n_node = tf.fill([G], N) n_edge = tf.zeros_like(n_node) data_dict = dict(nodes=nodes, edges=None, senders=None, receivers=None, globals=None, n_node=n_node, n_edge=n_edge) concat_graphs = GraphsTuple(**data_dict) concat_graphs = graph_unbatch_reshape(concat_graphs) # [n_graphs * (num_input + num_output), embedding_size] # nodes, senders, receivers, globals def edge_connect_rule(sender, receiver): # . a . b -> a . b . complete_2d = (sender < N2 + 1) & (receiver < N2 + 1) & ( sender + 1 != receiver) # exclude senders from one-right, so it doesn't learn copy. auto_regressive_3d = (sender <= receiver) & ( receiver >= N2 + 1) # auto-regressive (excluding 2d) with self-loops return complete_2d | auto_regressive_3d # nodes, senders, receivers, globals concat_graphs = connect_graph_dynamic(concat_graphs, edge_connect_rule) return concat_graphs
def test_kgcn_runs(self): tf.enable_eager_execution() graph = GraphsTuple( nodes=tf.convert_to_tensor( np.array([[1, 2, 0], [1, 0, 0], [1, 1, 0]], dtype=np.float32)), edges=tf.convert_to_tensor( np.array([[1, 0, 0], [1, 0, 0]], dtype=np.float32)), globals=tf.convert_to_tensor( np.array([[0, 0, 0, 0, 0]], dtype=np.float32)), receivers=tf.convert_to_tensor(np.array([1, 2], dtype=np.int32)), senders=tf.convert_to_tensor(np.array([0, 1], dtype=np.int32)), n_node=tf.convert_to_tensor(np.array([3], dtype=np.int32)), n_edge=tf.convert_to_tensor(np.array([2], dtype=np.int32))) attr_embedders = { lambda: lambda x: tf.constant(np.zeros((3, 6), dtype=np.float32)): [0, 1, 2] } kgcn = KGCN(3, 2, 5, 6, attr_embedders, edge_output_size=3, node_output_size=3) kgcn(graph, 2)
def test_kgcn_runs(self): tf.enable_eager_execution() graph = GraphsTuple( nodes=tf.convert_to_tensor( np.array([[1, 2, 0], [1, 0, 0], [1, 1, 0]], dtype=np.float32)), edges=tf.convert_to_tensor( np.array([[1, 0, 0], [1, 0, 0]], dtype=np.float32)), globals=tf.convert_to_tensor( np.array([[0, 0, 0, 0, 0]], dtype=np.float32)), receivers=tf.convert_to_tensor(np.array([1, 2], dtype=np.int32)), senders=tf.convert_to_tensor(np.array([0, 1], dtype=np.int32)), n_node=tf.convert_to_tensor(np.array([3], dtype=np.int32)), n_edge=tf.convert_to_tensor(np.array([2], dtype=np.int32))) thing_embedder = ThingEmbedder(node_types=['a', 'b', 'c'], type_embedding_dim=5, attr_embedding_dim=6, categorical_attributes={ 'a': ['a1', 'a2', 'a3'], 'b': ['b1', 'b2', 'b3'] }, continuous_attributes={'c': (0, 1)}) role_embedder = RoleEmbedder(num_edge_types=2, type_embedding_dim=5) kgcn = KGCN(thing_embedder, role_embedder, edge_output_size=3, node_output_size=3) kgcn(graph, 2)
def test_plot_is_created(self): num_processing_steps_ge = 6 graph = nx.MultiDiGraph(name=0) existing = dict(solution=0) to_infer = dict(solution=2) candidate = dict(solution=1) # people graph.add_node(0, type='person', **existing) graph.add_node(1, type='person', **candidate) # parentships graph.add_node(2, type='parentship', **to_infer) graph.add_edge(2, 0, type='parent', **to_infer) graph.add_edge(2, 1, type='child', **candidate) graph_tuple_target = GraphsTuple(nodes=np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), edges=np.array([[0., 0., 1.], [0., 1., 0.]]), receivers=np.array([1, 2], dtype=np.int32), senders=np.array([0, 1], dtype=np.int32), globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32), n_node=np.array([3], dtype=np.int32), n_edge=np.array([2], dtype=np.int32)) graph_tuple_output = GraphsTuple(nodes=np.array([[1., 0., 0.], [1., 1., 0.], [1., 0., 1.]]), edges=np.array([[1., 0., 0.], [1., 1., 0.]]), receivers=np.array([1, 2], dtype=np.int32), senders=np.array([0, 1], dtype=np.int32), globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32), n_node=np.array([3], dtype=np.int32), n_edge=np.array([2], dtype=np.int32)) test_values = {"target": graph_tuple_target, "outputs": [graph_tuple_output for _ in range(6)]} filename = f'./graph_{datetime.datetime.now()}.png' plot_predictions([graph], test_values, num_processing_steps_ge, output_file=filename) self.assertTrue(os.path.isfile(filename))
def build_dataset(tfrecords): # Extract the dataset (graph tuple, image, example_idx) from the tfrecords files dataset = tf.data.TFRecordDataset(tfrecords).map( partial(decode_examples, node_shape=(10, ), edge_shape=(2, ), image_shape=(1000, 1000, 1))) # (graph, image, idx) # Take the graphs and their corresponding index and shuffle the order of these pairs # Do the same for the images _graphs = dataset.map( lambda graph_data_dict, img, cluster_idx, projection_idx, vprime: (graph_data_dict, 26 * cluster_idx + projection_idx)).shuffle( buffer_size=260) # .replace(globals=tf.zeros((1, 1))) _images = dataset.map( lambda graph_data_dict, img, cluster_idx, projection_idx, vprime: (img, 26 * cluster_idx + projection_idx)).shuffle(buffer_size=260) # Zip the shuffled datasets back together so typically the index of the graph and image don't match. shuffled_dataset = tf.data.Dataset.zip( (_graphs, _images)) # ((graph, idx1), (img, idx2)) # Reshape the dataset to the graph, the image and a yes or no whether the indices are the same # So ((graph, idx1), (img, idx2)) --> (graph, img, True/False) shuffled_dataset = shuffled_dataset.map( lambda ds1, ds2: (ds1[0], ds2[0], ds1[1] == ds2[1])) # (graph, img, yes/no) # Take the subset of the data where the graph and image don't correspond (which is most of the dataset, since it's shuffled) shuffled_dataset = shuffled_dataset.filter( lambda graph_data_dict, img, c: ~c) # Transform the True/False class into 1/0 integer shuffled_dataset = shuffled_dataset.map(lambda graph_data_dict, img, c: ( GraphsTuple(**graph_data_dict), img, tf.cast(c, tf.int32))) # Use the original dataset where all indices correspond and give them class True and turn that into an integer # So every instance gets class 1 nonshuffeled_dataset = dataset.map( lambda graph_data_dict, img, cluster_idx, projection_idx, vprime: (GraphsTuple(**graph_data_dict), img, tf.constant(1, dtype=tf.int32) )) # (graph, img, yes) # For the training data, take a sample either from the correct or incorrect combinations of graphs and images nn_dataset = tf.data.experimental.sample_from_datasets( [shuffled_dataset, nonshuffeled_dataset]) return nn_dataset
def destandardize_graphs_tuple(self, graphs: GraphsTuple) -> GraphsTuple: standard_graphs = graphs.replace(globals=self._destandardize( graphs.globals, mean=self.global_mean, std=self.global_std)) standard_graphs = standard_graphs.replace(nodes=self._destandardize( graphs.nodes, mean=self.nodes_mean, std=self.nodes_std)) standard_graphs = standard_graphs.replace(edges=self._destandardize( graphs.edges, mean=self.edges_mean, std=self.edges_std)) return standard_graphs
def _build(self, graph: GraphsTuple, num_processing_steps): # give edges and globals to graph from nodes graph = self._first_block(graph) output = [] for i in range(num_processing_steps): graph = self._message_passing_graph(graph) graph = graph._replace(globals=self._global_block(graph).globals + graph.globals) output.append(graph) return output
def _get_graph_tuple(input_tensors: List[tf.Tensor]) -> GraphsTuple: return GraphsTuple( nodes=input_tensors[0], edges=input_tensors[1], senders=input_tensors[2], receivers=input_tensors[3], globals=tf.expand_dims(input_tensors[4], axis=0), n_node=tf.convert_to_tensor([input_tensors[0].shape[0]]), n_edge=tf.convert_to_tensor([input_tensors[1].shape[0]]), )
def _build(self, graph: GraphsTuple, positions: GraphsTuple, num_processing_steps): # give edges and globals to graph from nodes graph = self._first_block(graph) outputs = [] for i in range(num_processing_steps): graph = graph._replace( nodes=tf.concat([positions.nodes, graph.nodes], axis=-1)) graph = self._message_passing_graph(graph) outputs.append(self._property_block(graph)) return outputs
def _build(self, batch, *args, **kwargs): (graph, img, c) = batch del c # The encoded cluster graph has globals which can be compared against the encoded image graph encoded_graph = self.epd_graph(graph, self._core_steps) # Add an extra dimension to the image (tf.summary expects a Tensor of rank 4) img = img[None, ...] im_before_cnn = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img)) tf.summary.image(f'img_before_cnn', im_before_cnn, step=self.step) img = self.auto_encoder.encoder(img) # Prevent the autoencoder from learning try: for variable in self.auto_encoder.encoder.trainable_variables: variable._trainable = False for variable in self.auto_encoder.decoder.trainable_variables: variable._trainable = False except: pass img_after_autoencoder = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img)) tf.summary.image(f'img_after_autoencoder', tf.transpose(img_after_autoencoder, [3, 1, 2, 0]), step=self.step) decoded_img = self.auto_encoder.decoder(img) decoded_img = (decoded_img - tf.reduce_min(decoded_img)) / (tf.reduce_max(decoded_img) - tf.reduce_min(decoded_img)) tf.summary.image(f'decoded_img', decoded_img, step=self.step) # Reshape the encoded image so it can be used for the nodes #1, w,h,c -> w*h, c nodes = tf.reshape(img, (-1,self.image_feature_size)) # Create a graph that has a node for every encoded pixel. The features of each node # are the channels of the corresponding pixel. Then connect each node with every other # node. img_graph = GraphsTuple(nodes=nodes, edges=None, globals=None, receivers=None, senders=None, n_node=tf.shape(nodes)[0:1], n_edge=tf.constant([0])) connected_graph = fully_connect_graph_dynamic(img_graph) # The encoded image graph has globals which can be compared against the encoded cluster graph encoded_img = self.epd_image(connected_graph, 1) # Compare the globals from the encoded cluster graph and encoded image graph # to estimate the similarity between the input graph and input image distance = self.compare(tf.concat([encoded_graph.globals, encoded_img.globals], axis=1)) + self.compare( tf.concat([encoded_img.globals, encoded_graph.globals], axis=1)) return distance
def _build_dataset(data_dir): tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords')) dataset = tf.data.TFRecordDataset(tfrecords).map( partial(decode_examples_old, node_shape=(11, ), image_shape=(256, 256, 1))) # (graph, image, spsh, proj) dataset = dataset.map(lambda graph_data_dict, img, spsh, proj: (GraphsTuple(**graph_data_dict), img)) return dataset
def build_dataset(data_dir, batch_size): tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords')) dataset = tf.data.TFRecordDataset(tfrecords).map( lambda record_bytes: decode_examples(record_bytes, node_shape=[4])) dataset = dataset.map(lambda graph_data_dict, img, c: GraphsTuple( globals=tf.zeros([1]), edges=None, **graph_data_dict)) dataset = dataset.map(lambda graph: graph._replace(nodes=tf.concat( [graph.nodes[:, :3], tf.math.log(graph.nodes[:, 3:])], axis=1))) dataset = dataset.map(lambda graph: (graph, )) # dataset = batch_dataset_set_graph_tuples(all_graphs_same_size=True, dataset=dataset, batch_size=batch_size) return dataset
def to_placeholders(self, batch_size=None): """Creates a placeholder to be fed into a graph_net""" # pylint: disable=protected-access placeholders = utils_tf._build_placeholders_from_specs( dtypes=GraphsTuple( nodes=tf.float64, edges=tf.float64, receivers=tf.int32, senders=tf.int32, globals=tf.float64, n_node=tf.int32, n_edge=tf.int32, ), shapes=GraphsTuple( nodes=[batch_size, self.node_dim], edges=[batch_size, self.edge_dim], receivers=[batch_size], senders=[batch_size], globals=[batch_size, self.global_dim], n_node=[batch_size], n_edge=[batch_size], ), ) def make_feed_dict(val): if isinstance(val, GraphsTuple): graphs_tuple = val else: dicts = [] for graphs_tuple in val: dicts.append( utils_np.graphs_tuple_to_data_dicts(graphs_tuple)[0]) graphs_tuple = utils_np.data_dicts_to_graphs_tuple(dicts) return utils_tf.get_feed_dict(placeholders, graphs_tuple) placeholders.make_feed_dict = make_feed_dict placeholders.name = "Graph observation placeholder" return placeholders
def build_dataset(tfrecords_dirs, batch_size, type='train'): """ Build data set from a directory of tfrecords. With graph batching Args: data_dir: str, path to *.tfrecords Returns: Dataset obj. """ tfrecords = [] for tfrecords_dir in tfrecords_dirs: tfrecords += glob.glob(os.path.join(tfrecords_dir, type, '*.tfrecords')) random.shuffle(tfrecords) print(f'Number of {type} tfrecord files : {len(tfrecords)}') dataset = tf.data.TFRecordDataset(tfrecords).map( partial(decode_examples, node_shape=(10, ), edge_shape=(2, ), image_shape=(256, 256, 1))) # (graph, image, idx) dataset = dataset.map( lambda graph_data_dict, img, cluster_idx, projection_idx, vprime: (GraphsTuple(**graph_data_dict).replace(nodes=tf.concat([ GraphsTuple(**graph_data_dict).nodes[:, :3], GraphsTuple(**graph_data_dict).nodes[:, 6:8] ], axis=-1)), gaussian_filter2d(img))).shuffle(buffer_size=52).batch( batch_size=batch_size) return dataset
def test_compute_accuracy_is_as_expected(self): t_nodes = np.array([[1, 0], [1, 0], [0, 1]], dtype=np.float32) o_nodes = np.array([[0, 1], [1, 0], [1, 0]], dtype=np.float32) t_edges = np.array([[0, 1], [1, 0]], dtype=np.float32) o_edges = np.array([[1, 0], [1, 0]], dtype=np.float32) globals = None senders = np.array([0, 1]) receivers = np.array([1, 2]) n_node = np.array([3]) n_edge = np.array([2]) target = GraphsTuple(nodes=t_nodes, edges=t_edges, globals=globals, receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge) output = GraphsTuple(nodes=o_nodes, edges=o_edges, globals=globals, receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge) correct, solved = compute_accuracy(target, output) expected_correct = 2 / 5 expected_solved = 0 self.assertEqual(expected_correct, correct) self.assertEqual(expected_solved, solved)
def connect_graph_dynamic(graph: GraphsTuple, is_edge_func, name="connect_graph_dynamic"): """ Connects a graph using a boolean edge mask to create edges. Args: graph: GraphsTuple is_edge_func: callable(sender: int, receiver: int) -> bool, should broadcast name: Returns: connected GraphsTuple """ utils_tf._validate_edge_fields_are_all_none(graph) with tf.name_scope(name): def body(i, senders, receivers, n_edge): edges = _create_functional_connect_edges_dynamic(graph.n_node[i], is_edge_func) # edges = create_edges_func(graph.n_node[i]) return (i + 1, senders.write(i, edges['senders']), receivers.write(i, edges['receivers']), n_edge.write(i, edges['n_edge'])) num_graphs = utils_tf.get_num_graphs(graph) loop_condition = lambda i, *_: tf.less(i, num_graphs) initial_loop_vars = [0] + [ tf.TensorArray(dtype=tf.int32, size=num_graphs, infer_shape=False) for _ in range(3) # senders, receivers, n_edge ] _, senders_array, receivers_array, n_edge_array = tf.while_loop(loop_condition, body, initial_loop_vars) n_edge = n_edge_array.concat() offsets = utils_tf._compute_stacked_offsets(graph.n_node, n_edge) senders = senders_array.concat() + offsets receivers = receivers_array.concat() + offsets senders.set_shape(offsets.shape) receivers.set_shape(offsets.shape) receivers.set_shape([None]) senders.set_shape([None]) num_graphs = graph.n_node.get_shape().as_list()[0] n_edge.set_shape([num_graphs]) return graph.replace(senders=tf.stop_gradient(senders), receivers=tf.stop_gradient(receivers), n_edge=tf.stop_gradient(n_edge))
def sample_decoder(self, positions, logits, temperature): token_distribution = tfp.distributions.RelaxedOneHotCategorical(temperature, logits=logits) token_samples_onehot = token_distribution.sample((1,), name='token_samples') token_sample_onehot = token_samples_onehot[0] # [n_node, num_embedding] token_sample = tf.matmul(token_sample_onehot, self.embeddings) # [n_node, embedding_dim] n_node = tf.shape(token_sample)[0] latent_graph = GraphsTuple(nodes=token_sample, edges=None, globals=tf.constant([0.], dtype=tf.float32), senders=None, receivers=None, n_node=tf.constant([n_node], dtype=tf.int32), n_edge=tf.constant([0], dtype=tf.int32)) # [n_node, embedding_dim] latent_graph = fully_connect_graph_dynamic(latent_graph) gaussian_tokens = self.decoder(latent_graph) # nodes=[num_gaussian_components, component_dim] reconstructed_fields = reconstruct_fields_from_gaussians(gaussian_tokens, positions) return reconstructed_fields
def build_dataset(data_dir, batch_size): tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords')) dataset = tf.data.TFRecordDataset(tfrecords).map(partial(decode_examples, node_shape=(11,), image_shape=(256, 256, 1), k=6)) # (graph, image, spsh, proj) dataset = dataset.map(lambda graph_data_dict, img, spsh, proj, e: (graph_data_dict, img)) dataset.batch(batch_size) #batch fixing mechanism dataset = dataset.map(lambda data_dict, image: (batch_graph_data_dict(data_dict), image)) dataset = dataset.map(lambda data_dict, image: (GraphsTuple(**data_dict, edges=None, receivers=None, senders=None, globals=None), image)) dataset = dataset.map(lambda batched_graphs, image: (graph_unbatch_reshape(batched_graphs), image)) # dataset = dataset.cache() return dataset
def _single_decode(token_sample_onehot): """ Args: token_sample: [n_node, embedding_dim] Returns: log_likelihood: scalar kl_term: scalar """ token_sample = tf.matmul( token_sample_onehot, self.embeddings) # [n_node, embedding_dim] # = z ~ q(z|x) latent_graph = GraphsTuple( nodes=token_sample, edges=None, globals=tf.constant([0.], dtype=tf.float32), senders=None, receivers=None, n_node=encoded_graph.n_node, n_edge=tf.constant([0], dtype=tf.int32)) # [n_node, embedding_dim] latent_graph = fully_connect_graph_dynamic(latent_graph) # print('\n latent_graph', latent_graph, '\n') gaussian_tokens = self.decoder( latent_graph) # nodes=[num_gaussian_components, component_dim] # print('\n gaussian_tokens_nodes', gaussian_tokens.nodes, '\n') _, log_likelihood = gaussian_loss_function(gaussian_tokens.nodes, graph) # [n_node, num_embeddings].[n_node, num_embeddings] sum_selected_logits = tf.math.reduce_sum(token_sample_onehot * logits, axis=1) # [n_node] # print('sum', sum_selected_logits) # print('norm', log_norm) # print('num_embed', tf.cast(self.num_embedding, tf.float32)) # print('embed', tf.math.log(tf.cast(self.num_embedding, tf.float32))) kl_term = sum_selected_logits - self.num_embedding * log_norm + \ self.num_embedding * tf.math.log(tf.cast(self.num_embedding, tf.float32)) # print('kl_term 0', kl_term) # print('kl_term', tf.reduce_mean(kl_term)) kl_term = self.beta * tf.reduce_mean(kl_term) return log_likelihood, kl_term
def decode_examples(record_bytes, node_shape=None, edge_shape=None, image_shape=None): """ Decodes raw bytes as returned from tf.data.TFRecordDataset([example_path]) into a GraphTuple and image Args: record_bytes: raw bytes node_shape: shape of nodes if known. edge_shape: shape of edges if known. image_shape: shape of image if known. Returns: (GraphTuple, image) """ parsed_example = tf.io.parse_single_example( # Data record_bytes, # Schema dict( image=tf.io.FixedLenFeature([], dtype=tf.string), **feature_to_graph_tuple('graph') ) ) image = tf.io.parse_tensor(parsed_example['image'], tf.float32) image.set_shape(image_shape) graph_nodes = tf.io.parse_tensor(parsed_example['graph_nodes'], tf.float32) if node_shape is not None: graph_nodes.set_shape([None] + list(node_shape)) graph_edges = tf.io.parse_tensor(parsed_example['graph_edges'], tf.float32) if edge_shape is not None: graph_edges.set_shape([None] + list(edge_shape)) receivers = tf.io.parse_tensor(parsed_example['graph_receivers'], tf.int64) receivers.set_shape([None]) senders = tf.io.parse_tensor(parsed_example['graph_senders'], tf.int64) senders.set_shape([None]) graph = GraphsTuple(nodes=graph_nodes, edges=graph_edges, globals=None, receivers=receivers, senders=senders, n_node=tf.shape(graph_nodes)[0:1], n_edge=tf.shape(graph_edges)[0:1]) return (graph, image)
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple: """Returns randomly rotated graph representation. The rotation is an element of O(3) with rotation angles multiple of pi/2. This function assumes that the relative particle distances are stored in the edge features. Args: graph: The graphs tuple as defined in `graph_nets.graphs`. """ # Transposes edge features, so that the axes are in the first dimension. # Outputs a tensor of shape [3, n_particles]. xyz = tf.transpose(graph.edges) # Random pi/2 rotation(s) permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32)) xyz = tf.gather(xyz, permutation) # Random reflections. symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32) symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32) xyz = xyz * symmetry edges = tf.transpose(xyz) return graph.replace(edges=edges)
def _build(self, batch, *args, **kwargs): (graph, img, c) = batch del c encoded_graph = self.encoder_graph(graph) tf.summary.image(f'img_before_cnn', img[None, ...], step=self.step) img = self.image_cnn(img[None, ...]) for channel in range(img.shape[-1]): tf.summary.image(f'img_after_cnn[{channel}]', img[..., channel:channel + 1], step=self.step) #1, w,h,c -> w*h, c nodes = tf.reshape(img, (-1, self.image_feature_size)) img_graph = GraphsTuple(nodes=nodes, edges=None, globals=None, receivers=None, senders=None, n_node=tf.shape(nodes)[0:1], n_edge=tf.constant([0])) connected_graph = fully_connect_graph_dynamic(img_graph) encoded_img = self.encoder_image(connected_graph) return self.compare( tf.concat([encoded_graph.globals, encoded_img.globals], axis=1)) #[1]
def build_dataset(tfrecords, batch_size): """ Build data set from a directory of tfrecords. With graph batching Args: data_dir: str, path to *.tfrecords Returns: Dataset obj. """ dataset = tf.data.TFRecordDataset(tfrecords).map( partial(decode_examples, node_shape=(10, ), edge_shape=(2, ), image_shape=(1024, 1024, 1))) # (graph, image, idx) dataset = dataset.map(lambda graph_data_dict, img, cluster_idx, projection_idx, vprime: graph_data_dict).shuffle( buffer_size=50) dataset = dataset.map( lambda graph_data_dict: GraphsTuple(**graph_data_dict)) # dataset = batch_dataset_set_graph_tuples(all_graphs_same_size=True, dataset=dataset, batch_size=batch_size) return dataset