Exemplo n.º 1
0
    def __call__(
        self,
        adj_matrices: tf.SparseTensor,
        node_features: tf.Tensor,
        graph_sizes: tf.Tensor,
        mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN
    ) -> tf.Tensor:
        if not self.built:
            self.build(node_features.shape[2].value,
                       adj_matrices.get_shape()[1].value)

        # Pad features if needed
        if self.initial_node_features_size < self.node_features_size:
            pad_size = self.node_features_size - self.initial_node_features_size
            padding = tf.zeros(tf.concat(
                [tf.shape(node_features)[:2], (pad_size, )], axis=0),
                               dtype=tf.float32)
            node_features = tf.concat((node_features, padding), axis=2)

        return super().__call__(adj_matrices, node_features, graph_sizes, mode)
Exemplo n.º 2
0
    def __call__(self,
                 adj_matrices: tf.SparseTensor,
                 node_features: tf.Tensor,  # Shape: [ batch_size, V, D ]
                 graph_sizes: tf.Tensor,
                 primary_paths: tf.Tensor,
                 primary_path_lengths: tf.Tensor,
                 mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN) -> tf.Tensor:

        if not self.built:
            self.build(
                node_features.shape[2].value,
                adj_matrices.get_shape()[1].value)

        # gather representations for the nodes in the pad and do decoding on this path
        primary_path_features = batch_gather(node_features, primary_paths)
        if self.encoder_type == "bidirectional_rnn":
            rnn_path_representations, rnn_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=self.fwd_cell,
                cell_bw=self.bwd_cell,
                inputs=primary_path_features,
                sequence_length=primary_path_lengths,
                dtype=tf.float32,
                swap_memory=True)
            rnn_path_representations = tf.concat(rnn_path_representations, axis=-1)

            # concat fwd and bwd representations in all substructures of the state
            f_rnn_state_fwd = tf.contrib.framework.nest.flatten(rnn_state[0])
            f_rnn_state_bwd = tf.contrib.framework.nest.flatten(rnn_state[1])
            f_rnn_state = [tf.concat([t1, t2], axis=-1)
                           for t1, t2 in zip(f_rnn_state_fwd, f_rnn_state_bwd)]

            rnn_state = tf.contrib.framework.nest.pack_sequence_as(rnn_state[0], f_rnn_state)

        elif self.encoder_type == "rnn":
            rnn_path_representations, rnn_state = tf.nn.dynamic_rnn(
                cell=self.rnn_cell,
                inputs=primary_path_features,
                sequence_length=primary_path_lengths,
                dtype=tf.float32,
                swap_memory=True)

        batch_size = tf.shape(node_features, out_type=tf.int64)[0]
        max_num_nodes = tf.shape(node_features, out_type=tf.int64)[1]

        # shift indices by 1 and mask padding indices to zero
        # this ensures that scatter_nd won't use a padding rnn representation over
        # the actual representation for a node with the same index as the padding value
        # by forcing scatter_nd to write padding representations into "dummy" vectors
        shifted_paths = primary_paths + 1
        shifted_paths = shifted_paths * tf.sequence_mask(primary_path_lengths, dtype=tf.int64)
        rnn_representations = tf.scatter_nd(
            indices=tf.reshape(stack_indices(shifted_paths, axis=0), (-1, 2)),
            updates=tf.reshape(rnn_path_representations, (-1, self.num_units)),
            shape=tf.stack([batch_size, max_num_nodes + 1, self.num_units], axis=0))

        # remove dummy vectors
        rnn_representations = rnn_representations[:, 1:, :]

        if self.ignore_graph_encoder:
            return rnn_representations, rnn_state

        node_representations, graph_state = self.base_graph_encoder(
            adj_matrices=adj_matrices,
            node_features=self.merge_layer(
                tf.concat([rnn_representations, node_features], axis=-1)),
            graph_sizes=graph_sizes,
            mode=mode)

        output = self.output_map(tf.concat([rnn_representations, node_representations], axis=-1))

        # flatten states (ie LSTM/multi-layer tuples) and calculate state size
        flatten_rnn_state_l = tf.contrib.framework.nest.flatten(rnn_state)
        flatten_rnn_state = tf.concat(flatten_rnn_state_l, axis=1)
        state_sizes = []
        for state in flatten_rnn_state_l:
            state_sizes.append(state.get_shape().as_list()[-1])
        total_state_size = sum(state_sizes)

        # concat graph state to this and linear map back to flatten size
        self.state_map = tf.layers.Dense(
            name="state_map",
            units=total_state_size,
            use_bias=False,
            kernel_initializer=eye_glorot)
        flatten_state = self.state_map(tf.concat([flatten_rnn_state, graph_state], axis=-1))

        # defatten
        flatten_state = tf.split(flatten_state, state_sizes, axis=1)
        state = tf.contrib.framework.nest.pack_sequence_as(rnn_state, flatten_state)
        return output, state
Exemplo n.º 3
0
    def __call__(
        self,
        adj_matrices: tf.SparseTensor,
        node_features: tf.Tensor,  # Shape: [ batch_size, V, D ]
        graph_sizes: tf.Tensor,
        mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN
    ) -> tf.Tensor:
        """
        Encode graphs given by a (sparse) adjacency matrix and and their initial node features,
        returning the encoding of all graph nodes.

        Args:
            adj_matrices: SparseTensor of dense shape
              [BatchSize, NumEdgeTypes, MaxNumNodes, MaxNumNodes] representing edges in graph.
              adj_matrices[g, e, v, u] == 1 means that in graph g, there is an edge of
              type e between v and u.
            node_features: Tensor of shape [BatchSize, MaxNumNodes, NodeFeatureDimension],
              representing initial node features. node_features[g, v, :] are the features of
              node v in graph g.
            graph_sizes: Tensor of shape [BatchSize] representing the number of used nodes in
              the batchedand padded graphs. graph_size[g] is the number of nodes in graph g.
            mode: Flag indicating run mode. [Unused]

        Returns: 
            Tensor of shape [BatchSize, MaxNumNodes, NodeFeatureDimension]. Representations for 
              padding nodes will be zero vectors
        """
        if not self.built:
            self.build(node_features_size=node_features.shape[2].value,
                       num_edge_types=adj_matrices.get_shape()[1].value)

        if self.create_bwd_edges:
            adj_matrices = self._create_backward_edges(adj_matrices)

        # We only care about the edge indices, as adj_matrices is only an indicator
        # matrix with values 1 or not-present (i.e., an adjacency list):
        # Shape: [ num of edges (not edge types) ~ E, 4 ]
        adj_list = tf.cast(adj_matrices.indices, tf.int32)

        max_num_vertices = tf.shape(node_features, out_type=tf.int32)[1]
        total_edges = tf.shape(adj_list, out_type=tf.int32)[0]

        # Calculate offsets for flattening the adj matrices, as we are merging all graphs into one big graph.
        # Nodes in first graph are range(0,MaxNumNodes) and edges are shifted by [0,0],
        # nodes in second graph are range(MaxNumNodes,2*MaxNumNodes) and edges are
        # shifted by [MaxNumNodes,MaxNumNodes], etc.
        graph_ids_per_edge = adj_list[:, 0]
        node_id_offsets_per_edge = tf.expand_dims(graph_ids_per_edge,
                                                  axis=-1) * max_num_vertices
        edge_shifts_per_edge = tf.tile(node_id_offsets_per_edge,
                                       multiples=(1, 2))
        offsets_per_edge = tf.concat(
            [
                tf.zeros(
                    shape=(total_edges, 1),
                    dtype=tf.int32),  # we don't need to shift the edge type
                edge_shifts_per_edge
            ],
            axis=1)

        # Flatten both adj matrices and node features. For the adjacency list, we strip out the graph id
        # and instead shift the node IDs in edges.
        flattened_adj_list = offsets_per_edge + adj_list[:, 1:]
        flattened_node_features = tf.reshape(node_features,
                                             shape=(-1,
                                                    self.node_features_size))

        # propagate on this big graph and unflatten representations
        flattened_node_repr = self._propagate(flattened_adj_list,
                                              flattened_node_features, mode)
        node_representations = tf.reshape(
            flattened_node_repr,
            shape=(-1, max_num_vertices, flattened_node_repr.shape[-1]))

        # mask for padding nodes
        graph_mask = tf.expand_dims(
            tf.sequence_mask(graph_sizes, dtype=tf.float32), -1)
        if self.gated_state:
            gate_layer = tf.layers.Dense(1,
                                         activation=tf.nn.sigmoid,
                                         name="node_gate_layer")

            output_layer = tf.layers.Dense(node_representations.shape[-1],
                                           name="node_output_layer")

            # calculate weighted, node-level outputs
            node_all_repr = tf.concat([node_features, node_representations],
                                      axis=-1)
            graph_state = gate_layer(node_all_repr) * output_layer(
                node_representations)
            graph_state = tf.reduce_sum(graph_state * graph_mask, axis=1)

        else:
            graph_state = tf.reduce_sum(node_representations * graph_mask,
                                        axis=1)
            graph_state /= tf.cast(tf.expand_dims(graph_sizes, 1), tf.float32)

        return node_representations, graph_state