def test_batch_reshape(): data = dict(nodes=tf.reshape(tf.range(3 * 2), (3, 2)), edges=tf.reshape(tf.range(40 * 5), (40, 5)), senders=tf.random.uniform((40, ), minval=0, maxval=3, dtype=tf.int32), receivers=tf.random.uniform((40, ), minval=0, maxval=3, dtype=tf.int32), n_node=tf.constant([3]), n_edge=tf.constant([40]), globals=None) graph = GraphsTuple(**data) graphs = utils_tf.concat([graph] * 4, axis=0) batched_graphs = graph_batch_reshape(graphs) assert tf.reduce_all( batched_graphs.nodes[0] == batched_graphs.nodes[1]).numpy() assert tf.reduce_all( batched_graphs.edges[0] == batched_graphs.edges[1]).numpy() assert tf.reduce_all( batched_graphs.senders[0] + graphs.n_node[0] == batched_graphs.senders[1]).numpy() assert tf.reduce_all( batched_graphs.receivers[0] + graphs.n_node[0] == batched_graphs.receivers[1]).numpy() # print(batched_graphs) unbatched_graphs = graph_unbatch_reshape(batched_graphs) for (t1, t2) in zip(graphs, unbatched_graphs): if t1 is not None: assert tf.reduce_all(t1 == t2).numpy()
def construct_input_graph(self, input_sequence, N2): G, N = get_shape(input_sequence) # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3, embedding_dim input_tokens = tf.nn.embedding_lookup(self.embeddings, input_sequence) self.initialize_positional_encodings(input_tokens) nodes = input_tokens + self.positional_encodings n_node = tf.fill([G], N) n_edge = tf.zeros_like(n_node) data_dict = dict(nodes=nodes, edges=None, senders=None, receivers=None, globals=None, n_node=n_node, n_edge=n_edge) concat_graphs = GraphsTuple(**data_dict) concat_graphs = graph_unbatch_reshape(concat_graphs) # [n_graphs * (num_input + num_output), embedding_size] # nodes, senders, receivers, globals def edge_connect_rule(sender, receiver): # . a . b -> a . b . complete_2d = (sender < N2 + 1) & (receiver < N2 + 1) & ( sender + 1 != receiver) # exclude senders from one-right, so it doesn't learn copy. auto_regressive_3d = (sender <= receiver) & ( receiver >= N2 + 1) # auto-regressive (excluding 2d) with self-loops return complete_2d | auto_regressive_3d # nodes, senders, receivers, globals concat_graphs = connect_graph_dynamic(concat_graphs, edge_connect_rule) return concat_graphs
def build_dataset(data_dir, batch_size): tfrecords = glob.glob(os.path.join(data_dir, '*.tfrecords')) dataset = tf.data.TFRecordDataset(tfrecords).map(partial(decode_examples, node_shape=(11,), image_shape=(256, 256, 1), k=6)) # (graph, image, spsh, proj) dataset = dataset.map(lambda graph_data_dict, img, spsh, proj, e: (graph_data_dict, img)) dataset.batch(batch_size) #batch fixing mechanism dataset = dataset.map(lambda data_dict, image: (batch_graph_data_dict(data_dict), image)) dataset = dataset.map(lambda data_dict, image: (GraphsTuple(**data_dict, edges=None, receivers=None, senders=None, globals=None), image)) dataset = dataset.map(lambda batched_graphs, image: (graph_unbatch_reshape(batched_graphs), image)) # dataset = dataset.cache() return dataset
def _core(output_token_idx, latent_graphs, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d): batched_latent_graphs = graph_batch_reshape(latent_graphs) batched_input_nodes = batched_latent_graphs.nodes # [n_graphs, num_input + num_output, embedding_size] # todo: use self-attention latent_graphs = self.selfattention_core(latent_graphs) latent_graphs = self.edge_block(latent_graphs) latent_graphs = self.node_block(latent_graphs) batched_latent_graphs = graph_batch_reshape(latent_graphs) token_3d_logits = batched_latent_graphs.nodes[:, -self. num_output:, :] # n_graphs, num_output, num_embedding token_3d_logits -= tf.reduce_mean(token_3d_logits, axis=-1, keepdims=True) token_3d_logits /= tf.math.reduce_std(token_3d_logits, axis=-1, keepdims=True) reduce_logsumexp = tf.math.reduce_logsumexp( token_3d_logits, axis=-1) # [n_graphs, num_output] reduce_logsumexp = tf.tile( reduce_logsumexp[..., None], [1, 1, self.num_embedding ]) # [ n_graphs, num_output, num_embedding] token_3d_logits -= reduce_logsumexp token_distribution = tfp.distributions.RelaxedOneHotCategorical( temperature, logits=token_3d_logits) token_3d_samples_onehot = token_distribution.sample( (1, ), name='token_samples' ) # [1, n_graphs, num_output, num_embedding] token_3d_samples_onehot = token_3d_samples_onehot[ 0] # [n_graphs, num_output, num_embedding] # token_3d_samples_max_index = tf.math.argmax(token_3d_logits, axis=-1, output_type=tf.int32) # token_3d_samples_onehot = tf.cast(tf.tile(tf.range(self.num_embedding)[None, None, :], [n_graphs, self.num_output, 1]) == # token_3d_samples_max_index[:,:,None], tf.float32) # [n_graphs, num_output, num_embedding] token_3d_samples = tf.einsum( 'goe,ed->god', token_3d_samples_onehot, self.embeddings) # [n_graphs, num_ouput, embedding_dim] _mask = tf.range( self.num_output) == output_token_idx # [num_output] mask = tf.concat([ tf.zeros(n_node_per_graph_before_concat, dtype=tf.bool), _mask ], axis=0) # num_input + num_output mask = tf.tile( mask[None, :, None], [n_graphs, 1, self.embedding_size ]) # [n_graphs, num_input + num_output, embedding_size] kl_term = tf.reduce_sum( (token_3d_samples_onehot * token_3d_logits), axis=-1) # [n_graphs, num_output] kl_term = tf.reduce_sum(tf.cast(_mask, tf.float32) * kl_term, axis=-1) # [n_graphs] kl_term += prev_kl_term # n_graphs, n_node+num_output, embedding_size output_nodes = tf.where( mask, tf.concat([ tf.zeros([ n_graphs, n_node_per_graph_before_concat, self.embedding_size ]), token_3d_samples ], axis=1), batched_input_nodes) batched_latent_graphs = batched_latent_graphs.replace( nodes=output_nodes) latent_graphs = graph_unbatch_reshape(batched_latent_graphs) return (output_token_idx + 1, latent_graphs, kl_term, token_3d_samples_onehot, token_3d_logits)
def _build(self, graphs, temperature): """ Adds another set of nodes to each graph. Autoregressively links all nodes in a graph. Args: graphs: batched GraphsTuple, node_shape = [n_graphs * num_input, input_embedding_size] temperature: scalar > 0 Returns: #todo: give shapes to returns token_node kl_div token_3d_samples_onehot basis_weights """ # give graphs edges and new node dimension (linear transformation) graphs = self.projection_node_block( graphs) # nodes = [n_graphs * num_input, embedding_size] batched_graphs = graph_batch_reshape( graphs) # nodes = [n_graphs, num_input, embedding_size] [n_graphs, n_node_per_graph_before_concat, _] = get_shape(batched_graphs.nodes) concat_nodes = tf.concat( [ batched_graphs.nodes, tf.tile(self.starting_node_variable[None, :], [n_graphs, 1, 1]) ], axis=-2) # [n_graphs, num_input + num_output, embedding_size] batched_graphs = batched_graphs.replace( nodes=concat_nodes, globals=tf.tile(self.starting_global_variable[None, :], [n_graphs, 1]), n_node=tf.fill([n_graphs], n_node_per_graph_before_concat + self.num_output)) concat_graphs = graph_unbatch_reshape( batched_graphs ) # [n_graphs * (num_input + num_output), embedding_size] # nodes, senders, receivers, globals concat_graphs = autoregressive_connect_graph_dynamic( concat_graphs ) # exclude self edges because 3d tokens orginally placeholder? # todo: this only works if exclude_self_edges=False n_edge = n_graphs * ( (n_node_per_graph_before_concat + self.num_output) * (n_node_per_graph_before_concat + self.num_output - 1) // 2 + (n_node_per_graph_before_concat + self.num_output)) latent_graphs = concat_graphs.replace(edges=tf.tile( tf.constant(self.edge_size * [0.])[None, :], [n_edge, 1])) latent_graphs.receivers.set_shape([n_edge]) latent_graphs.senders.set_shape([n_edge]) def _core(output_token_idx, latent_graphs, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d): batched_latent_graphs = graph_batch_reshape(latent_graphs) batched_input_nodes = batched_latent_graphs.nodes # [n_graphs, num_input + num_output, embedding_size] # todo: use self-attention latent_graphs = self.selfattention_core(latent_graphs) latent_graphs = self.edge_block(latent_graphs) latent_graphs = self.node_block(latent_graphs) batched_latent_graphs = graph_batch_reshape(latent_graphs) token_3d_logits = batched_latent_graphs.nodes[:, -self. num_output:, :] # n_graphs, num_output, num_embedding token_3d_logits -= tf.reduce_mean(token_3d_logits, axis=-1, keepdims=True) token_3d_logits /= tf.math.reduce_std(token_3d_logits, axis=-1, keepdims=True) reduce_logsumexp = tf.math.reduce_logsumexp( token_3d_logits, axis=-1) # [n_graphs, num_output] reduce_logsumexp = tf.tile( reduce_logsumexp[..., None], [1, 1, self.num_embedding ]) # [ n_graphs, num_output, num_embedding] token_3d_logits -= reduce_logsumexp token_distribution = tfp.distributions.RelaxedOneHotCategorical( temperature, logits=token_3d_logits) token_3d_samples_onehot = token_distribution.sample( (1, ), name='token_samples' ) # [1, n_graphs, num_output, num_embedding] token_3d_samples_onehot = token_3d_samples_onehot[ 0] # [n_graphs, num_output, num_embedding] # token_3d_samples_max_index = tf.math.argmax(token_3d_logits, axis=-1, output_type=tf.int32) # token_3d_samples_onehot = tf.cast(tf.tile(tf.range(self.num_embedding)[None, None, :], [n_graphs, self.num_output, 1]) == # token_3d_samples_max_index[:,:,None], tf.float32) # [n_graphs, num_output, num_embedding] token_3d_samples = tf.einsum( 'goe,ed->god', token_3d_samples_onehot, self.embeddings) # [n_graphs, num_ouput, embedding_dim] _mask = tf.range( self.num_output) == output_token_idx # [num_output] mask = tf.concat([ tf.zeros(n_node_per_graph_before_concat, dtype=tf.bool), _mask ], axis=0) # num_input + num_output mask = tf.tile( mask[None, :, None], [n_graphs, 1, self.embedding_size ]) # [n_graphs, num_input + num_output, embedding_size] kl_term = tf.reduce_sum( (token_3d_samples_onehot * token_3d_logits), axis=-1) # [n_graphs, num_output] kl_term = tf.reduce_sum(tf.cast(_mask, tf.float32) * kl_term, axis=-1) # [n_graphs] kl_term += prev_kl_term # n_graphs, n_node+num_output, embedding_size output_nodes = tf.where( mask, tf.concat([ tf.zeros([ n_graphs, n_node_per_graph_before_concat, self.embedding_size ]), token_3d_samples ], axis=1), batched_input_nodes) batched_latent_graphs = batched_latent_graphs.replace( nodes=output_nodes) latent_graphs = graph_unbatch_reshape(batched_latent_graphs) return (output_token_idx + 1, latent_graphs, kl_term, token_3d_samples_onehot, token_3d_logits) _, latent_graphs, kl_div, token_3d_samples_onehot, logits_3d = tf.while_loop( cond=lambda output_token_idx, state, _, __, ___: output_token_idx < self.num_output, body=lambda output_token_idx, state, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d: _core( output_token_idx, state, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d), loop_vars=(tf.constant([0]), latent_graphs, tf.zeros((n_graphs, ), dtype=tf.float32), tf.zeros( (n_graphs, self.num_output, self.num_embedding), dtype=tf.float32), tf.zeros( (n_graphs, self.num_output, self.num_embedding), dtype=tf.float32))) latent_graphs = graph_batch_reshape(latent_graphs) token_nodes = latent_graphs.nodes[:, -self.num_output:, :] if self.do_basis_weight: # compute weights for how much each basis function will contribute, forcing later ones to contribute less. #todo: use self-attention basis_weight_graphs = self.selfattention_weights(latent_graphs) basis_weight_graphs = self.basis_weight_node_block( self.basis_weight_edge_block(basis_weight_graphs)) basis_weight_graphs = graph_batch_reshape(basis_weight_graphs) #[n_graphs, num_output] basis_weights = basis_weight_graphs.nodes[:, -self.num_output:, 0] #make the weights shrink with increasing component basis_weights = tf.math.cumprod(tf.nn.sigmoid(basis_weights), axis=-1) return token_nodes, kl_div, token_3d_samples_onehot, basis_weights, logits_3d else: return token_nodes, kl_div, token_3d_samples_onehot, logits_3d