def __call__( self, adj_matrices: tf.SparseTensor, node_features: tf.Tensor, graph_sizes: tf.Tensor, mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN ) -> tf.Tensor: if not self.built: self.build(node_features.shape[2].value, adj_matrices.get_shape()[1].value) # Pad features if needed if self.initial_node_features_size < self.node_features_size: pad_size = self.node_features_size - self.initial_node_features_size padding = tf.zeros(tf.concat( [tf.shape(node_features)[:2], (pad_size, )], axis=0), dtype=tf.float32) node_features = tf.concat((node_features, padding), axis=2) return super().__call__(adj_matrices, node_features, graph_sizes, mode)
def __call__(self, adj_matrices: tf.SparseTensor, node_features: tf.Tensor, # Shape: [ batch_size, V, D ] graph_sizes: tf.Tensor, primary_paths: tf.Tensor, primary_path_lengths: tf.Tensor, mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN) -> tf.Tensor: if not self.built: self.build( node_features.shape[2].value, adj_matrices.get_shape()[1].value) # gather representations for the nodes in the pad and do decoding on this path primary_path_features = batch_gather(node_features, primary_paths) if self.encoder_type == "bidirectional_rnn": rnn_path_representations, rnn_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=self.fwd_cell, cell_bw=self.bwd_cell, inputs=primary_path_features, sequence_length=primary_path_lengths, dtype=tf.float32, swap_memory=True) rnn_path_representations = tf.concat(rnn_path_representations, axis=-1) # concat fwd and bwd representations in all substructures of the state f_rnn_state_fwd = tf.contrib.framework.nest.flatten(rnn_state[0]) f_rnn_state_bwd = tf.contrib.framework.nest.flatten(rnn_state[1]) f_rnn_state = [tf.concat([t1, t2], axis=-1) for t1, t2 in zip(f_rnn_state_fwd, f_rnn_state_bwd)] rnn_state = tf.contrib.framework.nest.pack_sequence_as(rnn_state[0], f_rnn_state) elif self.encoder_type == "rnn": rnn_path_representations, rnn_state = tf.nn.dynamic_rnn( cell=self.rnn_cell, inputs=primary_path_features, sequence_length=primary_path_lengths, dtype=tf.float32, swap_memory=True) batch_size = tf.shape(node_features, out_type=tf.int64)[0] max_num_nodes = tf.shape(node_features, out_type=tf.int64)[1] # shift indices by 1 and mask padding indices to zero # this ensures that scatter_nd won't use a padding rnn representation over # the actual representation for a node with the same index as the padding value # by forcing scatter_nd to write padding representations into "dummy" vectors shifted_paths = primary_paths + 1 shifted_paths = shifted_paths * tf.sequence_mask(primary_path_lengths, dtype=tf.int64) rnn_representations = tf.scatter_nd( indices=tf.reshape(stack_indices(shifted_paths, axis=0), (-1, 2)), updates=tf.reshape(rnn_path_representations, (-1, self.num_units)), shape=tf.stack([batch_size, max_num_nodes + 1, self.num_units], axis=0)) # remove dummy vectors rnn_representations = rnn_representations[:, 1:, :] if self.ignore_graph_encoder: return rnn_representations, rnn_state node_representations, graph_state = self.base_graph_encoder( adj_matrices=adj_matrices, node_features=self.merge_layer( tf.concat([rnn_representations, node_features], axis=-1)), graph_sizes=graph_sizes, mode=mode) output = self.output_map(tf.concat([rnn_representations, node_representations], axis=-1)) # flatten states (ie LSTM/multi-layer tuples) and calculate state size flatten_rnn_state_l = tf.contrib.framework.nest.flatten(rnn_state) flatten_rnn_state = tf.concat(flatten_rnn_state_l, axis=1) state_sizes = [] for state in flatten_rnn_state_l: state_sizes.append(state.get_shape().as_list()[-1]) total_state_size = sum(state_sizes) # concat graph state to this and linear map back to flatten size self.state_map = tf.layers.Dense( name="state_map", units=total_state_size, use_bias=False, kernel_initializer=eye_glorot) flatten_state = self.state_map(tf.concat([flatten_rnn_state, graph_state], axis=-1)) # defatten flatten_state = tf.split(flatten_state, state_sizes, axis=1) state = tf.contrib.framework.nest.pack_sequence_as(rnn_state, flatten_state) return output, state
def __call__( self, adj_matrices: tf.SparseTensor, node_features: tf.Tensor, # Shape: [ batch_size, V, D ] graph_sizes: tf.Tensor, mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN ) -> tf.Tensor: """ Encode graphs given by a (sparse) adjacency matrix and and their initial node features, returning the encoding of all graph nodes. Args: adj_matrices: SparseTensor of dense shape [BatchSize, NumEdgeTypes, MaxNumNodes, MaxNumNodes] representing edges in graph. adj_matrices[g, e, v, u] == 1 means that in graph g, there is an edge of type e between v and u. node_features: Tensor of shape [BatchSize, MaxNumNodes, NodeFeatureDimension], representing initial node features. node_features[g, v, :] are the features of node v in graph g. graph_sizes: Tensor of shape [BatchSize] representing the number of used nodes in the batchedand padded graphs. graph_size[g] is the number of nodes in graph g. mode: Flag indicating run mode. [Unused] Returns: Tensor of shape [BatchSize, MaxNumNodes, NodeFeatureDimension]. Representations for padding nodes will be zero vectors """ if not self.built: self.build(node_features_size=node_features.shape[2].value, num_edge_types=adj_matrices.get_shape()[1].value) if self.create_bwd_edges: adj_matrices = self._create_backward_edges(adj_matrices) # We only care about the edge indices, as adj_matrices is only an indicator # matrix with values 1 or not-present (i.e., an adjacency list): # Shape: [ num of edges (not edge types) ~ E, 4 ] adj_list = tf.cast(adj_matrices.indices, tf.int32) max_num_vertices = tf.shape(node_features, out_type=tf.int32)[1] total_edges = tf.shape(adj_list, out_type=tf.int32)[0] # Calculate offsets for flattening the adj matrices, as we are merging all graphs into one big graph. # Nodes in first graph are range(0,MaxNumNodes) and edges are shifted by [0,0], # nodes in second graph are range(MaxNumNodes,2*MaxNumNodes) and edges are # shifted by [MaxNumNodes,MaxNumNodes], etc. graph_ids_per_edge = adj_list[:, 0] node_id_offsets_per_edge = tf.expand_dims(graph_ids_per_edge, axis=-1) * max_num_vertices edge_shifts_per_edge = tf.tile(node_id_offsets_per_edge, multiples=(1, 2)) offsets_per_edge = tf.concat( [ tf.zeros( shape=(total_edges, 1), dtype=tf.int32), # we don't need to shift the edge type edge_shifts_per_edge ], axis=1) # Flatten both adj matrices and node features. For the adjacency list, we strip out the graph id # and instead shift the node IDs in edges. flattened_adj_list = offsets_per_edge + adj_list[:, 1:] flattened_node_features = tf.reshape(node_features, shape=(-1, self.node_features_size)) # propagate on this big graph and unflatten representations flattened_node_repr = self._propagate(flattened_adj_list, flattened_node_features, mode) node_representations = tf.reshape( flattened_node_repr, shape=(-1, max_num_vertices, flattened_node_repr.shape[-1])) # mask for padding nodes graph_mask = tf.expand_dims( tf.sequence_mask(graph_sizes, dtype=tf.float32), -1) if self.gated_state: gate_layer = tf.layers.Dense(1, activation=tf.nn.sigmoid, name="node_gate_layer") output_layer = tf.layers.Dense(node_representations.shape[-1], name="node_output_layer") # calculate weighted, node-level outputs node_all_repr = tf.concat([node_features, node_representations], axis=-1) graph_state = gate_layer(node_all_repr) * output_layer( node_representations) graph_state = tf.reduce_sum(graph_state * graph_mask, axis=1) else: graph_state = tf.reduce_sum(node_representations * graph_mask, axis=1) graph_state /= tf.cast(tf.expand_dims(graph_sizes, 1), tf.float32) return node_representations, graph_state