Exemplo n.º 1
0
def cluster_datasets_rsa(args) -> None:
    # Get executor
    if args.processor == "joblib":
        set_loky_pickler("pickle")
        executor = ExecutorJoblib(int(args.n_procs))
    elif args.processor == "map":
        executor = ExecutorMap()
    executor_map = ExecutorMap()

    print("defined executor")

    # Get Dataset
    print("Loading datasets")
    dataset = MultiCrystalDataset.mcd_from_pandda_input_dir(
        p.Path(args.data_dirs))
    # dataset.datasets = {key: value for key, value in list(dataset.datasets.items())[:150]}
    print("got datasets")

    print("\tBefore tuncation on res there are {} datasets".format(
        len(dataset.datasets)))
    dataset = truncate_dataset_on_resolution(dataset)
    print("\tAfter truncationt on res here are {} datasets".format(
        len(dataset.datasets)))

    # Select lowest res dataset
    min_res = get_min_res(dataset)

    # Load xmaps
    print("Getting exmpas")
    xmaps = mapdict(
        Loader(min_res=min_res),
        dataset.datasets,
        executor,
    )

    # Align models to ref
    reference_dtag = list(dataset.datasets.keys())[0]
    reference_model = dataset.datasets[reference_dtag].structure.structure

    # Get model alignmetns
    aligners = {}
    for dtag, d in dataset.datasets.items():
        # print("\tAligning {} to {}".format(dtag, reference_dtag))
        aligners[dtag] = StructureAligner(
            reference_model,
            d.structure.structure,
        )

    alignments = mapdict(
        wrap_call,
        aligners,
        executor_map,
    )

    # Sample Xmaps uniformly
    # grid_params = [50,50,50]
    reference_cell = xmaps[reference_dtag].xmap.cell
    grid_params = (
        int(reference_cell.a),
        int(reference_cell.b),
        int(reference_cell.c),
    )
    print("Grid params are: {}".format(grid_params))
    samplers = {}
    for dtag, xmap in xmaps.items():
        rtop = alignments[dtag]
        mobile_to_ref_translation = rtop[1]
        mobile_to_ref_rotation = rtop[0]

        # print("\tDataset {} translation is: {}".format(dtag, mobile_to_ref_translation))
        # print("\tDataset {} rotation is: {}".format(dtag, mobile_to_ref_rotation.flatten()))

        samplers[dtag] = Sampler(
            xmap,
            grid_params,
            mobile_to_ref_translation,
            mobile_to_ref_rotation,
        )

    nxmaps = mapdict(
        wrap_call,
        samplers,
        executor,
    )

    # Convert nxmaps to np
    print("Converting xmaps to np")
    xmaps_np = mapdict(
        nxmap_to_numpy,
        nxmaps,
        executor,
    )

    # Rescale
    rescaler = Rescaler()
    xmaps_np = mapdict(
        rescaler,
        xmaps_np,
        executor,
    )
    postcondtion_scaling(xmaps_np)

    # Align xmaps
    print("aligning xmaps")
    static_image = xmaps_np[list(xmaps_np.keys())[0]]
    aligners = {}
    for dtag, xmap_np in xmaps_np.items():
        aligners[dtag] = ImageAligner(
            static_image,
            xmap_np,
        )

    xmaps_aligned = mapdict(
        wrap_call,
        aligners,
        executor,
    )

    # Embed the xmaps into a latent space
    print("Dimension reducing")
    xmap_embedding = embed_xmaps(xmaps_np.values())

    # Cluster the embedding
    print("clustering")
    clustered_xmaps = cluster_embedding(xmap_embedding)

    # Make dataframe with cluster and position
    print("getting cluster dataframe")
    cluster_df = get_cluster_df(
        xmaps_np,
        xmap_embedding,
        clustered_xmaps,
    )
    cluster_df_summary(cluster_df)

    # Get clusters
    print("associating xmaps with clusters")
    map_clusters = get_map_clusters(
        cluster_df,
        xmaps_aligned,
    )

    #  Make mean maps
    executor_seriel = ExecutorMap()
    mean_maps_np = mapdict(
        lambda x: make_mean_map(x),
        map_clusters,
        executor_seriel,
    )

    # Output the mean maps
    print("Outputting mean maps")
    template_map = xmaps[list(xmaps.keys())[0]]
    # print(template_map)
    # cell = dataset.datasets[list(dataset.datasets.keys())[0]].reflections.hkl_info.cell
    cell = clipper_python.Cell(
        clipper_python.Cell_descr(
            grid_params[0],
            grid_params[1],
            grid_params[2],
            np.pi / 2,
            np.pi / 2,
            np.pi / 2,
        ))
    for cluster_num, mean_map_np in mean_maps_np.items():
        output_mean_nxmap(
            mean_map_np,
            cell,
            p.Path(args.out_dir) / "{}.ccp4".format(cluster_num),
            grid_params,
        )
        # output_mean_nxmap(mean_map_np_list,
        #                   cell,
        #                   p.Path(args.out_dir) / "{}.ccp4".format(cluster_num),
        #                   )

    # Ouptut the csv
    print("Outputting csv")
    output_labeled_embedding_csv(
        cluster_df,
        str(p.Path(args.out_dir) / "labelled_embedding.csv"),
    )

    # Output the graph
    # output_cluster_graph(cluster_df, str(out_dir / "output_graph.png"))

    return cluster_df
Exemplo n.º 2
0
def cluster_datasets_tm(args) -> None:
    # Get executor
    if args.processor == "joblib":
        set_loky_pickler("pickle")
        executor = ExecutorJoblib(20)
    elif args.processor == "map":
        executor = ExecutorMap()
    executor_map = ExecutorMap()

    print("defined executor")

    # Get Dataset
    print("Loading datasets")
    dataset = MultiCrystalDataset.mcd_from_pandda_input_dir(
        p.Path(args.data_dirs))
    print("got datasets")

    print("\tBefore tuncation on res there are {} datasets".format(
        len(dataset.datasets)))
    dataset = truncate_dataset_on_resolution(dataset)
    print("\tAfter truncationt on res here are {} datasets".format(
        len(dataset.datasets)))

    # Select lowest res dataset
    min_res = get_min_res(dataset)

    # # Align structures
    # aligners =
    # alignments = mapdict(wrap_call,
    #                      alginers,
    #                      dataset.datasets.items(),
    #                      )
    #
    #
    # # Load xmaps
    # print("Getting exmpas")
    # xmaps = mapdict(Loader(min_res=None,
    #                        grid_params=grid_params,
    #                        ),
    #                 dataset.datasets,
    #                 executor,
    #                 )

    # Sample to reference frame
    cell = clipper_python.Cell(
        clipper_python.Cell_descr(100, 100, 100, np.pi / 2, np.pi / 2,
                                  np.pi / 2))
    spacegroup = clipper_python.Spacegroup(clipper_python.Spgr_descr("1"))

    # Get reference model
    reference_model = list(dataset.datasets.items())[0][1].structure.structure
    io = PDBIO()
    io.set_structure(reference_model)
    io.save('out_before.pdb')
    atoms_list = np.array(
        [atom.get_coord() for atom in reference_model.get_atoms()])
    print(atoms_list)
    mean_coords = np.mean(atoms_list, axis=0)
    print(mean_coords)
    # rotation = rotmat(Vector(0,1,0), Vector(1, 0, 0))
    rotation = np.eye(3)
    translation = np.array(mean_coords, 'f')
    for atom in reference_model.get_atoms():
        atom.transform(rotation, -translation)

    io = PDBIO()
    io.set_structure(reference_model)
    io.save('out.pdb')
    # exit()

    # Get model alignmetns
    aligners = {
        dtag: lambda: align(ref_structure=reference_model,
                            mobile_structure=d.structure.structure)
        for dtag, d in dataset.datasets.items()
    }
    print(aligners)
    alignments = mapdict(
        wrap_call,
        aligners,
        executor_map,
    )

    # align structures
    def align_structures(alignment, structure, output_path):
        alignment.apply(structure)
        io = PDBIO()
        io.set_structure(structure)
        print("Saving to {}".format(output_path))
        io.save(str(output_path))
        return structure

    structure_aligners = {
        dtag: lambda: align_structures(
            alignments[dtag],
            d.structure.structure,
            p.Path(args.out_dir) / dtag,
        )
        for dtag, d in dataset.datasets.items()
    }
    print(structure_aligners)
    print("aligning and outputting structures")
    aligned_structures = mapdict(
        wrap_call,
        structure_aligners,
        executor_map,
    )

    # structure_aligners = {dtag: lambda: alignment.apply(dataset.dataset[dtag].structure.structure)
    #                              for dtag, alignment
    #                              in alignments.items()
    #                              }
    # aligned_structures = mapdict(wrap_call,
    #                      structure_aligners,
    #                      executor_map,
    #                      )
    #
    # # ouput aligned structures
    # structure_outputters = {dtag: lambda: PDBIO().set_structure()}

    exit()

    # Convert xmaps to np
    print("Converting xmaps to np")
    static_structure = dataset.datasets[list(
        dataset.datasets.keys())[0]].structure.structure
    # print("Got static structure")
    aligners = {}
    for dtag, xmap in xmaps.items():
        aligners[dtag] = StructureAligner(
            static_structure,
            dataset.datasets[dtag].structure.structure,
            xmap,
        )

    xmaps_aligned = mapdict(
        wrap_call,
        aligners,
        executor,
    )

    # Rescale
    rescaler = Rescaler()
    xmaps_np = mapdict(
        rescaler,
        xmaps_aligned,
        executor,
    )
    postcondtion_scaling(xmaps_np)

    # Align xmaps
    print("aligning xmaps")

    xmaps_np = filter_on_grid(xmaps_np)

    # Embed the xmaps into a latent space
    print("Dimension reducing")
    xmap_embedding = embed_xmaps(xmaps_np.values())

    # Cluster the embedding
    print("clustering")
    clustered_xmaps = cluster_embedding(xmap_embedding)

    # Make dataframe with cluster and position
    print("getting cluster dataframe")
    cluster_df = get_cluster_df(
        xmaps_np,
        xmap_embedding,
        clustered_xmaps,
    )
    cluster_df_summary(cluster_df)

    # Get clusters
    print("associating xmaps with clusters")
    map_clusters = get_map_clusters(
        cluster_df,
        xmaps_aligned,
    )

    #  Make mean maps
    executor_seriel = ExecutorMap()
    mean_maps_np = mapdict(
        lambda x: make_mean_map(x),
        map_clusters,
        executor_seriel,
    )

    # Output the mean maps
    print("Outputting mean maps")
    template_map = xmaps[list(xmaps.keys())[0]]
    # print(template_map)
    # cell = dataset.datasets[list(dataset.datasets.keys())[0]].reflections.hkl_info.cell
    cell = clipper_python.Cell(
        clipper_python.Cell_descr(100, 100, 100, np.pi / 2, np.pi / 2,
                                  np.pi / 2))
    for cluster_num, mean_map_np_list in mean_maps_np.items():
        output_mean_map(
            template_map,
            mean_map_np_list,
            p.Path(args.out_dir) / "{}.ccp4".format(cluster_num),
        )
        # output_mean_nxmap(mean_map_np_list,
        #                   cell,
        #                   p.Path(args.out_dir) / "{}.ccp4".format(cluster_num),
        #                   )

    # Ouptut the csv
    print("Outputting csv")
    output_labeled_embedding_csv(
        cluster_df,
        str(p.Path(args.out_dir) / "labelled_embedding.csv"),
    )

    # Output the graph
    # output_cluster_graph(cluster_df, str(out_dir / "output_graph.png"))

    return cluster_df
Exemplo n.º 3
0
def cluster_datasets(args) -> None:
    # Get executor
    if args.processor == "joblib":
        set_loky_pickler("pickle")
        executor = ExecutorJoblib(20)
    elif args.processor == "map":
        executor = ExecutorMap()
    executor_map = ExecutorMap()

    print("defined executor")

    # Get Dataset
    print("Loading datasets")
    dataset = MultiCrystalDataset.mcd_from_pandda_input_dir(
        p.Path(args.data_dirs))
    # dataset.datasets = {key: value for key, value in list(dataset.datasets.items())[:150]}
    print("got datasets")

    print("\tBefore tuncation on res there are {} datasets".format(
        len(dataset.datasets)))
    dataset = truncate_dataset_on_resolution(dataset)
    print("\tAfter truncationt on res here are {} datasets".format(
        len(dataset.datasets)))
    #
    #
    # unit_cell_clustering_labels = cluster_unit_cell_params(dataset)
    # print("\tBefore tuncation there are {} datasets".format(len(dataset.datasets)))
    # dataset = truncate_dataset_on_clustering(dataset, unit_cell_clustering_labels)
    # print("\tAfter truncationt here are {} datasets".format(len(dataset.datasets)))

    # Select lowest res dataset
    min_res = get_min_res(dataset)
    # dataset = truncate_on_res(dataset)

    # d = dataset.datasets[list(dataset.datasets.keys())[0]]
    # xmap = MCDXMap.xmap_from_dataset(d,
    #                                  resolution=Resolution(min_res),
    #                                  )
    # print(xmap)
    # xmap_np = xmap_to_numpy_crystalographic_axis(xmap)
    # print(np.std(xmap_np))
    #
    # with open("test.pkl", "wb") as f:
    #     pickle.dump(xmap, f)
    #
    # with open("test.pkl", "rb") as f:
    #     xmap_reloaded = pickle.load(f)
    #
    #
    # xmap_reloaded_np = xmap_to_numpy_crystalographic_axis(xmap_reloaded)
    # print(np.std(xmap_reloaded_np))

    #
    # exit()

    # Load xmaps
    print("Getting exmpas")
    # xmaps = mapdict(lambda d: MCDXMap.xmap_from_dataset(d,
    #                                                     resolution=Resolution(min_res),
    #                                                     ),
    #                 dataset.datasets,
    #                 executor,
    #                 )
    reference_dataset = dataset.datasets[list(dataset.datasets.keys())[0]]
    # print(dir(reference_dataset))
    reference_grid = clipper_python.Grid_sampling(
        reference_dataset.reflections.hkl_info.spacegroup,
        reference_dataset.reflections.hkl_info.cell,
        Resolution(min_res),
    )
    grid_params = (
        reference_grid.nu,
        reference_grid.nv,
        reference_grid.nw,
    )
    xmaps = mapdict(
        Loader(
            min_res=None,
            grid_params=grid_params,
        ),
        dataset.datasets,
        executor,
    )
    # print("Xmap std is: {}".format(np.std(xmaps[list(xmaps.keys())[0]].xmap.export_numpy())))

    # Convert xmaps to np
    print("Converting xmaps to np")
    # xmaps_np = mapdict(xmap_to_numpy_crystalographic_axis,
    #                    xmaps,
    #                    executor,
    #                    )
    # postcondition_alignment(xmaps_np)
    # xmaps_np = mapdict(interpolate_uniform_grid,
    #                    xmaps,
    #                    executor,
    #                    )
    # print("\tXmap representative shape is: {}".format(xmaps_np[list(xmaps_np.keys())[0]].shape))

    static_structure = dataset.datasets[list(
        dataset.datasets.keys())[0]].structure.structure
    # print("Got static structure")
    aligners = {}
    for dtag, xmap in xmaps.items():
        aligners[dtag] = StructureAligner(
            static_structure,
            dataset.datasets[dtag].structure.structure,
            xmap,
        )

    xmaps_aligned = mapdict(
        wrap_call,
        aligners,
        executor,
    )

    # Rescale
    rescaler = Rescaler()
    xmaps_np = mapdict(
        rescaler,
        xmaps_aligned,
        executor,
    )
    postcondtion_scaling(xmaps_np)

    # Align xmaps
    print("aligning xmaps")
    # static_map = xmaps_np[list(xmaps_np.keys())[0]]
    # aligner = Aligner(static_map)
    # xmaps_aligned = mapdict(aligner,
    #                         xmaps_np,
    #                         executor,
    #                         )

    xmaps_np = filter_on_grid(xmaps_np)

    # Embed the xmaps into a latent space
    print("Dimension reducing")
    xmap_embedding = embed_xmaps(xmaps_np.values())

    # Cluster the embedding
    print("clustering")
    clustered_xmaps = cluster_embedding(xmap_embedding)

    # Make dataframe with cluster and position
    print("getting cluster dataframe")
    cluster_df = get_cluster_df(
        xmaps_np,
        xmap_embedding,
        clustered_xmaps,
    )
    cluster_df_summary(cluster_df)

    # Get clusters
    print("associating xmaps with clusters")
    map_clusters = get_map_clusters(
        cluster_df,
        xmaps_aligned,
    )
    # print("Map clusters: {}".format(map_clusters))

    #  Make mean maps
    executor_seriel = ExecutorMap()
    mean_maps_np = mapdict(
        lambda x: make_mean_map(x),
        map_clusters,
        executor_seriel,
    )
    # print("Mean maps mp: {}".format(mean_maps_np))

    # Output the mean maps
    print("Outputting mean maps")
    template_map = xmaps[list(xmaps.keys())[0]]
    # print(template_map)
    # cell = dataset.datasets[list(dataset.datasets.keys())[0]].reflections.hkl_info.cell
    cell = clipper_python.Cell(
        clipper_python.Cell_descr(100, 100, 100, np.pi / 2, np.pi / 2,
                                  np.pi / 2))
    for cluster_num, mean_map_np_list in mean_maps_np.items():
        output_mean_map(
            template_map,
            mean_map_np_list,
            p.Path(args.out_dir) / "{}.ccp4".format(cluster_num),
        )
        # output_mean_nxmap(mean_map_np_list,
        #                   cell,
        #                   p.Path(args.out_dir) / "{}.ccp4".format(cluster_num),
        #                   )

    # Ouptut the csv
    print("Outputting csv")
    output_labeled_embedding_csv(
        cluster_df,
        str(p.Path(args.out_dir) / "labelled_embedding.csv"),
    )

    # Output the graph
    # output_cluster_graph(cluster_df, str(out_dir / "output_graph.png"))

    return cluster_df
Exemplo n.º 4
0
def cluster_datasets_luigi(ccp4_map_paths, out_dir, processor,
                           n_procs) -> pd.DataFrame:
    # Get executor
    if processor == "joblib":
        set_loky_pickler("pickle")
        executor = ExecutorJoblib(int(n_procs))
    elif processor == "map":
        executor = ExecutorMap()

    # Load xmaps to np
    print("Getting exmpas")
    loaders = {}
    for dtag in ccp4_map_paths:
        loaders[dtag] = LoaderCCP4(ccp4_map_paths[dtag]["path"])
    xmaps_np = mapdict(
        wrap_call,
        loaders,
        executor,
    )

    # Truncate to same shape
    dims = np.vstack([xmap.shape for xmap in xmaps_np.values()])
    print(dims)
    truncation_shape = np.min(
        dims,
        axis=0,
    )
    print("The truncation shape is: {}".format(truncation_shape))
    xmaps_np = {
        dtag: xmap_np[:truncation_shape[0], :truncation_shape[1], :
                      truncation_shape[2], ]
        for dtag, xmap_np in xmaps_np.items()
    }

    # Embed the xmaps into a latent space
    print("Dimension reducing")
    xmap_embedding = embed_xmaps(xmaps_np.values())

    # Cluster the embedding
    print("clustering")
    # clustered_xmaps = cluster_embedding(xmap_embedding)
    clustering = cluster_embedding(xmap_embedding)

    exemplars = clustering.exemplars_
    clustered_xmaps = clustering.labels_
    outlier_scores = clustering.outlier_scores_
    print("Exemplars: {}".format(exemplars))
    print("Outlier scores: {}".format(outlier_scores))
    print("Labels: {}".format(clustered_xmaps))

    dtags = np.array(list(xmaps_np.keys()))
    cluster_exemplars = {}
    for cluster_num in np.unique(clustered_xmaps):
        if cluster_num == -1:
            continue
        cluster_dtags = dtags[clustered_xmaps == cluster_num]
        cluster_outlier_scores = outlier_scores[clustered_xmaps == cluster_num]
        cluster_exemplars[cluster_num] = cluster_dtags[np.argmin(
            cluster_outlier_scores)]
        print("Outlier scores dict: {}".format({
            dtag: outlier_score
            for dtag, outlier_score in zip(
                list(cluster_dtags),
                list(cluster_outlier_scores),
            )
        }))

    print("Cluster exemplars: {}".format(cluster_exemplars))

    # Make dataframe with cluster and position
    print("getting cluster dataframe")
    cluster_df = get_cluster_df(
        xmaps_np,
        xmap_embedding,
        clustered_xmaps,
    )
    cluster_df_summary(cluster_df)

    # Get clusters
    print("associating xmaps with clusters")
    map_clusters = get_map_clusters(
        cluster_df,
        xmaps_np,
    )

    #  Make mean maps
    executor_seriel = ExecutorMap()
    mean_maps_np = mapdict(
        lambda x: make_mean_map(x),
        map_clusters,
        executor_seriel,
    )

    # Output the mean maps
    print("Outputting mean maps")
    template_nxmap = load_ccp4_map(list(ccp4_map_paths.values())[0]["path"])
    for cluster_num, mean_map_np in mean_maps_np.items():
        dataset_clustering.xmap_utils.save_nxmap_from_template(
            template_nxmap,
            mean_map_np,
            p.Path(out_dir) / "{}.ccp4".format(cluster_num),
        )

    # Ouptut the csv
    print("Outputting csv")
    output_labeled_embedding_csv(
        cluster_df,
        str(p.Path(out_dir) / "labelled_embedding.csv"),
    )

    print("Cluster DF is: \n{}".format(cluster_df))

    return cluster_df