def test_autoregressive_connect_graph_dynamic():
    graphs = GraphsTuple(nodes=tf.range(20),
                         n_node=tf.constant([12, 8]),
                         n_edge=tf.constant([0, 0]),
                         edges=None,
                         receivers=None,
                         senders=None,
                         globals=None)
    graphs = GraphsTuple(nodes=tf.range(6),
                         n_node=tf.constant([6, 0]),
                         n_edge=tf.constant([0, 0]),
                         edges=None,
                         receivers=None,
                         senders=None,
                         globals=None)
    graphs = autoregressive_connect_graph_dynamic(graphs,
                                                  exclude_self_edges=False)
    import networkx as nx
    G = nx.MultiDiGraph()
    for sender, receiver in zip(graphs.senders.numpy(),
                                graphs.receivers.numpy()):
        G.add_edge(sender, receiver)
    nx.drawing.draw_circular(
        G,
        with_labels=True,
        node_color=(0, 0, 0),
        font_color=(1, 1, 1),
        font_size=25,
        node_size=1000,
        arrowsize=30,
    )
    import pylab as plt
    plt.show()
Exemple #2
0
def test_batch_reshape():
    data = dict(nodes=tf.reshape(tf.range(3 * 2), (3, 2)),
                edges=tf.reshape(tf.range(40 * 5), (40, 5)),
                senders=tf.random.uniform((40, ),
                                          minval=0,
                                          maxval=3,
                                          dtype=tf.int32),
                receivers=tf.random.uniform((40, ),
                                            minval=0,
                                            maxval=3,
                                            dtype=tf.int32),
                n_node=tf.constant([3]),
                n_edge=tf.constant([40]),
                globals=None)
    graph = GraphsTuple(**data)
    graphs = utils_tf.concat([graph] * 4, axis=0)
    batched_graphs = graph_batch_reshape(graphs)
    assert tf.reduce_all(
        batched_graphs.nodes[0] == batched_graphs.nodes[1]).numpy()
    assert tf.reduce_all(
        batched_graphs.edges[0] == batched_graphs.edges[1]).numpy()
    assert tf.reduce_all(
        batched_graphs.senders[0] +
        graphs.n_node[0] == batched_graphs.senders[1]).numpy()
    assert tf.reduce_all(
        batched_graphs.receivers[0] +
        graphs.n_node[0] == batched_graphs.receivers[1]).numpy()

    # print(batched_graphs)
    unbatched_graphs = graph_unbatch_reshape(batched_graphs)
    for (t1, t2) in zip(graphs, unbatched_graphs):
        if t1 is not None:
            assert tf.reduce_all(t1 == t2).numpy()
Exemple #3
0
    def _build(self, graph, crossing_steps):
        n_non_components = graph.n_node
        latent = self.node_block(graph)

        for _ in range(self.num_components):
            latent = self.edge_block(latent)
            latent = self.global_block(latent)

            component = latent.globals
            new_nodes = tf.concat([latent.nodes, component], axis=0)
            n_nodes = latent.n_node
            new_senders = tf.concat([
                tf.range(n_nodes + 1),
                tf.fill(dims=(n_nodes + 1), value=n_nodes)
            ])
            new_receivers = tf.reverse(new_senders, axis=0)
            new_edges = tf.fill(dims=((n_nodes + 1)**2), value=0.)
            n_edges = (n_nodes + 1)**2

            latent = GraphsTuple(nodes=new_nodes,
                                 edges=new_edges,
                                 globals=None,
                                 senders=new_senders,
                                 receivers=new_receivers,
                                 n_node=n_nodes + 1,
                                 n_edge=n_edges)

        gaussian_components = latent.nodes[n_non_components:]

        return gaussian_components
Exemple #4
0
def nearest_neighbours_connected_graph(virtual_positions, k):
    kdtree = cKDTree(virtual_positions)
    dist, idx = kdtree.query(virtual_positions, k=k + 1)
    receivers = idx[:, 1:]  # N,k
    senders = np.arange(virtual_positions.shape[0])  # N
    senders = np.tile(senders[:, None], [1, k])  # N,k

    receivers = receivers.flatten()
    senders = senders.flatten()

    graph_nodes = tf.convert_to_tensor(virtual_positions, tf.float32)
    graph_nodes.set_shape([None, 3])
    receivers = tf.convert_to_tensor(receivers, tf.int32)
    receivers.set_shape([None])
    senders = tf.convert_to_tensor(senders, tf.int32)
    senders.set_shape([None])
    n_node = tf.shape(graph_nodes)[0:1]
    n_edge = tf.shape(senders)[0:1]

    graph_data_dict = dict(nodes=graph_nodes,
                           edges=tf.zeros((n_edge[0], 1)),
                           globals=tf.zeros([1]),
                           receivers=receivers,
                           senders=senders,
                           n_node=n_node,
                           n_edge=n_edge)

    return GraphsTuple(**graph_data_dict)
  def _init_call_func(self, observations, training=False):
    """Graph nets implementation."""
    node_vals, edge_indices, node_to_graph, edge_vals = GraphObserver.graph(
      observations=observations,
      graph_dims=self._graph_dims,
      dense=True)
    batch_size = tf.shape(observations)[0]
    node_counts = tf.unique_with_counts(node_to_graph)[2]
    edge_counts = tf.math.square(node_counts)

    input_graph = GraphsTuple(
      nodes=tf.cast(node_vals, tf.float32),
      edges=tf.cast(edge_vals, tf.float32),
      globals=tf.tile([[0.]], [batch_size, 1]),
      receivers=tf.cast(edge_indices[:, 1], tf.int32),
      senders=tf.cast(edge_indices[:, 0], tf.int32),
      n_node=node_counts,
      n_edge=edge_counts)

    self._latent_trace = []
    latent = input_graph
    for gb in self._graph_blocks:
      latent = gb(latent)
      self._latent_trace.append(latent)
    node_values = tf.reshape(latent.nodes, [batch_size, -1, self._embedding_size])
    return node_values
Exemple #6
0
def build_dataset(data_dir):
    """
    Build data set from a directory of tfrecords.

    Args:
        data_dir: str, path to *.tfrecords

    Returns: Dataset obj.
    """
    tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords'))
    dataset = tf.data.TFRecordDataset(tfrecords).map(partial(decode_examples_old,
                                                             node_shape=(11,),
                                                             image_shape=(256, 256, 1)))  # (graph, image, spsh, proj)
    _graphs = dataset.map(lambda graph_data_dict, img, spsh, proj: (graph_data_dict, spsh, proj)).shuffle(buffer_size=50)
    _images = dataset.map(lambda graph_data_dict, img, spsh, proj: (img, spsh, proj)).shuffle(buffer_size=50)
    shuffled_dataset = tf.data.Dataset.zip((_graphs, _images))  # ((graph_data_dict, idx1), (img, idx2))
    shuffled_dataset = shuffled_dataset.map(lambda ds1, ds2: (ds1[0], ds2[0], (ds1[1] == ds2[1]) and
                                                              (ds1[2] == ds2[2])))  # (graph, img, yes/no)
    shuffled_dataset = shuffled_dataset.filter(lambda graph_data_dict, img, c: ~c)
    shuffled_dataset = shuffled_dataset.map(lambda graph_data_dict, img, c: (graph_data_dict, img, tf.cast(c, tf.int32)))
    nonshuffeled_dataset = dataset.map(
        lambda graph_data_dict, img, spsh, proj : (graph_data_dict, img, tf.constant(1, dtype=tf.int32)))  # (graph, img, yes)
    dataset = tf.data.experimental.sample_from_datasets([shuffled_dataset, nonshuffeled_dataset])
    dataset = dataset.map(lambda graph_data_dict, img, c: (GraphsTuple(**graph_data_dict), img, c))

    # dataset = batch_dataset_set_graph_tuples(all_graphs_same_size=True, dataset=dataset, batch_size=16)

    return dataset
Exemple #7
0
def build_dataset(data_dir):
    """
    Build data set from a directory of tfrecords.

    Args:
        data_dir: str, path to *.tfrecords

    Returns: Dataset obj.
    """
    tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords'))
    dataset = tf.data.TFRecordDataset(tfrecords).map(
        partial(decode_examples, node_shape=(4, ),
                image_shape=(18, 18, 1)))  # (graph_data_dict, image, idx)
    _graphs = dataset.map(lambda graph_data_dict, img, idx:
                          (graph_data_dict, idx)).shuffle(buffer_size=50)
    _images = dataset.map(lambda graph_data_dict, img, idx:
                          (img, idx)).shuffle(buffer_size=50)
    shuffled_dataset = tf.data.Dataset.zip(
        (_graphs, _images))  # ((graph_data_dict, idx1), (img, idx2))
    shuffled_dataset = shuffled_dataset.map(
        lambda ds1, ds2:
        (ds1[0], ds2[0], ds1[1] == ds2[1]))  # (graph, img, yes/no)
    shuffled_dataset = shuffled_dataset.filter(lambda graph, img, c: ~c)
    shuffled_dataset = shuffled_dataset.map(lambda graph_data_dict, img, c: (
        graph_data_dict, img, tf.cast(c, tf.int32)))
    nonshuffeled_dataset = dataset.map(
        lambda graph_data_dict, img, idx:
        (graph_data_dict, img, tf.constant(1, dtype=tf.int32)
         ))  # (graph, img, yes)
    dataset = tf.data.experimental.sample_from_datasets(
        [shuffled_dataset, nonshuffeled_dataset])
    dataset = dataset.map(lambda graph_data_dict, img, c: (GraphsTuple(
        globals=None, edges=None, **graph_data_dict), img, c))
    return dataset
Exemple #8
0
    def construct_input_graph(self, input_sequence, N2):
        G, N = get_shape(input_sequence)
        # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3, embedding_dim
        input_tokens = tf.nn.embedding_lookup(self.embeddings, input_sequence)
        self.initialize_positional_encodings(input_tokens)
        nodes = input_tokens + self.positional_encodings
        n_node = tf.fill([G], N)
        n_edge = tf.zeros_like(n_node)
        data_dict = dict(nodes=nodes, edges=None, senders=None, receivers=None, globals=None,
                         n_node=n_node,
                         n_edge=n_edge)
        concat_graphs = GraphsTuple(**data_dict)
        concat_graphs = graph_unbatch_reshape(concat_graphs)  # [n_graphs * (num_input + num_output), embedding_size]
        # nodes, senders, receivers, globals
        def edge_connect_rule(sender, receiver):
            # . a . b -> a . b .
            complete_2d = (sender < N2 + 1) & (receiver < N2 + 1) & (
                        sender + 1 != receiver)  # exclude senders from one-right, so it doesn't learn copy.
            auto_regressive_3d = (sender <= receiver) & (
                        receiver >= N2 + 1)  # auto-regressive (excluding 2d) with self-loops
            return complete_2d | auto_regressive_3d

        # nodes, senders, receivers, globals
        concat_graphs = connect_graph_dynamic(concat_graphs, edge_connect_rule)
        return concat_graphs
Exemple #9
0
    def test_kgcn_runs(self):
        tf.enable_eager_execution()

        graph = GraphsTuple(
            nodes=tf.convert_to_tensor(
                np.array([[1, 2, 0], [1, 0, 0], [1, 1, 0]], dtype=np.float32)),
            edges=tf.convert_to_tensor(
                np.array([[1, 0, 0], [1, 0, 0]], dtype=np.float32)),
            globals=tf.convert_to_tensor(
                np.array([[0, 0, 0, 0, 0]], dtype=np.float32)),
            receivers=tf.convert_to_tensor(np.array([1, 2], dtype=np.int32)),
            senders=tf.convert_to_tensor(np.array([0, 1], dtype=np.int32)),
            n_node=tf.convert_to_tensor(np.array([3], dtype=np.int32)),
            n_edge=tf.convert_to_tensor(np.array([2], dtype=np.int32)))

        attr_embedders = {
            lambda: lambda x: tf.constant(np.zeros((3, 6), dtype=np.float32)):
            [0, 1, 2]
        }
        kgcn = KGCN(3,
                    2,
                    5,
                    6,
                    attr_embedders,
                    edge_output_size=3,
                    node_output_size=3)

        kgcn(graph, 2)
Exemple #10
0
    def test_kgcn_runs(self):
        tf.enable_eager_execution()

        graph = GraphsTuple(
            nodes=tf.convert_to_tensor(
                np.array([[1, 2, 0], [1, 0, 0], [1, 1, 0]], dtype=np.float32)),
            edges=tf.convert_to_tensor(
                np.array([[1, 0, 0], [1, 0, 0]], dtype=np.float32)),
            globals=tf.convert_to_tensor(
                np.array([[0, 0, 0, 0, 0]], dtype=np.float32)),
            receivers=tf.convert_to_tensor(np.array([1, 2], dtype=np.int32)),
            senders=tf.convert_to_tensor(np.array([0, 1], dtype=np.int32)),
            n_node=tf.convert_to_tensor(np.array([3], dtype=np.int32)),
            n_edge=tf.convert_to_tensor(np.array([2], dtype=np.int32)))

        thing_embedder = ThingEmbedder(node_types=['a', 'b', 'c'],
                                       type_embedding_dim=5,
                                       attr_embedding_dim=6,
                                       categorical_attributes={
                                           'a': ['a1', 'a2', 'a3'],
                                           'b': ['b1', 'b2', 'b3']
                                       },
                                       continuous_attributes={'c': (0, 1)})

        role_embedder = RoleEmbedder(num_edge_types=2, type_embedding_dim=5)

        kgcn = KGCN(thing_embedder,
                    role_embedder,
                    edge_output_size=3,
                    node_output_size=3)

        kgcn(graph, 2)
Exemple #11
0
    def test_plot_is_created(self):
        num_processing_steps_ge = 6

        graph = nx.MultiDiGraph(name=0)

        existing = dict(solution=0)
        to_infer = dict(solution=2)
        candidate = dict(solution=1)

        # people
        graph.add_node(0, type='person', **existing)
        graph.add_node(1, type='person', **candidate)

        # parentships
        graph.add_node(2, type='parentship', **to_infer)
        graph.add_edge(2, 0, type='parent', **to_infer)
        graph.add_edge(2, 1, type='child', **candidate)

        graph_tuple_target = GraphsTuple(nodes=np.array([[1., 0., 0.],
                                                         [0., 1., 0.],
                                                         [0., 0., 1.]]),
                                         edges=np.array([[0., 0., 1.],
                                                         [0., 1., 0.]]),
                                         receivers=np.array([1, 2], dtype=np.int32),
                                         senders=np.array([0, 1], dtype=np.int32),
                                         globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32),
                                         n_node=np.array([3], dtype=np.int32),
                                         n_edge=np.array([2], dtype=np.int32))

        graph_tuple_output = GraphsTuple(nodes=np.array([[1., 0., 0.],
                                                         [1., 1., 0.],
                                                         [1., 0., 1.]]),
                                         edges=np.array([[1., 0., 0.],
                                                         [1., 1., 0.]]),
                                         receivers=np.array([1, 2], dtype=np.int32),
                                         senders=np.array([0, 1], dtype=np.int32),
                                         globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32),
                                         n_node=np.array([3], dtype=np.int32),
                                         n_edge=np.array([2], dtype=np.int32))

        test_values = {"target": graph_tuple_target, "outputs": [graph_tuple_output for _ in range(6)]}

        filename = f'./graph_{datetime.datetime.now()}.png'

        plot_predictions([graph], test_values, num_processing_steps_ge, output_file=filename)

        self.assertTrue(os.path.isfile(filename))
Exemple #12
0
def build_dataset(tfrecords):
    # Extract the dataset (graph tuple, image, example_idx) from the tfrecords files
    dataset = tf.data.TFRecordDataset(tfrecords).map(
        partial(decode_examples,
                node_shape=(10, ),
                edge_shape=(2, ),
                image_shape=(1000, 1000, 1)))  # (graph, image, idx)
    # Take the graphs and their corresponding index and shuffle the order of these pairs
    # Do the same for the images
    _graphs = dataset.map(
        lambda graph_data_dict, img, cluster_idx, projection_idx, vprime:
        (graph_data_dict, 26 * cluster_idx + projection_idx)).shuffle(
            buffer_size=260)  # .replace(globals=tf.zeros((1, 1)))
    _images = dataset.map(
        lambda graph_data_dict, img, cluster_idx, projection_idx, vprime:
        (img, 26 * cluster_idx + projection_idx)).shuffle(buffer_size=260)

    # Zip the shuffled datasets back together so typically the index of the graph and image don't match.
    shuffled_dataset = tf.data.Dataset.zip(
        (_graphs, _images))  # ((graph, idx1), (img, idx2))

    # Reshape the dataset to the graph, the image and a yes or no whether the indices are the same
    # So ((graph, idx1), (img, idx2)) --> (graph, img, True/False)
    shuffled_dataset = shuffled_dataset.map(
        lambda ds1, ds2:
        (ds1[0], ds2[0], ds1[1] == ds2[1]))  # (graph, img, yes/no)

    # Take the subset of the data where the graph and image don't correspond (which is most of the dataset, since it's shuffled)
    shuffled_dataset = shuffled_dataset.filter(
        lambda graph_data_dict, img, c: ~c)

    # Transform the True/False class into 1/0 integer
    shuffled_dataset = shuffled_dataset.map(lambda graph_data_dict, img, c: (
        GraphsTuple(**graph_data_dict), img, tf.cast(c, tf.int32)))

    # Use the original dataset where all indices correspond and give them class True and turn that into an integer
    # So every instance gets class 1
    nonshuffeled_dataset = dataset.map(
        lambda graph_data_dict, img, cluster_idx, projection_idx, vprime:
        (GraphsTuple(**graph_data_dict), img, tf.constant(1, dtype=tf.int32)
         ))  # (graph, img, yes)

    # For the training data, take a sample either from the correct or incorrect combinations of graphs and images
    nn_dataset = tf.data.experimental.sample_from_datasets(
        [shuffled_dataset, nonshuffeled_dataset])
    return nn_dataset
Exemple #13
0
    def destandardize_graphs_tuple(self, graphs: GraphsTuple) -> GraphsTuple:
        standard_graphs = graphs.replace(globals=self._destandardize(
            graphs.globals, mean=self.global_mean, std=self.global_std))
        standard_graphs = standard_graphs.replace(nodes=self._destandardize(
            graphs.nodes, mean=self.nodes_mean, std=self.nodes_std))
        standard_graphs = standard_graphs.replace(edges=self._destandardize(
            graphs.edges, mean=self.edges_mean, std=self.edges_std))

        return standard_graphs
Exemple #14
0
 def _build(self, graph: GraphsTuple, num_processing_steps):
     # give edges and globals to graph from nodes
     graph = self._first_block(graph)
     output = []
     for i in range(num_processing_steps):
         graph = self._message_passing_graph(graph)
         graph = graph._replace(globals=self._global_block(graph).globals +
                                graph.globals)
         output.append(graph)
     return output
 def _get_graph_tuple(input_tensors: List[tf.Tensor]) -> GraphsTuple:
     return GraphsTuple(
         nodes=input_tensors[0],
         edges=input_tensors[1],
         senders=input_tensors[2],
         receivers=input_tensors[3],
         globals=tf.expand_dims(input_tensors[4], axis=0),
         n_node=tf.convert_to_tensor([input_tensors[0].shape[0]]),
         n_edge=tf.convert_to_tensor([input_tensors[1].shape[0]]),
     )
Exemple #16
0
 def _build(self, graph: GraphsTuple, positions: GraphsTuple,
            num_processing_steps):
     # give edges and globals to graph from nodes
     graph = self._first_block(graph)
     outputs = []
     for i in range(num_processing_steps):
         graph = graph._replace(
             nodes=tf.concat([positions.nodes, graph.nodes], axis=-1))
         graph = self._message_passing_graph(graph)
         outputs.append(self._property_block(graph))
     return outputs
Exemple #17
0
    def _build(self, batch, *args, **kwargs):
        (graph, img, c) = batch
        del c
        # The encoded cluster graph has globals which can be compared against the encoded image graph
        encoded_graph = self.epd_graph(graph, self._core_steps)

        # Add an extra dimension to the image (tf.summary expects a Tensor of rank 4)
        img = img[None, ...]
        im_before_cnn = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))
        tf.summary.image(f'img_before_cnn', im_before_cnn, step=self.step)

        img = self.auto_encoder.encoder(img)

        # Prevent the autoencoder from learning
        try:
            for variable in self.auto_encoder.encoder.trainable_variables:
                variable._trainable = False
            for variable in self.auto_encoder.decoder.trainable_variables:
                variable._trainable = False
        except:
            pass

        img_after_autoencoder = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))
        tf.summary.image(f'img_after_autoencoder', tf.transpose(img_after_autoencoder, [3, 1, 2, 0]), step=self.step)

        decoded_img = self.auto_encoder.decoder(img)
        decoded_img = (decoded_img - tf.reduce_min(decoded_img)) / (tf.reduce_max(decoded_img) - tf.reduce_min(decoded_img))
        tf.summary.image(f'decoded_img', decoded_img, step=self.step)

        # Reshape the encoded image so it can be used for the nodes
        #1, w,h,c -> w*h, c
        nodes = tf.reshape(img, (-1,self.image_feature_size))

        # Create a graph that has a node for every encoded pixel. The features of each node
        # are the channels of the corresponding pixel. Then connect each node with every other
        # node.
        img_graph = GraphsTuple(nodes=nodes,
                            edges=None,
                            globals=None,
                            receivers=None,
                            senders=None,
                            n_node=tf.shape(nodes)[0:1],
                            n_edge=tf.constant([0]))
        connected_graph = fully_connect_graph_dynamic(img_graph)

        # The encoded image graph has globals which can be compared against the encoded cluster graph
        encoded_img = self.epd_image(connected_graph, 1)

        # Compare the globals from the encoded cluster graph and encoded image graph
        # to estimate the similarity between the input graph and input image
        distance = self.compare(tf.concat([encoded_graph.globals, encoded_img.globals], axis=1)) + self.compare(
            tf.concat([encoded_img.globals, encoded_graph.globals], axis=1))

        return distance
Exemple #18
0
def _build_dataset(data_dir):
    tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords'))

    dataset = tf.data.TFRecordDataset(tfrecords).map(
        partial(decode_examples_old,
                node_shape=(11, ),
                image_shape=(256, 256, 1)))  # (graph, image, spsh, proj)

    dataset = dataset.map(lambda graph_data_dict, img, spsh, proj:
                          (GraphsTuple(**graph_data_dict), img))

    return dataset
Exemple #19
0
def build_dataset(data_dir, batch_size):
    tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords'))
    dataset = tf.data.TFRecordDataset(tfrecords).map(
        lambda record_bytes: decode_examples(record_bytes, node_shape=[4]))
    dataset = dataset.map(lambda graph_data_dict, img, c: GraphsTuple(
        globals=tf.zeros([1]), edges=None, **graph_data_dict))
    dataset = dataset.map(lambda graph: graph._replace(nodes=tf.concat(
        [graph.nodes[:, :3],
         tf.math.log(graph.nodes[:, 3:])], axis=1)))
    dataset = dataset.map(lambda graph: (graph, ))
    # dataset = batch_dataset_set_graph_tuples(all_graphs_same_size=True, dataset=dataset, batch_size=batch_size)
    return dataset
    def to_placeholders(self, batch_size=None):
        """Creates a placeholder to be fed into a graph_net"""
        # pylint: disable=protected-access
        placeholders = utils_tf._build_placeholders_from_specs(
            dtypes=GraphsTuple(
                nodes=tf.float64,
                edges=tf.float64,
                receivers=tf.int32,
                senders=tf.int32,
                globals=tf.float64,
                n_node=tf.int32,
                n_edge=tf.int32,
            ),
            shapes=GraphsTuple(
                nodes=[batch_size, self.node_dim],
                edges=[batch_size, self.edge_dim],
                receivers=[batch_size],
                senders=[batch_size],
                globals=[batch_size, self.global_dim],
                n_node=[batch_size],
                n_edge=[batch_size],
            ),
        )

        def make_feed_dict(val):
            if isinstance(val, GraphsTuple):
                graphs_tuple = val
            else:
                dicts = []
                for graphs_tuple in val:
                    dicts.append(
                        utils_np.graphs_tuple_to_data_dicts(graphs_tuple)[0])
                graphs_tuple = utils_np.data_dicts_to_graphs_tuple(dicts)
            return utils_tf.get_feed_dict(placeholders, graphs_tuple)

        placeholders.make_feed_dict = make_feed_dict
        placeholders.name = "Graph observation placeholder"
        return placeholders
Exemple #21
0
def build_dataset(tfrecords_dirs, batch_size, type='train'):
    """
    Build data set from a directory of tfrecords. With graph batching

    Args:
        data_dir: str, path to *.tfrecords

    Returns: Dataset obj.
    """

    tfrecords = []

    for tfrecords_dir in tfrecords_dirs:
        tfrecords += glob.glob(os.path.join(tfrecords_dir, type,
                                            '*.tfrecords'))

    random.shuffle(tfrecords)

    print(f'Number of {type} tfrecord files : {len(tfrecords)}')

    dataset = tf.data.TFRecordDataset(tfrecords).map(
        partial(decode_examples,
                node_shape=(10, ),
                edge_shape=(2, ),
                image_shape=(256, 256, 1)))  # (graph, image, idx)

    dataset = dataset.map(
        lambda graph_data_dict, img, cluster_idx, projection_idx, vprime:
        (GraphsTuple(**graph_data_dict).replace(nodes=tf.concat([
            GraphsTuple(**graph_data_dict).nodes[:, :3],
            GraphsTuple(**graph_data_dict).nodes[:, 6:8]
        ],
                                                                axis=-1)),
         gaussian_filter2d(img))).shuffle(buffer_size=52).batch(
             batch_size=batch_size)

    return dataset
Exemple #22
0
    def test_compute_accuracy_is_as_expected(self):

        t_nodes = np.array([[1, 0], [1, 0], [0, 1]], dtype=np.float32)
        o_nodes = np.array([[0, 1], [1, 0], [1, 0]], dtype=np.float32)
        t_edges = np.array([[0, 1], [1, 0]], dtype=np.float32)
        o_edges = np.array([[1, 0], [1, 0]], dtype=np.float32)

        globals = None
        senders = np.array([0, 1])
        receivers = np.array([1, 2])
        n_node = np.array([3])
        n_edge = np.array([2])

        target = GraphsTuple(nodes=t_nodes,
                             edges=t_edges,
                             globals=globals,
                             receivers=receivers,
                             senders=senders,
                             n_node=n_node,
                             n_edge=n_edge)

        output = GraphsTuple(nodes=o_nodes,
                             edges=o_edges,
                             globals=globals,
                             receivers=receivers,
                             senders=senders,
                             n_node=n_node,
                             n_edge=n_edge)

        correct, solved = compute_accuracy(target, output)

        expected_correct = 2 / 5
        expected_solved = 0

        self.assertEqual(expected_correct, correct)
        self.assertEqual(expected_solved, solved)
Exemple #23
0
def connect_graph_dynamic(graph: GraphsTuple, is_edge_func, name="connect_graph_dynamic"):
    """
    Connects a graph using a boolean edge mask to create edges.

    Args:
        graph: GraphsTuple
        is_edge_func: callable(sender: int, receiver: int) -> bool, should broadcast
        name:

    Returns:
        connected GraphsTuple
    """
    utils_tf._validate_edge_fields_are_all_none(graph)

    with tf.name_scope(name):
        def body(i, senders, receivers, n_edge):
            edges = _create_functional_connect_edges_dynamic(graph.n_node[i], is_edge_func)
            # edges = create_edges_func(graph.n_node[i])
            return (i + 1, senders.write(i, edges['senders']),
                    receivers.write(i, edges['receivers']),
                    n_edge.write(i, edges['n_edge']))

        num_graphs = utils_tf.get_num_graphs(graph)
        loop_condition = lambda i, *_: tf.less(i, num_graphs)
        initial_loop_vars = [0] + [
            tf.TensorArray(dtype=tf.int32, size=num_graphs, infer_shape=False)
            for _ in range(3)  # senders, receivers, n_edge
        ]
        _, senders_array, receivers_array, n_edge_array = tf.while_loop(loop_condition, body, initial_loop_vars)

        n_edge = n_edge_array.concat()
        offsets = utils_tf._compute_stacked_offsets(graph.n_node, n_edge)
        senders = senders_array.concat() + offsets
        receivers = receivers_array.concat() + offsets
        senders.set_shape(offsets.shape)
        receivers.set_shape(offsets.shape)

        receivers.set_shape([None])
        senders.set_shape([None])

        num_graphs = graph.n_node.get_shape().as_list()[0]
        n_edge.set_shape([num_graphs])

        return graph.replace(senders=tf.stop_gradient(senders),
                             receivers=tf.stop_gradient(receivers),
                             n_edge=tf.stop_gradient(n_edge))
Exemple #24
0
 def sample_decoder(self, positions, logits, temperature):
     token_distribution = tfp.distributions.RelaxedOneHotCategorical(temperature, logits=logits)
     token_samples_onehot = token_distribution.sample((1,),
                                                      name='token_samples')
     token_sample_onehot = token_samples_onehot[0]  # [n_node, num_embedding]
     token_sample = tf.matmul(token_sample_onehot, self.embeddings)  # [n_node, embedding_dim]
     n_node = tf.shape(token_sample)[0]
     latent_graph = GraphsTuple(nodes=token_sample,
                                edges=None,
                                globals=tf.constant([0.], dtype=tf.float32),
                                senders=None,
                                receivers=None,
                                n_node=tf.constant([n_node], dtype=tf.int32),
                                n_edge=tf.constant([0], dtype=tf.int32))  # [n_node, embedding_dim]
     latent_graph = fully_connect_graph_dynamic(latent_graph)
     gaussian_tokens = self.decoder(latent_graph)  # nodes=[num_gaussian_components, component_dim]
     reconstructed_fields = reconstruct_fields_from_gaussians(gaussian_tokens, positions)
     return reconstructed_fields
Exemple #25
0
def build_dataset(data_dir, batch_size):
    tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords'))

    dataset = tf.data.TFRecordDataset(tfrecords).map(partial(decode_examples,
                                                             node_shape=(11,),
                                                             image_shape=(256, 256, 1),
                                                             k=6))  # (graph, image, spsh, proj)

    dataset = dataset.map(lambda graph_data_dict, img, spsh, proj, e: (graph_data_dict, img))

    dataset.batch(batch_size)
    #batch fixing mechanism
    dataset = dataset.map(lambda data_dict, image: (batch_graph_data_dict(data_dict), image))
    dataset = dataset.map(lambda data_dict, image: (GraphsTuple(**data_dict,
                                                                edges=None, receivers=None, senders=None, globals=None), image))
    dataset = dataset.map(lambda batched_graphs, image: (graph_unbatch_reshape(batched_graphs), image))
    # dataset = dataset.cache()
    return dataset
        def _single_decode(token_sample_onehot):
            """

            Args:
                token_sample: [n_node, embedding_dim]

            Returns:
                log_likelihood: scalar
                kl_term: scalar
            """
            token_sample = tf.matmul(
                token_sample_onehot,
                self.embeddings)  # [n_node, embedding_dim]  # = z ~ q(z|x)
            latent_graph = GraphsTuple(
                nodes=token_sample,
                edges=None,
                globals=tf.constant([0.], dtype=tf.float32),
                senders=None,
                receivers=None,
                n_node=encoded_graph.n_node,
                n_edge=tf.constant([0],
                                   dtype=tf.int32))  # [n_node, embedding_dim]
            latent_graph = fully_connect_graph_dynamic(latent_graph)
            # print('\n latent_graph', latent_graph, '\n')
            gaussian_tokens = self.decoder(
                latent_graph)  # nodes=[num_gaussian_components, component_dim]
            # print('\n gaussian_tokens_nodes', gaussian_tokens.nodes, '\n')
            _, log_likelihood = gaussian_loss_function(gaussian_tokens.nodes,
                                                       graph)
            # [n_node, num_embeddings].[n_node, num_embeddings]
            sum_selected_logits = tf.math.reduce_sum(token_sample_onehot *
                                                     logits,
                                                     axis=1)  # [n_node]
            # print('sum', sum_selected_logits)
            # print('norm', log_norm)
            # print('num_embed', tf.cast(self.num_embedding, tf.float32))
            # print('embed', tf.math.log(tf.cast(self.num_embedding, tf.float32)))
            kl_term = sum_selected_logits - self.num_embedding * log_norm + \
                      self.num_embedding * tf.math.log(tf.cast(self.num_embedding, tf.float32))
            # print('kl_term 0', kl_term)
            # print('kl_term', tf.reduce_mean(kl_term))
            kl_term = self.beta * tf.reduce_mean(kl_term)
            return log_likelihood, kl_term
Exemple #27
0
def decode_examples(record_bytes, node_shape=None, edge_shape=None, image_shape=None):
    """
    Decodes raw bytes as returned from tf.data.TFRecordDataset([example_path]) into a GraphTuple and image
    Args:
        record_bytes: raw bytes
        node_shape: shape of nodes if known.
        edge_shape: shape of edges if known.
        image_shape: shape of image if known.

    Returns: (GraphTuple, image)
    """
    parsed_example = tf.io.parse_single_example(
        # Data
        record_bytes,

        # Schema
        dict(
            image=tf.io.FixedLenFeature([], dtype=tf.string),
            **feature_to_graph_tuple('graph')
        )
    )
    image = tf.io.parse_tensor(parsed_example['image'], tf.float32)
    image.set_shape(image_shape)
    graph_nodes = tf.io.parse_tensor(parsed_example['graph_nodes'], tf.float32)
    if node_shape is not None:
        graph_nodes.set_shape([None] + list(node_shape))
    graph_edges = tf.io.parse_tensor(parsed_example['graph_edges'], tf.float32)
    if edge_shape is not None:
        graph_edges.set_shape([None] + list(edge_shape))
    receivers = tf.io.parse_tensor(parsed_example['graph_receivers'], tf.int64)
    receivers.set_shape([None])
    senders = tf.io.parse_tensor(parsed_example['graph_senders'], tf.int64)
    senders.set_shape([None])
    graph = GraphsTuple(nodes=graph_nodes,
                        edges=graph_edges,
                        globals=None,
                        receivers=receivers,
                        senders=senders,
                        n_node=tf.shape(graph_nodes)[0:1],
                        n_edge=tf.shape(graph_edges)[0:1])
    return (graph, image)
Exemple #28
0
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
    """Returns randomly rotated graph representation.

  The rotation is an element of O(3) with rotation angles multiple of pi/2.
  This function assumes that the relative particle distances are stored in
  the edge features.

  Args:
    graph: The graphs tuple as defined in `graph_nets.graphs`.
  """
    # Transposes edge features, so that the axes are in the first dimension.
    # Outputs a tensor of shape [3, n_particles].
    xyz = tf.transpose(graph.edges)
    # Random pi/2 rotation(s)
    permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
    xyz = tf.gather(xyz, permutation)
    # Random reflections.
    symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
    symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
    xyz = xyz * symmetry
    edges = tf.transpose(xyz)
    return graph.replace(edges=edges)
Exemple #29
0
 def _build(self, batch, *args, **kwargs):
     (graph, img, c) = batch
     del c
     encoded_graph = self.encoder_graph(graph)
     tf.summary.image(f'img_before_cnn', img[None, ...], step=self.step)
     img = self.image_cnn(img[None, ...])
     for channel in range(img.shape[-1]):
         tf.summary.image(f'img_after_cnn[{channel}]',
                          img[..., channel:channel + 1],
                          step=self.step)
     #1, w,h,c -> w*h, c
     nodes = tf.reshape(img, (-1, self.image_feature_size))
     img_graph = GraphsTuple(nodes=nodes,
                             edges=None,
                             globals=None,
                             receivers=None,
                             senders=None,
                             n_node=tf.shape(nodes)[0:1],
                             n_edge=tf.constant([0]))
     connected_graph = fully_connect_graph_dynamic(img_graph)
     encoded_img = self.encoder_image(connected_graph)
     return self.compare(
         tf.concat([encoded_graph.globals, encoded_img.globals],
                   axis=1))  #[1]
Exemple #30
0
def build_dataset(tfrecords, batch_size):
    """
    Build data set from a directory of tfrecords. With graph batching

    Args:
        data_dir: str, path to *.tfrecords

    Returns: Dataset obj.
    """
    dataset = tf.data.TFRecordDataset(tfrecords).map(
        partial(decode_examples,
                node_shape=(10, ),
                edge_shape=(2, ),
                image_shape=(1024, 1024, 1)))  # (graph, image, idx)

    dataset = dataset.map(lambda graph_data_dict, img, cluster_idx,
                          projection_idx, vprime: graph_data_dict).shuffle(
                              buffer_size=50)
    dataset = dataset.map(
        lambda graph_data_dict: GraphsTuple(**graph_data_dict))

    # dataset = batch_dataset_set_graph_tuples(all_graphs_same_size=True, dataset=dataset, batch_size=batch_size)

    return dataset