Example #1
0
def cartesian_gaussian():
    """Generate a 3D Cartesian Gaussian with random mean and covariance."""
    rotation = stats.special_ortho_group(3).rvs()
    eigenvalues = stats.gamma(2).rvs(3)
    cov = rotation @ np.diag(eigenvalues) @ rotation.T
    mean = stats.norm.rvs(size=3)
    return stats.multivariate_normal(mean, cov)
Example #2
0
    def test_frozen_matrix(self):
        dim = 7
        frozen = special_ortho_group(dim)

        rvs1 = frozen.rvs(random_state=1234)
        rvs2 = special_ortho_group.rvs(dim, random_state=1234)

        assert_equal(rvs1, rvs2)
Example #3
0
    def test_frozen_matrix(self):
        dim = 7
        frozen = special_ortho_group(dim)

        rvs1 = frozen.rvs(random_state=1234)
        rvs2 = special_ortho_group.rvs(dim, random_state=1234)

        assert_equal(rvs1, rvs2)
Example #4
0
    def build_moment_index(self, name, mesh, grid_size):
        print(f'Building HuIndex for {name}')
        n_samples = 1000
        for i in range(n_samples):
            print(f'{i}/{n_samples}')
            R = special_ortho_group(3).rvs()

            mask, offset, vt = self.mesh_to_mask(mesh, R, grid_size)
            hu = cv2.HuMoments(cv2.moments(np.float32(mask)))

            lower, upper = hu - abs(hu * 0.3), hu + abs(hu * 0.3)
            self.Index.insert(i, np.concatenate((lower, upper), axis = 0), obj = (R, offset, vt))
Example #5
0
def pair_model(model, FLAGS, node_embed):
    # Generate a dataset where two atoms are very close to each other and everything else is very far
    # Indices for atoms
    atom_names = ["X", "C", "N", "O", "S"]
    residue_names = [
        "ALA",
        "ARG",
        "ASN",
        "ASP",
        "CYS",
        "GLU",
        "GLN",
        "GLY",
        "HIS",
        "ILE",
        "LEU",
        "LYS",
        "MET",
        "PHE",
        "PRO",
        "SER",
        "THR",
        "TRP",
        "TYR",
        "VAL",
    ]

    energies_output_dict = {}

    def make_key(n_rotations, residue_name1, residue_name2, atom_name1,
                 atom_name2):
        return f"{n_rotations}_{residue_name1}_{residue_name2}_{atom_name1}_{atom_name2}"

    # Save a copy of the node embed
    node_embed = node_embed[0]
    node_embed_orig = node_embed.clone()

    # Try different combinations
    for n_rotations in [5]:
        # Rotations
        so3 = special_ortho_group(3)
        rot_matrix_neg = so3.rvs(
            n_rotations)  # number of random rotations to average

        residue_names_proc = ["ALA", "TYR", "LEU"]
        atom_names_proc = ["C", "N", "O"]
        for residue_name1, residue_name2 in itertools.product(
                residue_names_proc, repeat=2):
            for atom_name1, atom_name2 in itertools.product(atom_names_proc,
                                                            repeat=2):
                eps = []
                energies = []

                residue_index1 = residue_names.index(residue_name1)
                residue_index2 = residue_names.index(residue_name2)
                atom_index1 = atom_names.index(atom_name1)
                atom_index2 = atom_names.index(atom_name2)

                for i in np.linspace(0.1, 1.0, 100):
                    node_embed = node_embed_orig.clone()
                    node_embed[-2, -3:] = torch.Tensor([1.0, 0.5, 0.5])
                    node_embed[-1, -3:] = torch.Tensor([1.0 + i, 0.5, 0.5])
                    node_embed[-1, 0] = residue_index1
                    node_embed[-2, 0] = residue_index2
                    node_embed[-1, 1] = atom_index1
                    node_embed[-2, 1] = atom_index2
                    node_embed[-1, 2] = 6  # res_counter
                    node_embed[-2, 2] = 6  # res_counter

                    node_embed = np.tile(node_embed[None, :, :],
                                         (n_rotations, 1, 1))
                    node_embed[:, :, -3:] = np.matmul(node_embed[:, :, -3:],
                                                      rot_matrix_neg)
                    node_embed_feed = torch.Tensor(node_embed).cuda()
                    node_embed_feed[:, :,
                                    -3:] = node_embed_feed[:, :,
                                                           -3:] - node_embed_feed[:, :, -3:].mean(
                                                               dim=1,
                                                               keepdim=True)
                    energy = model.forward(node_embed_feed)  #
                    energy = energy.mean()

                    eps.append(i * 10)
                    energies.append(energy.item())

                key = make_key(n_rotations, residue_name1, residue_name2,
                               atom_name1, atom_name2)
                energies_output_dict[key] = (eps, energies)

                # Optionally make plots here -- potentially add conditions to avoid making too many
                plt.plot(eps, energies)
                plt.xlabel("Atom Distance")
                plt.ylabel("Energy")
                plt.title(
                    f"{n_rotations} rots: {atom_name1}, {atom_name2} distance in {residue_name1}/{residue_name2}"
                )
                plt.savefig(
                    f"distance_plots/{n_rotations}_{atom_name1}_{atom_name2}_in_{residue_name1}_{residue_name2}_distance.png"
                )
                plt.clf()

    # Back to outside
    output_path = osp.join(FLAGS.outdir, "atom_distances.p")
    pickle.dump(energies_output_dict, open(output_path, "wb"))
Example #6
0
def new_model(model, FLAGS, node_embed):
    BATCH_SIZE = 120
    pdb_name = FLAGS.pdb_name  #'6mdw.0'
    pickle_file = f"/private/home/yilundu/dataset/mmcif/mmCIF/{pdb_name[1:3]}/{pdb_name}.p"
    (node_embed, ) = pickle.load(open(pickle_file, "rb"))
    par, child, pos, pos_exist, res, chis_valid = parse_dense_format(
        node_embed)
    angles = compute_dihedral(par, child, pos, pos_exist)

    chis_target_initial = angles[:, 4:8].copy(
    )  # dihedral for backbone (:4); dihedral for sidechain (4:8)

    NUM_RES = len(res)
    all_energies = np.empty(
        (NUM_RES, 4, 360))  # 4 is number of possible chi angles

    surface_core_type = []
    for idx in range(NUM_RES):
        dist = np.sqrt(np.square(pos[idx:idx + 1, 2] - pos[:, 2]).sum(axis=1))
        neighbors = (dist < 10).sum()
        if neighbors >= 24:
            surface_core_type.append("core")
        elif neighbors <= 16:
            surface_core_type.append("surface")
        else:
            surface_core_type.append("unlabeled")

    for idx in tqdm(range(NUM_RES)):
        for chi_num in range(4):
            if not chis_valid[idx, chi_num]:
                continue

            # init_angle = chis_target[idx, chi_num]
            for angle_deltas in batch(range(-180, 180, 3), BATCH_SIZE):
                pre_rot_node_embed_short = []
                for angle_delta in angle_deltas:
                    chis_target = chis_target_initial.copy(
                    )  # make a local copy

                    # modify the angle by angle_delta amount. rotate to chis_target
                    chis_target[
                        idx,
                        chi_num] += angle_delta  # Set the specific chi angle to be the sampled value

                    # pos_new is n residues x 20 atoms x 3 (xyz)
                    pos_new = rotate_dihedral_fast(angles, par, child, pos,
                                                   pos_exist, chis_target,
                                                   chis_valid, idx)
                    node_neg_embed = reencode_dense_format(
                        node_embed, pos_new, pos_exist)

                    # sort the atoms by how far away they are
                    # sort key is the first atom on the sidechain
                    pos_chosen = pos_new[idx, 4]
                    close_idx = np.argsort(
                        np.square(node_neg_embed[:, -3:] -
                                  pos_chosen).sum(axis=1))

                    # Grab the 64 closest atoms
                    node_embed_short = node_neg_embed[
                        close_idx[:FLAGS.max_size]].copy()

                    # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
                    node_embed_short[:,
                                     -3:] = node_embed_short[:, -3:] - np.mean(
                                         node_embed_short[:, -3:], axis=0)
                    node_embed_short = torch.from_numpy(
                        node_embed_short).float().cuda()
                    pre_rot_node_embed_short.append(
                        node_embed_short.unsqueeze(0))
                pre_rot_node_embed_short = torch.stack(
                    pre_rot_node_embed_short)

                # Now rotate all elements
                n_rotations = 100
                so3 = special_ortho_group(3)
                rot_matrix = so3.rvs(n_rotations)  # n x 3 x 3
                node_embed_short = pre_rot_node_embed_short.repeat(
                    1, n_rotations, 1, 1)
                rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
                node_embed_short[:, :, :, -3:] = torch.matmul(
                    node_embed_short[:, :, :, -3:],
                    rot_matrix)  # (batch_size, n_rotations, 64, 20)

                # Compute the energies for the n_rotations * batch_size for this window of 64 atoms.
                # Batch the first two dimensions, then pull them apart aftewrads.
                node_embed_short = node_embed_short.reshape(
                    node_embed_short.shape[0] * node_embed_short.shape[1],
                    *node_embed_short.shape[2:],
                )
                energies = model.forward(node_embed_short)  # (12000, 1)

                # divide the batch dimension by the 10 things we just did
                energies = energies.reshape(BATCH_SIZE, -1)  # (10, 200)

                # Average the energy across the n_rotations, but keeping batch-wise seperate
                energies = energies.mean(1)  # (10, 1)

                # Save the result
                all_energies[idx, chi_num,
                             angle_deltas] = energies.cpu().numpy()

    # Can use these for processing later.
    avg_chi_angle_energy = (
        all_energies * chis_valid[:NUM_RES, :4, None]).sum(0) / np.expand_dims(
            chis_valid[:NUM_RES, :4].sum(0),
            1)  # normalize by how many times each chi angle occurs
    output = {
        "all_energies": all_energies,
        "chis_valid": chis_valid,
        "chis_target_initial": chis_target_initial,
        "avg_chi_angle_energy":
        avg_chi_angle_energy,  # make four plots from this (4, 360),
        "res": res,
        "surface_core_type": surface_core_type,
    }
    # Dump the output
    output_path = osp.join(FLAGS.outdir, f"{pdb_name}_rot_energies.p")
    if not osp.exists(FLAGS.outdir):
        os.makedirs(FLAGS.outdir)
    pickle.dump(output, open(output_path, "wb"))
Example #7
0
def make_tsne(model, FLAGS, node_embed):
    """
    grab representations for each of the residues in a pdb
    """
    pdb_name = FLAGS.pdb_name
    pickle_file = MMCIF_PATH + f"/mmCIF/{pdb_name[1:3]}/{pdb_name}.p"
    (node_embed, ) = pickle.load(open(pickle_file, "rb"))
    par, child, pos, pos_exist, res, chis_valid = parse_dense_format(
        node_embed)
    angles = compute_dihedral(par, child, pos, pos_exist)

    all_hiddens = []
    all_energies = []

    n_rotations = 2
    so3 = special_ortho_group(3)
    rot_matrix = so3.rvs(n_rotations)
    rot_matrix = torch.from_numpy(rot_matrix).float().cuda()

    for idx in range(len(res)):
        # sort the atoms by how far away they are
        # sort key is the first atom on the sidechain
        pos_chosen = pos[idx, 4]
        close_idx = np.argsort(
            np.square(node_embed[:, -3:] - pos_chosen).sum(axis=1))

        # Grab the 64 closest atoms
        node_embed_short = node_embed[close_idx[:FLAGS.max_size]].copy()

        # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
        node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
            node_embed_short[:, -3:], axis=0)
        node_embed_short = torch.from_numpy(node_embed_short).float().cuda()

        node_embed_short = node_embed_short[None, :, :].repeat(
            n_rotations, 1, 1)
        node_embed_short[:, :, -3:] = torch.matmul(node_embed_short[:, :, -3:],
                                                   rot_matrix)

        # Compute the energies for the n_rotations * batch_size for this window of 64 atoms.
        # Batch the first two dimensions, then pull them apart aftewrads.
        # node_embed_short = node_embed_short.reshape(node_embed_short.shape[0] * node_embed_short.shape[1], *node_embed_short.shape[2:])

        energies, hidden = model.forward(node_embed_short,
                                         return_hidden=True)  # (12000, 1)

        # all_hiddens.append(hidden.mean(0)) # mean over the rotations
        all_hiddens.append(hidden[0])  # take first rotation
        all_energies.append(energies[0])

    surface_core_type = []
    for idx in range(len(res)):
        # >16 c-beta neighbors within 10A of its own c-beta (or c-alpha for gly).
        hacked_pos = np.copy(pos)
        swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1)  # (20, 59, 3)
        idxs_to_change = swap_hacked_pos[4] == [0, 0, 0]  # (59, 3)
        swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][idxs_to_change]
        hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1)

        dist = np.sqrt(
            np.square(hacked_pos_final[idx:idx + 1, 4] -
                      hacked_pos_final[:, 4]).sum(axis=1))
        neighbors = (dist < 10).sum()

        if neighbors >= 24:
            surface_core_type.append("core")
        elif neighbors <= 16:
            surface_core_type.append("surface")
        else:
            surface_core_type.append("unlabeled")

    output = {
        "res": res,
        "surface_core_type": surface_core_type,
        "all_hiddens": torch.stack(all_hiddens).cpu().numpy(),
        "all_energies": torch.stack(all_energies).cpu().numpy(),
    }
    # Dump the output
    output_path = osp.join(FLAGS.outdir, f"{pdb_name}_representations.p")
    if not osp.exists(FLAGS.outdir):
        os.makedirs(FLAGS.outdir)
    pickle.dump(output, open(output_path, "wb"))
Example #8
0
def rotamer_trials(model, FLAGS, test_dataset):
    test_files = test_dataset.files
    random.shuffle(test_files)
    db = load_rotamor_library()
    so3 = special_ortho_group(3)

    node_embed_evals = []
    nminibatch = 4

    if FLAGS.ensemble > 1:
        models = model

    # The three different sampling methods are weighted_gauss, gmm, rosetta
    rotamer_scores_total = []
    surface_scores_total = []
    buried_scores_total = []
    amino_recovery_total = {}
    for k, v in kvs.items():
        amino_recovery_total[k.lower()] = []

    counter = 0
    rotations = FLAGS.rotations

    for test_file in tqdm(test_files):
        (node_embed, ) = pickle.load(open(test_file, "rb"))
        node_embed_original = node_embed
        par, child, pos, pos_exist, res, chis_valid = parse_dense_format(
            node_embed)
        angles = compute_dihedral(par, child, pos, pos_exist)

        amino_recovery = {}
        for k, v in kvs.items():
            amino_recovery[k.lower()] = []

        if node_embed is None:
            continue

        rotamer_scores = []
        surface_scores = []
        buried_scores = []
        types = []

        gt_chis = []
        node_embed_evals = []
        neg_chis = []
        valid_chi_idxs = []
        res_names = []

        neg_sample = FLAGS.neg_sample

        n_amino = pos.shape[0]
        amino_recovery_curr = {}
        for idx in range(1, n_amino - 1):
            res_name = res[idx]
            if res_name == "gly" or res_name == "ala":
                continue

            res_names.append(res_name)

            gt_chis.append(angles[idx, 4:8])
            valid_chi_idxs.append(chis_valid[idx, :4])

            hacked_pos = np.copy(pos)
            swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1)  # (20, 59, 3)
            idxs_to_change = swap_hacked_pos[4] == [0, 0, 0]  # (59, 3)
            swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][
                idxs_to_change]
            hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1)

            neighbors = np.linalg.norm(
                pos[idx:idx + 1, 4] - hacked_pos_final[:, 4], axis=1) < 10
            neighbors = neighbors.astype(np.int32).sum()

            if neighbors >= 24:
                types.append("buried")
            elif neighbors < 16:
                types.append("surface")
            else:
                types.append("neutral")

            if neighbors >= 24:
                tresh = 0.98
            else:
                tresh = 0.95

            if FLAGS.sample_mode == "weighted_gauss":
                chis_list = interpolated_sample_normal(db,
                                                       angles[idx, 1],
                                                       angles[idx, 2],
                                                       res[idx],
                                                       neg_sample,
                                                       uniform=False)
            elif FLAGS.sample_mode == "gmm":
                chis_list = mixture_sample_normal(db,
                                                  angles[idx, 1],
                                                  angles[idx, 2],
                                                  res[idx],
                                                  neg_sample,
                                                  uniform=False)
            elif FLAGS.sample_mode == "rosetta":
                chis_list = exhaustive_sample(db,
                                              angles[idx, 1],
                                              angles[idx, 2],
                                              res[idx],
                                              tresh=tresh)

            neg_chis.append(chis_list)

            node_neg_embeds = []
            length_chis = len(chis_list)
            for i in range(neg_sample):
                chis_target = angles[:, 4:8].copy()

                if i >= len(chis_list):
                    node_neg_embed_copy = node_neg_embed.copy()
                    node_neg_embeds.append(node_neg_embeds[i % length_chis])
                    neg_chis[-1].append(chis_list[i % length_chis])
                    continue

                chis = chis_list[i]

                chis_target[idx] = (
                    chis * chis_valid[idx, :4] +
                    (1 - chis_valid[idx, :4]) * chis_target[idx])
                pos_new = rotate_dihedral_fast(angles, par, child, pos,
                                               pos_exist, chis_target,
                                               chis_valid, idx)

                node_neg_embed = reencode_dense_format(node_embed, pos_new,
                                                       pos_exist)
                node_neg_embeds.append(node_neg_embed)

            node_neg_embeds = np.array(node_neg_embeds)
            dist = np.square(node_neg_embeds[:, :, -3:] -
                             pos[idx:idx + 1, 4:5, :]).sum(axis=2)
            close_idx = np.argsort(dist)
            node_neg_embeds = np.take_along_axis(node_neg_embeds,
                                                 close_idx[:, :64, None],
                                                 axis=1)
            node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] / 10.0
            node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] - np.mean(
                node_neg_embeds[:, :, -3:], axis=1, keepdims=True)

            node_embed_evals.append(node_neg_embeds)

            if len(node_embed_evals) == nminibatch or idx == (n_amino - 2):
                n_entries = len(node_embed_evals)
                node_embed_evals = np.concatenate(node_embed_evals)
                s = node_embed_evals.shape

                # For sample rotations per batch
                node_embed_evals = np.tile(node_embed_evals[:, None, :, :],
                                           (1, rotations, 1, 1))
                rot_matrix = so3.rvs(rotations)

                if rotations == 1:
                    rot_matrix = rot_matrix[None, :, :]

                node_embed_evals[:, :, :, -3:] = np.matmul(
                    node_embed_evals[:, :, :, -3:], rot_matrix[None, :, :, :])
                node_embed_evals = node_embed_evals.reshape((-1, *s[1:]))

                node_embed_feed = torch.from_numpy(
                    node_embed_evals).float().cuda()

                with torch.no_grad():
                    energy = 0
                    if FLAGS.ensemble > 1:
                        for model in models:
                            energy_tmp = model.forward(node_embed_feed)
                            energy = energy + energy_tmp
                    else:
                        energy = model.forward(node_embed_feed)

                energy = energy.view(n_entries, -1, rotations).mean(dim=2)
                select_idx = torch.argmin(energy, dim=1).cpu().numpy()

                for i in range(n_entries):
                    select_idx_i = select_idx[i]
                    valid_chi_idx = valid_chi_idxs[i]
                    rotamer_score, _ = compute_rotamer_score_planar(
                        gt_chis[i], neg_chis[i][select_idx_i],
                        valid_chi_idx[:4], res_names[i])
                    rotamer_scores.append(rotamer_score)

                    amino_recovery[str(res_names[i])] = amino_recovery[str(
                        res_names[i])] + [rotamer_score]

                    if types[i] == "buried":
                        buried_scores.append(rotamer_score)
                    elif types[i] == "surface":
                        surface_scores.append(rotamer_score)

                gt_chis = []
                node_embed_evals = []
                neg_chis = []
                valid_chi_idxs = []
                res_names = []
                types = []

            counter += 1

        rotamer_scores_total.append(np.mean(rotamer_scores))

        if len(buried_scores) > 0:
            buried_scores_total.append(np.mean(buried_scores))
        surface_scores_total.append(np.mean(surface_scores))

        for k, v in amino_recovery.items():
            if len(v) > 0:
                amino_recovery_total[k] = amino_recovery_total[k] + [
                    np.mean(v)
                ]

        print(
            "Obtained a rotamer recovery score of ",
            np.mean(rotamer_scores_total),
            np.std(rotamer_scores_total) / len(rotamer_scores_total)**0.5,
        )
        print(
            "Obtained a buried recovery score of ",
            np.mean(buried_scores_total),
            np.std(buried_scores_total) / len(buried_scores_total)**0.5,
        )
        print(
            "Obtained a surface recovery score of ",
            np.mean(surface_scores_total),
            np.std(surface_scores_total) / len(surface_scores_total)**0.5,
        )
        for k, v in amino_recovery_total.items():
            print(
                "per amino acid recovery of {} score of ".format(k),
                np.mean(v),
                np.std(v) / len(v)**0.5,
            )
Example #9
0
    def __init__(
        self,
        FLAGS,
        mmcif_path="./mmcif",
        split="train",
        rank_idx=0,
        world_size=1,
        uniform=True,
        weighted_gauss=False,
        gmm=False,
        chi_mean=False,
        valid=False,
    ):
        files = []
        dirs = os.listdir(osp.join(mmcif_path, "mmCIF"))

        self.split = split
        self.so3 = special_ortho_group(3)
        self.chi_mean = chi_mean
        self.weighted_gauss = weighted_gauss
        self.gmm = gmm
        self.uniform = uniform

        # Filter out proteins in test dataset
        for d in tqdm(dirs):
            directory = osp.join(mmcif_path, "mmCIF", d)
            d_files = os.listdir(directory)
            files_tmp = [osp.join(directory, d_file) for d_file in d_files if ".p" in d_file]

            for f in files_tmp:
                name = f.split("/")[-1]
                name = name.split(".")[0]
                if name in test_rotamers and self.split == "test":
                    files.append(f)
                elif name not in test_rotamers and self.split in ["train", "val"]:
                    files.append(f)

        self.files = files

        if split in ["train", "val"]:
            duplicate_seqs = set()

            # Remove proteins in the train dataset that are too similar to the test dataset
            with open(osp.join(mmcif_path, "duplicate_sequences.txt"), "r") as f:
                for line in f:
                    duplicate_seqs.add(line.strip())

            fids = set()

            # Remove low resolution proteins
            with open(
                osp.join(mmcif_path, "cullpdb_pc90_res1.8_R0.25_d190807_chains14857"), "r"
            ) as f:
                i = 0
                for line in f:
                    if i is not 0:
                        fid = line.split()[0]
                        if fid not in duplicate_seqs:
                            fids.add(fid)

                    i += 1

            files_new = []

            alphabet = []
            for letter in range(65, 91):
                alphabet.append(chr(letter))

            for f in files:
                tup = (f.split("/")[-1]).split(".")

                if int(tup[1]) >= len(alphabet):
                    continue

                seq_id = tup[0].upper() + alphabet[int(tup[1])]

                if seq_id in fids:
                    files_new.append(f)

            self.files = files_new
        elif split == "test":
            fids = set()

            # Remove low resolution proteins
            with open(
                osp.join(mmcif_path, "cullpdb_pc90_res1.8_R0.25_d190807_chains14857"), "r"
            ) as f:
                i = 0
                for line in f:
                    if i is not 0:
                        fid = line.split()[0]
                        fids.add(fid)

                    i += 1

            files_new = []

            alphabet = []
            for letter in range(65, 91):
                alphabet.append(chr(letter))

            for f in files:
                tup = (f.split("/")[-1]).split(".")

                if int(tup[1]) >= len(alphabet):
                    continue

                seq_id = tup[0].upper() + alphabet[int(tup[1])]

                if seq_id in fids:
                    files_new.append(f)

            self.files = files_new

        chunksize = len(self.files) // world_size

        n = len(self.files)

        # Set up a validation dataset
        if split == "train":
            n = self.files[int(0.95 * n) :]
        elif split == "val":
            n = self.files[: int(0.95 * n)]

        self.FLAGS = FLAGS
        self.db = load_rotamor_library()
        print(f"Loaded {len(self.files)} files for {split} dataset split")

        self.split = split