Ejemplo n.º 1
0
def test_batch_reshape():
    data = dict(nodes=tf.reshape(tf.range(3 * 2), (3, 2)),
                edges=tf.reshape(tf.range(40 * 5), (40, 5)),
                senders=tf.random.uniform((40, ),
                                          minval=0,
                                          maxval=3,
                                          dtype=tf.int32),
                receivers=tf.random.uniform((40, ),
                                            minval=0,
                                            maxval=3,
                                            dtype=tf.int32),
                n_node=tf.constant([3]),
                n_edge=tf.constant([40]),
                globals=None)
    graph = GraphsTuple(**data)
    graphs = utils_tf.concat([graph] * 4, axis=0)
    batched_graphs = graph_batch_reshape(graphs)
    assert tf.reduce_all(
        batched_graphs.nodes[0] == batched_graphs.nodes[1]).numpy()
    assert tf.reduce_all(
        batched_graphs.edges[0] == batched_graphs.edges[1]).numpy()
    assert tf.reduce_all(
        batched_graphs.senders[0] +
        graphs.n_node[0] == batched_graphs.senders[1]).numpy()
    assert tf.reduce_all(
        batched_graphs.receivers[0] +
        graphs.n_node[0] == batched_graphs.receivers[1]).numpy()

    # print(batched_graphs)
    unbatched_graphs = graph_unbatch_reshape(batched_graphs)
    for (t1, t2) in zip(graphs, unbatched_graphs):
        if t1 is not None:
            assert tf.reduce_all(t1 == t2).numpy()
Ejemplo n.º 2
0
    def construct_input_graph(self, input_sequence, N2):
        G, N = get_shape(input_sequence)
        # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3, embedding_dim
        input_tokens = tf.nn.embedding_lookup(self.embeddings, input_sequence)
        self.initialize_positional_encodings(input_tokens)
        nodes = input_tokens + self.positional_encodings
        n_node = tf.fill([G], N)
        n_edge = tf.zeros_like(n_node)
        data_dict = dict(nodes=nodes, edges=None, senders=None, receivers=None, globals=None,
                         n_node=n_node,
                         n_edge=n_edge)
        concat_graphs = GraphsTuple(**data_dict)
        concat_graphs = graph_unbatch_reshape(concat_graphs)  # [n_graphs * (num_input + num_output), embedding_size]
        # nodes, senders, receivers, globals
        def edge_connect_rule(sender, receiver):
            # . a . b -> a . b .
            complete_2d = (sender < N2 + 1) & (receiver < N2 + 1) & (
                        sender + 1 != receiver)  # exclude senders from one-right, so it doesn't learn copy.
            auto_regressive_3d = (sender <= receiver) & (
                        receiver >= N2 + 1)  # auto-regressive (excluding 2d) with self-loops
            return complete_2d | auto_regressive_3d

        # nodes, senders, receivers, globals
        concat_graphs = connect_graph_dynamic(concat_graphs, edge_connect_rule)
        return concat_graphs
Ejemplo n.º 3
0
def build_dataset(data_dir, batch_size):
    tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords'))

    dataset = tf.data.TFRecordDataset(tfrecords).map(partial(decode_examples,
                                                             node_shape=(11,),
                                                             image_shape=(256, 256, 1),
                                                             k=6))  # (graph, image, spsh, proj)

    dataset = dataset.map(lambda graph_data_dict, img, spsh, proj, e: (graph_data_dict, img))

    dataset.batch(batch_size)
    #batch fixing mechanism
    dataset = dataset.map(lambda data_dict, image: (batch_graph_data_dict(data_dict), image))
    dataset = dataset.map(lambda data_dict, image: (GraphsTuple(**data_dict,
                                                                edges=None, receivers=None, senders=None, globals=None), image))
    dataset = dataset.map(lambda batched_graphs, image: (graph_unbatch_reshape(batched_graphs), image))
    # dataset = dataset.cache()
    return dataset
        def _core(output_token_idx, latent_graphs, prev_kl_term,
                  prev_token_3d_samples_onehot, prev_logits_3d):
            batched_latent_graphs = graph_batch_reshape(latent_graphs)
            batched_input_nodes = batched_latent_graphs.nodes  # [n_graphs, num_input + num_output, embedding_size]

            # todo: use self-attention
            latent_graphs = self.selfattention_core(latent_graphs)
            latent_graphs = self.edge_block(latent_graphs)
            latent_graphs = self.node_block(latent_graphs)

            batched_latent_graphs = graph_batch_reshape(latent_graphs)

            token_3d_logits = batched_latent_graphs.nodes[:, -self.
                                                          num_output:, :]  # n_graphs, num_output, num_embedding

            token_3d_logits -= tf.reduce_mean(token_3d_logits,
                                              axis=-1,
                                              keepdims=True)
            token_3d_logits /= tf.math.reduce_std(token_3d_logits,
                                                  axis=-1,
                                                  keepdims=True)
            reduce_logsumexp = tf.math.reduce_logsumexp(
                token_3d_logits, axis=-1)  # [n_graphs, num_output]
            reduce_logsumexp = tf.tile(
                reduce_logsumexp[..., None],
                [1, 1, self.num_embedding
                 ])  # [ n_graphs, num_output, num_embedding]
            token_3d_logits -= reduce_logsumexp

            token_distribution = tfp.distributions.RelaxedOneHotCategorical(
                temperature, logits=token_3d_logits)
            token_3d_samples_onehot = token_distribution.sample(
                (1, ), name='token_samples'
            )  # [1, n_graphs, num_output, num_embedding]
            token_3d_samples_onehot = token_3d_samples_onehot[
                0]  # [n_graphs, num_output, num_embedding]
            # token_3d_samples_max_index = tf.math.argmax(token_3d_logits, axis=-1, output_type=tf.int32)
            # token_3d_samples_onehot = tf.cast(tf.tile(tf.range(self.num_embedding)[None, None, :], [n_graphs, self.num_output, 1]) ==
            #                                   token_3d_samples_max_index[:,:,None], tf.float32)  # [n_graphs, num_output, num_embedding]
            token_3d_samples = tf.einsum(
                'goe,ed->god', token_3d_samples_onehot,
                self.embeddings)  # [n_graphs, num_ouput, embedding_dim]
            _mask = tf.range(
                self.num_output) == output_token_idx  # [num_output]
            mask = tf.concat([
                tf.zeros(n_node_per_graph_before_concat, dtype=tf.bool), _mask
            ],
                             axis=0)  # num_input + num_output
            mask = tf.tile(
                mask[None, :, None],
                [n_graphs, 1, self.embedding_size
                 ])  # [n_graphs, num_input + num_output, embedding_size]

            kl_term = tf.reduce_sum(
                (token_3d_samples_onehot * token_3d_logits),
                axis=-1)  # [n_graphs, num_output]
            kl_term = tf.reduce_sum(tf.cast(_mask, tf.float32) * kl_term,
                                    axis=-1)  # [n_graphs]
            kl_term += prev_kl_term

            # n_graphs, n_node+num_output, embedding_size
            output_nodes = tf.where(
                mask,
                tf.concat([
                    tf.zeros([
                        n_graphs, n_node_per_graph_before_concat,
                        self.embedding_size
                    ]), token_3d_samples
                ],
                          axis=1), batched_input_nodes)
            batched_latent_graphs = batched_latent_graphs.replace(
                nodes=output_nodes)
            latent_graphs = graph_unbatch_reshape(batched_latent_graphs)

            return (output_token_idx + 1, latent_graphs, kl_term,
                    token_3d_samples_onehot, token_3d_logits)
    def _build(self, graphs, temperature):
        """
        Adds another set of nodes to each graph. Autoregressively links all nodes in a graph.

        Args:
            graphs: batched GraphsTuple, node_shape = [n_graphs * num_input, input_embedding_size]
            temperature: scalar > 0
        Returns:
            #todo: give shapes to returns
            token_node
             kl_div
             token_3d_samples_onehot
             basis_weights

        """
        # give graphs edges and new node dimension (linear transformation)
        graphs = self.projection_node_block(
            graphs)  # nodes = [n_graphs * num_input, embedding_size]
        batched_graphs = graph_batch_reshape(
            graphs)  # nodes = [n_graphs, num_input, embedding_size]
        [n_graphs, n_node_per_graph_before_concat,
         _] = get_shape(batched_graphs.nodes)

        concat_nodes = tf.concat(
            [
                batched_graphs.nodes,
                tf.tile(self.starting_node_variable[None, :], [n_graphs, 1, 1])
            ],
            axis=-2)  # [n_graphs, num_input + num_output, embedding_size]
        batched_graphs = batched_graphs.replace(
            nodes=concat_nodes,
            globals=tf.tile(self.starting_global_variable[None, :],
                            [n_graphs, 1]),
            n_node=tf.fill([n_graphs],
                           n_node_per_graph_before_concat + self.num_output))
        concat_graphs = graph_unbatch_reshape(
            batched_graphs
        )  # [n_graphs * (num_input + num_output), embedding_size]

        # nodes, senders, receivers, globals
        concat_graphs = autoregressive_connect_graph_dynamic(
            concat_graphs
        )  # exclude self edges because 3d tokens orginally placeholder?

        # todo: this only works if exclude_self_edges=False
        n_edge = n_graphs * (
            (n_node_per_graph_before_concat + self.num_output) *
            (n_node_per_graph_before_concat + self.num_output - 1) // 2 +
            (n_node_per_graph_before_concat + self.num_output))

        latent_graphs = concat_graphs.replace(edges=tf.tile(
            tf.constant(self.edge_size * [0.])[None, :], [n_edge, 1]))
        latent_graphs.receivers.set_shape([n_edge])
        latent_graphs.senders.set_shape([n_edge])

        def _core(output_token_idx, latent_graphs, prev_kl_term,
                  prev_token_3d_samples_onehot, prev_logits_3d):
            batched_latent_graphs = graph_batch_reshape(latent_graphs)
            batched_input_nodes = batched_latent_graphs.nodes  # [n_graphs, num_input + num_output, embedding_size]

            # todo: use self-attention
            latent_graphs = self.selfattention_core(latent_graphs)
            latent_graphs = self.edge_block(latent_graphs)
            latent_graphs = self.node_block(latent_graphs)

            batched_latent_graphs = graph_batch_reshape(latent_graphs)

            token_3d_logits = batched_latent_graphs.nodes[:, -self.
                                                          num_output:, :]  # n_graphs, num_output, num_embedding

            token_3d_logits -= tf.reduce_mean(token_3d_logits,
                                              axis=-1,
                                              keepdims=True)
            token_3d_logits /= tf.math.reduce_std(token_3d_logits,
                                                  axis=-1,
                                                  keepdims=True)
            reduce_logsumexp = tf.math.reduce_logsumexp(
                token_3d_logits, axis=-1)  # [n_graphs, num_output]
            reduce_logsumexp = tf.tile(
                reduce_logsumexp[..., None],
                [1, 1, self.num_embedding
                 ])  # [ n_graphs, num_output, num_embedding]
            token_3d_logits -= reduce_logsumexp

            token_distribution = tfp.distributions.RelaxedOneHotCategorical(
                temperature, logits=token_3d_logits)
            token_3d_samples_onehot = token_distribution.sample(
                (1, ), name='token_samples'
            )  # [1, n_graphs, num_output, num_embedding]
            token_3d_samples_onehot = token_3d_samples_onehot[
                0]  # [n_graphs, num_output, num_embedding]
            # token_3d_samples_max_index = tf.math.argmax(token_3d_logits, axis=-1, output_type=tf.int32)
            # token_3d_samples_onehot = tf.cast(tf.tile(tf.range(self.num_embedding)[None, None, :], [n_graphs, self.num_output, 1]) ==
            #                                   token_3d_samples_max_index[:,:,None], tf.float32)  # [n_graphs, num_output, num_embedding]
            token_3d_samples = tf.einsum(
                'goe,ed->god', token_3d_samples_onehot,
                self.embeddings)  # [n_graphs, num_ouput, embedding_dim]
            _mask = tf.range(
                self.num_output) == output_token_idx  # [num_output]
            mask = tf.concat([
                tf.zeros(n_node_per_graph_before_concat, dtype=tf.bool), _mask
            ],
                             axis=0)  # num_input + num_output
            mask = tf.tile(
                mask[None, :, None],
                [n_graphs, 1, self.embedding_size
                 ])  # [n_graphs, num_input + num_output, embedding_size]

            kl_term = tf.reduce_sum(
                (token_3d_samples_onehot * token_3d_logits),
                axis=-1)  # [n_graphs, num_output]
            kl_term = tf.reduce_sum(tf.cast(_mask, tf.float32) * kl_term,
                                    axis=-1)  # [n_graphs]
            kl_term += prev_kl_term

            # n_graphs, n_node+num_output, embedding_size
            output_nodes = tf.where(
                mask,
                tf.concat([
                    tf.zeros([
                        n_graphs, n_node_per_graph_before_concat,
                        self.embedding_size
                    ]), token_3d_samples
                ],
                          axis=1), batched_input_nodes)
            batched_latent_graphs = batched_latent_graphs.replace(
                nodes=output_nodes)
            latent_graphs = graph_unbatch_reshape(batched_latent_graphs)

            return (output_token_idx + 1, latent_graphs, kl_term,
                    token_3d_samples_onehot, token_3d_logits)

        _, latent_graphs, kl_div, token_3d_samples_onehot, logits_3d = tf.while_loop(
            cond=lambda output_token_idx, state, _, __, ___: output_token_idx <
            self.num_output,
            body=lambda output_token_idx, state, prev_kl_term,
            prev_token_3d_samples_onehot, prev_logits_3d: _core(
                output_token_idx, state, prev_kl_term,
                prev_token_3d_samples_onehot, prev_logits_3d),
            loop_vars=(tf.constant([0]), latent_graphs,
                       tf.zeros((n_graphs, ), dtype=tf.float32),
                       tf.zeros(
                           (n_graphs, self.num_output, self.num_embedding),
                           dtype=tf.float32),
                       tf.zeros(
                           (n_graphs, self.num_output, self.num_embedding),
                           dtype=tf.float32)))

        latent_graphs = graph_batch_reshape(latent_graphs)
        token_nodes = latent_graphs.nodes[:, -self.num_output:, :]

        if self.do_basis_weight:
            # compute weights for how much each basis function will contribute, forcing later ones to contribute less.
            #todo: use self-attention
            basis_weight_graphs = self.selfattention_weights(latent_graphs)
            basis_weight_graphs = self.basis_weight_node_block(
                self.basis_weight_edge_block(basis_weight_graphs))
            basis_weight_graphs = graph_batch_reshape(basis_weight_graphs)
            #[n_graphs, num_output]
            basis_weights = basis_weight_graphs.nodes[:, -self.num_output:, 0]
            #make the weights shrink with increasing component
            basis_weights = tf.math.cumprod(tf.nn.sigmoid(basis_weights),
                                            axis=-1)
            return token_nodes, kl_div, token_3d_samples_onehot, basis_weights, logits_3d
        else:
            return token_nodes, kl_div, token_3d_samples_onehot, logits_3d