def statement_encoder( populated_proto_db: unlabelled_graph_database.Database, populated_graph_db: graph_tuple_database.Database, cache_size: int, ): """A test fixture which enumerates statement encoders.""" return graph2seq.StatementEncoder(populated_graph_db, populated_proto_db, cache_size)
def GetEncoder(self) -> graph2seq.EncoderBase: """Construct the graph encoder.""" if not (self.graph_db.node_y_dimensionality and self.graph_db.node_x_dimensionality == 2 and self.graph_db.graph_y_dimensionality == 0): raise app.UsageError( f"Unsupported graph dimensionalities: {self.graph_db}") return graph2seq.StatementEncoder( graph_db=self.graph_db, proto_db=self._proto_db, max_encoded_length=self.padded_sequence_length, max_nodes=self.padded_node_sequence_length, )