Ejemplo n.º 1
0
    def destandardize_graphs_tuple(self, graphs: GraphsTuple) -> GraphsTuple:
        standard_graphs = graphs.replace(globals=self._destandardize(
            graphs.globals, mean=self.global_mean, std=self.global_std))
        standard_graphs = standard_graphs.replace(nodes=self._destandardize(
            graphs.nodes, mean=self.nodes_mean, std=self.nodes_std))
        standard_graphs = standard_graphs.replace(edges=self._destandardize(
            graphs.edges, mean=self.edges_mean, std=self.edges_std))

        return standard_graphs
Ejemplo n.º 2
0
def connect_graph_dynamic(graph: GraphsTuple, is_edge_func, name="connect_graph_dynamic"):
    """
    Connects a graph using a boolean edge mask to create edges.

    Args:
        graph: GraphsTuple
        is_edge_func: callable(sender: int, receiver: int) -> bool, should broadcast
        name:

    Returns:
        connected GraphsTuple
    """
    utils_tf._validate_edge_fields_are_all_none(graph)

    with tf.name_scope(name):
        def body(i, senders, receivers, n_edge):
            edges = _create_functional_connect_edges_dynamic(graph.n_node[i], is_edge_func)
            # edges = create_edges_func(graph.n_node[i])
            return (i + 1, senders.write(i, edges['senders']),
                    receivers.write(i, edges['receivers']),
                    n_edge.write(i, edges['n_edge']))

        num_graphs = utils_tf.get_num_graphs(graph)
        loop_condition = lambda i, *_: tf.less(i, num_graphs)
        initial_loop_vars = [0] + [
            tf.TensorArray(dtype=tf.int32, size=num_graphs, infer_shape=False)
            for _ in range(3)  # senders, receivers, n_edge
        ]
        _, senders_array, receivers_array, n_edge_array = tf.while_loop(loop_condition, body, initial_loop_vars)

        n_edge = n_edge_array.concat()
        offsets = utils_tf._compute_stacked_offsets(graph.n_node, n_edge)
        senders = senders_array.concat() + offsets
        receivers = receivers_array.concat() + offsets
        senders.set_shape(offsets.shape)
        receivers.set_shape(offsets.shape)

        receivers.set_shape([None])
        senders.set_shape([None])

        num_graphs = graph.n_node.get_shape().as_list()[0]
        n_edge.set_shape([num_graphs])

        return graph.replace(senders=tf.stop_gradient(senders),
                             receivers=tf.stop_gradient(receivers),
                             n_edge=tf.stop_gradient(n_edge))
Ejemplo n.º 3
0
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
    """Returns randomly rotated graph representation.

  The rotation is an element of O(3) with rotation angles multiple of pi/2.
  This function assumes that the relative particle distances are stored in
  the edge features.

  Args:
    graph: The graphs tuple as defined in `graph_nets.graphs`.
  """
    # Transposes edge features, so that the axes are in the first dimension.
    # Outputs a tensor of shape [3, n_particles].
    xyz = tf.transpose(graph.edges)
    # Random pi/2 rotation(s)
    permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
    xyz = tf.gather(xyz, permutation)
    # Random reflections.
    symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
    symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
    xyz = xyz * symmetry
    edges = tf.transpose(xyz)
    return graph.replace(edges=edges)
Ejemplo n.º 4
0
    def _build(self, graph):
        # give graph edges and new node dimension (linear transformation)
        graph = graph.replace(edges=tf.tile(self.intra_graph_edge_variable[None, :], [graph.n_edge[0], 1]))
        graph = self.projection_node_block(graph)  # [n_nodes, node_size]
        # print('graph 1', graph)
        n_node = tf.shape(graph.nodes)[0]
        graph.replace(n_node=n_node)
        # create fully connected output token nodes
        token_start_nodes = tf.tile(self.empty_node_variable[None, :], [self.num_output, 1])
        token_graph = GraphsTuple(nodes=token_start_nodes,
                                  edges=None,
                                  globals=tf.constant([0.], dtype=tf.float32),
                                  senders=None,
                                  receivers=None,
                                  n_node=tf.constant([self.num_output], dtype=tf.int32),
                                  n_edge=tf.constant([0], dtype=tf.int32))
        token_graph = fully_connect_graph_dynamic(token_graph)
        # print('\n token graph', token_graph, '\n')
        token_graph = token_graph.replace(
            edges=tf.tile(self.intra_token_graph_edge_variable[None, :], [token_graph.n_edge[0], 1]))
        concat_graph = concat([graph, token_graph], axis=0)  # n_node = [n_nodes, n_tokes]
        concat_graph = concat_graph.replace(n_node=tf.reduce_sum(concat_graph.n_node, keepdims=True),
                                            n_edge=tf.reduce_sum(concat_graph.n_edge,
                                                                 keepdims=True))  # n_node=[n_nodes+n_tokens]

        # add random edges between
        # choose random unique set of nodes in graph, choose random set of nodes in token_graph
        gumbel = -tf.math.log(-tf.math.log(tf.random.uniform((n_node,))))
        n_connect_edges = tf.cast(
            tf.multiply(tf.constant([self.inter_graph_connect_prob]), tf.cast(n_node, tf.float32)), tf.int32)
        _, graph_senders = tf.nn.top_k(gumbel, n_connect_edges[0])
        # print('graph_senders', graph_senders)
        token_graph_receivers = n_node + tf.random.uniform(shape=n_connect_edges, minval=0, maxval=self.num_output,
                                                           dtype=tf.int32)
        # print('token_graph_receivers', token_graph_receivers)
        senders = tf.concat([concat_graph.senders, graph_senders, token_graph_receivers],
                            axis=0)  # add bi-directional senders + receivers
        receivers = tf.concat([concat_graph.receivers, token_graph_receivers, graph_senders], axis=0)
        inter_edges = tf.tile(self.inter_graph_edge_variable[None, :],
                              tf.concat([2 * n_connect_edges, tf.constant([1], dtype=tf.int32)],
                                        axis=0))  # 200 = 10000(n_nodes) * 0.01 * 2
        edges = tf.concat([concat_graph.edges, inter_edges], axis=0)
        concat_graph = concat_graph.replace(senders=senders, receivers=receivers, edges=edges,
                                            n_edge=concat_graph.n_edge[0] + 2 * n_connect_edges[0],
                                            # concat_graph.n_edge[0] + 2 * n_connect_edges
                                            globals=self.starting_global_variable[None, :])
        # print('starting global', self.starting_global_variable[None, :])
        latent_graph = concat_graph

        print('concat_graph_nodes', self.name, concat_graph.nodes)
        for step in range(self.crossing_steps):  # this would be that theoretical crossing time for information through the graph
            input_nodes = latent_graph.nodes
            latent_graph = self.edge_block(latent_graph)
            latent_graph = self.node_block(latent_graph)
            latent_graph = self.global_block(latent_graph)
            latent_graph = latent_graph.replace(nodes=latent_graph.nodes + input_nodes)  # residual connections

        print('latent_graph_nodes', self.name, latent_graph.nodes)
        print('latent_graph_edges', self.name, latent_graph.edges)
        print('latent_graph_globals', self.name, latent_graph.globals)
        latent_graph = latent_graph.replace(nodes=latent_graph.nodes[n_node:],
                                            edges=None,
                                            receivers=None,
                                            senders=None,
                                            globals=None,
                                            n_node=tf.constant([self.num_output], dtype=tf.int32),
                                            n_edge=tf.constant(0, dtype=tf.int32))
        output_graph = self.output_projection_node_block(latent_graph)
        print('output_graph_nodes', self.name, output_graph.nodes)

        return output_graph
Ejemplo n.º 5
0
    def _build(self, graph, crossing_steps):
        n_edge = graph.n_edge[0]
        graph = graph.replace(edges=tf.tile(
            self.intra_graph_edge_variable[None, :], [n_edge, 1]))
        graph = self.projection_node_block(graph)  # [n_nodes, node_size]
        n_node = tf.shape(graph.nodes)[0]
        # create fully connected output token nodes
        token_start_nodes = tf.tile(self.empty_node_variable[None, :],
                                    [self.num_output, 1])
        token_graph = GraphsTuple(nodes=token_start_nodes,
                                  edges=None,
                                  globals=None,
                                  senders=None,
                                  receivers=None,
                                  n_node=[self.num_output],
                                  n_edge=None)
        token_graph = fully_connect_graph_static(token_graph)
        n_edge = token_graph.n_edge[0]
        token_graph = token_graph.replace(edges=tf.tile(
            self.intra_token_graph_edge_variable[None, :], [n_edge, 1]))

        concat_graph = concat([graph, token_graph],
                              axis=0)  # n_node = [n_nodes, n_tokes]
        concat_graph = concat_graph.replace(n_node=tf.reduce_sum(
            concat_graph, keepdims=True))  # n_node=[n_nodes+n_tokens]
        # add random edges between
        # choose random unique set of nodes in graph, choose random set of nodes in token_graph
        gumbel = -tf.log(-tf.log(tf.random_uniform((n_node, ))))
        n_connect_edges = tf.cast(self.inter_graph_connect_prob * n_node,
                                  tf.int32)
        _, graph_senders = tf.nn.top_k(gumbel, n_connect_edges)
        token_graph_receivers = n_node + tf.random.uniform(
            (n_connect_edges, ),
            minval=0,
            maxval=self.num_output,
            dtype=tf.int32)
        senders = tf.concat(
            [concat_graph.senders, graph_senders, token_graph_receivers],
            axis=0)  # add bi-directional senders + receivers
        receivers = tf.concat(
            [concat_graph.receivers, token_graph_receivers, graph_senders],
            axis=0)
        inter_edges = tf.tile(self.inter_graph_edge_variable[None, :],
                              [2 * n_connect_edges, 1])
        edges = tf.concat([concat_graph.edges, inter_edges], axis=0)
        concat_graph = concat_graph.replace(
            senders=senders,
            receivers=receivers,
            edges=edges,
            n_edge=[concat_graph.n_edge[0] + 2 * n_connect_edges],
            globals=self.starting_global_variable[None, :])

        latent_graph = concat_graph
        for _ in range(
                self.crossing_steps
        ):  # this would be that theoretical crossing time for information through the graph
            input_nodes = latent_graph.nodes
            latent_graph = self.edge_block(latent_graph)
            latent_graph = self.node_block(latent_graph)
            latent_graph = self.global_block(latent_graph)
            latent_graph = latent_graph.replace(
                nodes=latent_graph.nodes + input_nodes)  # residual connections

        output_graph = self.output_projection_node_block(latent_graph)

        return output_graph