def test_train_val_test_split_counts( populated_db: ir_database.Database, train_val_test_splitter: split.Splitter, ): """Test that train/val/test splitter produces 3 splits.""" splits = train_val_test_splitter.Split(populated_db) assert len(splits) == 3
def ApplySplit( ir_db: ir_database.Database, proto_db: unlabelled_graph_database.Database, splitter: ir_split.Splitter, ): """Split the IR database and apply the split to the graph database.""" # Unset all splits. with prof.Profile(f"Unset splits on {proto_db.proto_count} protos"): update = sql.update( unlabelled_graph_database.ProgramGraph).values(split=None) proto_db.engine.execute(update) # Split the IR database and assign the splits to the unlabelled graphs. for split, ir_ids in enumerate(splitter.Split(ir_db)): with prof.Profile( f"Set {split} split on {humanize.Plural(len(ir_ids), 'IR ID')}" ): update = (sql.update(unlabelled_graph_database.ProgramGraph).where( unlabelled_graph_database.ProgramGraph.ir_id.in_( ir_ids)).values(split=split)) proto_db.engine.execute(update)
def test_unique_irs(populated_db: ir_database.Database, splitter: split.Splitter): """Test that all IR IDs are unique.""" splits = splitter.Split(populated_db) all_ids = np.concatenate(splits) assert len(set(all_ids)) == len(all_ids)