Esempio n. 1
0
def get_chembl(n_mols=None, as_mols=True, option='', max_size=1000):
    """ 
        Return list of SMILES.
        NOTE: this function should be located
        in the same directory as data files.
    """
    path = os.path.join(__location__, "ChEMBL.txt")
    with open(path, "r") as f:
        if n_mols is None:
            res = [line.strip() for line in f]
        else:
            res = [f.readline().strip() for _ in range(n_mols)]
    mols = [Molecule(smile) for smile in res]
    if len(mols) < max_size:
        return mols

    gen = np.random.RandomState(42)
    mols = list(gen.choice(mols, max_size, replace=False))
    if option == '':
        return mols
    elif option == 'small_qed':
        qed_func = get_objective_by_name("qed")
        return [mol for mol in mols if qed_func(mol) < 0.6]
    elif option == 'large_qed':
        qed_func = get_objective_by_name("qed")
        return [mol for mol in mols if qed_func(mol) >= 0.6]
    else:
        raise ValueError(f"Dataset filter {option} not supported.")
def compute_sa_score_datasets():
    sas = get_objective_by_name("sascore")
    chembl = get_chembl(max_size=50)
    res = [sas(m) for m in chembl]
    print("ChEMBL: {:.3f} +- std {:.3f}".format(np.mean(res), np.std(res)))
    zinc = get_zinc250(max_size=50)
    res = [sas(m) for m in zinc]
    print("ZINC: {:.3f} +- std {:.3f}".format(np.mean(res), np.std(res)))
def parse_min_synthesizability(exp_path):
    sas = get_objective_by_name("sascore")
    sa_score = None
    with open(os.path.join(exp_path, 'exp_log'), 'r') as f:
        for line in f:
            if 'Minimum synthesis score over the path' in line:
                sa_score = float(line.split()[-1])
    if not sa_score: return
    return sa_score
Esempio n. 4
0
def get_chembl(option='', max_size=1000, as_mols=True):
    """ 
    Return list of Molecules.
    NOTE: this function should be located
    in the same directory as data files.

    Arguments:
        option {str} -- either empty or of format '{small,large}_{objective name}'
        max_size {int} -- number of molecules to sample, if None, returns all,
            else randomly samples a subset. Attention: there is a randomly set random seed
            that seeds this sampler now, so the subset will always be the same.
        as_mols {bool} -- whether to wrap SMILES into the Molecule class
    """
    path = os.path.join(__location__, "ChEMBL.txt")
    with open(path, "r") as f:
        mols = [line.strip() for line in f]
    if as_mols:
        mols = [Molecule(smile) for smile in mols]

    if max_size == -1:
        max_size = len(mols)
    if len(mols) <= max_size:
        return mols

    # TODO: this logic is off, if filtering afterwards,
    # we get less than max_size molecules in the end.
    # Fix this if needed.
    gen = np.random.RandomState(42)
    mols = list(gen.choice(mols, max_size, replace=False))
    if option == '':
        return mols
    elif option.startswith('small_'):
        obj_name = option.split("_")[1]
        obj_func = get_objective_by_name(obj_name)
        small_thresh = get_threshold(obj_name, mode='low')
        return [mol for mol in mols if obj_func(mol) < small_thresh]
    elif option.startswith('large_'):
        obj_name = option.split("_")[1]
        obj_func = get_objective_by_name(obj_name)
        large_thresh = get_threshold(obj_name, mode='high')
        return [mol for mol in mols if obj_func(mol) >= large_thresh]
    else:
        raise ValueError(f"Dataset filter {option} not supported.")
Esempio n. 5
0
def print_pool_statistics(dataset, seed, n=30):
    from mols.mol_functions import get_objective_by_name
    objective = "qed"
    samp = MolSampler(dataset, seed)
    pool = samp(n)
    obj_func = get_objective_by_name(objective)
    props = [obj_func(mol) for mol in pool]
    print(
        f"Properties of pool: quantity {len(pool)}, min {np.min(props)}, avg {np.mean(props)}, max {np.max(props)}, std {np.std(props)}"
    )
def compute_synthesizability(exp_path):
    sas = get_objective_by_name("sascore")
    mol = None
    with open(os.path.join(exp_path, 'exp_log'), 'r') as f:
        for line in f:
            if 'Resulting molecule' in line:
                mol = Molecule(smiles=line.split()[2])
    if not mol: return
    sa_score = sas(mol)
    return sa_score
Esempio n. 7
0
def display_dataset_statistics(dataset):
    chembl = get_chembl(max_size=10000)

    qed_func = get_objective_by_name('qed')
    plogp_func = get_objective_by_name('plogp')

    mol_values_qed = [qed_func(mol) for mol in chembl]
    mol_values_plogp = [plogp_func(mol) for mol in chembl]

    plt.title('Distribution of QED values in ChEMBL')
    plt.hist(mol_values_qed, bins=200, density=True)
    plt.xticks(np.arange(0, max(mol_values_qed) + 0.1, 0.1))
    plt.savefig(f'./experiments/visualizations/{dataset}_qed_histogram.pdf')
    plt.clf()

    plt.title('Distribution of penalized LogP values in ChEMBL')
    plt.hist(mol_values_plogp, bins=200, density=True)
    plt.xticks(np.arange(-20, max(mol_values_plogp) + 1, 4))
    plt.savefig(f'./experiments/visualizations/{dataset}_plogp_histogram.pdf')
    plt.clf()
Esempio n. 8
0
def main():
    setup_logging()
    args = parse_args()
    # Obtain a reporter and worker manager
    reporter = get_reporter(open(EXP_LOG_FILE, 'w'))
    worker_manager = SyntheticWorkerManager(num_workers=N_WORKERS,
                                            time_distro='const')

    # Problem settings
    objective_func = get_objective_by_name(args.objective)
    # check MolDomain constructor for full argument list:
    domain_config = {
        'data_source': args.dataset,
        'constraint_checker':
        'organic',  # not specifying constraint_checker defaults to None
        'sampling_seed': args.seed
    }
    chemist_args = {
        'acq_opt_method': 'rand_explorer',
        'init_capital': args.init_pool_size,
        'dom_mol_kernel_type': args.
        kernel,  # e.g. 'distance_kernel_expsum', 'similarity_kernel', 'wl_kernel'
        'acq_opt_max_evals': args.steps,
        'objective': args.objective,
        'max_pool_size': args.max_pool_size,
        'report_results_every': 1,
        'gpb_hp_tune_criterion': 'ml'
    }

    chemist = Chemist(objective_func,
                      domain_config=domain_config,
                      chemist_args=chemist_args,
                      is_mf=False,
                      worker_manager=worker_manager,
                      reporter=reporter)
    opt_val, opt_point, history = chemist.run(args.budget)

    # convert to raw format
    raw_opt_point = chemist.get_raw_domain_point_from_processed(opt_point)
    opt_mol = raw_opt_point[0]

    # Print the optimal value and visualize the molecule and path.
    reporter.writeln(f"\nOptimum value found: {opt_val}")
    reporter.writeln(
        f"Optimum molecule: {opt_mol} with formula {opt_mol.to_formula()}")
    reporter.writeln(f"Synthesis path: {opt_mol.get_synthesis_path()}")

    # visualize mol/synthesis path
    visualize_file = os.path.join(EXP_DIR, 'optimal_molecule.png')
    reporter.writeln(f'Optimal molecule visualized in {visualize_file}')
    visualize_mol(opt_mol, visualize_file)

    with open(SYN_PATH_FILE, 'wb') as f:
        pkl.dump(opt_mol.get_synthesis_path(), f)
def make_pairwise(func, n_mols, to_randomize=True):


    if func == 'prop':
        smile_strings, smiles_to_prop = get_chembl_prop(n_mols=n_mols)
        prop_list = [smiles_to_prop[sm] for sm in smile_strings]
    else:
        n_mols_to_get = 5 * n_mols if to_randomize else n_mols
        mols = get_chembl(n_mols=n_mols_to_get)
        np.random.shuffle(mols)
        mols = mols[:n_mols]
        smile_strings = [mol.to_smiles() for mol in mols]
        func_ = get_objective_by_name(func)
        prop_list = [func_(mol) for mol in mols]

    dist_computer = OTChemDistanceComputer()  # <-- default computer
    dists = dist_computer(smile_strings, smile_strings)

    num_rows = max(2, int(np.ceil(dist_computer.get_num_distances() / 4.0)))
    print(num_rows)
    f, ll_ax = plt.subplots(num_rows, 4, figsize=(15, 15))
    axes = itertools.chain.from_iterable(ll_ax)
    for ind, (ax, distmat) in enumerate(zip(axes, dists)):

        xs, ys = [], []
        pairs = []
        for i in range(n_mols):
            for j in range(i, n_mols):
                dist_in_dist = distmat[i, j]
                dist_in_val = np.abs(prop_list[i] - prop_list[j])
                xs.append(dist_in_dist)
                ys.append(dist_in_val)
                pairs.append((i,j))
#                 pairs.append('(%d,%d)'%(i,j))

        ax.set_title(f'Distance {ind}')  # TODO: parameters of distance
        if n_mols > 12:
          ax.scatter(xs, ys, s=1, alpha=0.6)
        else:
          for xval, yval, pval in zip(xs, ys, pairs):
            print(xval, yval, pval)
            if pval[0] == pval[1]:
#               ax.scatter([xval], [yval], s=1, alpha=0.8)
              ax.text(xval, yval, '*', fontsize=14)
            else:
              ax.text(xval, yval, '(%d, %d)'%(pval[0], pval[1]))
          ax.set_xlim((0.0, max(xs) * 1.25))
#         ax.set_xticks([])
#         ax.set_yticks([])

    plt.savefig(os.path.join(VIS_DIR, "dist_vs_value_%d_%s_%s"%(n_mols, func,
                             datetime.now().strftime('%m%d%H%M%S'))))
    print(smile_strings, len(smile_strings))
Esempio n. 10
0
def get_zinc250(option='', max_size=1000):
    path = os.path.join(__location__, "zinc250k.csv")
    zinc_df = pd.read_csv(path)
    list_of_smiles = list(map(lambda x: x.strip(), zinc_df.smiles.values))
    # other columns are logP, qed, and sas
    mols = [Molecule(smile) for smile in res]
    if len(mols) < max_size:
        return mols

    gen = np.random.RandomState(42)
    mols = list(gen.choice(mols, max_size, replace=False))
    if option == '':
        return mols
    elif option == 'small_qed':
        qed_func = get_objective_by_name("qed")
        return [mol for mol in mols if qed_func(mol) < 0.6]
    elif option == 'large_qed':
        qed_func = get_objective_by_name("qed")
        return [mol for mol in mols if qed_func(mol) >= 0.6]
    else:
        raise ValueError(f"Dataset filter {option} not supported.")
Esempio n. 11
0
def run_screen(init_pool_size, seed, budget, objective, dataset, iter_num):
    obj_func = get_objective_by_name(objective)
    sampler = MolSampler(dataset, sampling_seed=seed + iter_num)
    pool = sampler(init_pool_size)
    real_budget = budget - init_pool_size
    opt_val = max([obj_func(mol) for mol in pool])
    for i in range(real_budget):
        # pick a new point randomly
        new_point = sampler(1)[0]
        opt_val = max(obj_func(new_point), opt_val)
        pool.append(new_point)
    print("Optimal value: {:.3f}".format(opt_val))
    return opt_val
Esempio n. 12
0
def get_zinc250(option='', max_size=1000, as_mols=True):
    """ 
    Return list of Molecules.
    NOTE: this function should be located
    in the same directory as data files.

    Arguments:
        option {str} -- either empty or of format '{small,large}_{objective name}'
        max_size {int} -- number of molecules to sample, if None, returns all,
            else randomly samples a subset. Attention: there is a randomly set random seed
            that seeds this sampler now, so the subset will always be the same.
        as_mols {bool} -- whether to wrap SMILES into the Molecule class
    """
    path = os.path.join(__location__, "zinc250k.csv")
    zinc_df = pd.read_csv(path)
    list_of_smiles = list(map(lambda x: x.strip(), zinc_df.smiles.values))
    # other columns are logP, qed, and sas
    mols = [Molecule(smile) for smile in list_of_smiles]

    if max_size == -1:
        max_size = len(mols)
    if len(mols) <= max_size:
        return mols

    gen = np.random.RandomState(42)
    mols = list(gen.choice(mols, max_size, replace=False))
    if option == '':
        return mols
    elif option.startswith('small_'):
        obj_func = get_objective_by_name(option.split("_")[1])
        return [mol for mol in mols if obj_func(mol) < 0.6]
    elif option.startswith('large_'):
        obj_func = get_objective_by_name(option.split("_")[1])
        return [mol for mol in mols if obj_func(mol) >= 0.6]
    else:
        raise ValueError(f"Dataset filter {option} not supported.")
Esempio n. 13
0
def explore_and_validate_synth(init_pool_size, seed, budget, objective,
                               dataset, max_pool_size, reporter):
    """
    This experiment is equivalent to unlimited-evaluation optimization.
    It compares optimal found vs optimal over pool, and checks if synthesizeability is improved.
    """
    obj_func = get_objective_by_name(objective)
    sampler = MolSampler(dataset, sampling_seed=seed)
    pool = sampler(init_pool_size)
    exp = RandomExplorer(obj_func,
                         initial_pool=pool,
                         max_pool_size=max_pool_size)
    real_budget = budget - init_pool_size

    props = [obj_func(mol) for mol in pool]
    reporter.writeln(
        f"Properties of pool: quantity {len(pool)}, min {np.min(props)}, avg {np.mean(props)}, max {np.max(props)}"
    )
    reporter.writeln(f"Starting {objective} optimization")

    t0 = time.time()
    top_value, top_point, history = exp.run(real_budget)

    reporter.writeln("Finished run in {:.3f} minutes".format(
        (time.time() - t0) / 60))
    reporter.writeln(f"Is a valid molecule: {check_validity(top_point)}")
    reporter.writeln(f"Resulting molecule: {top_point}")
    reporter.writeln(f"Top score: {obj_func(top_point)}")
    reporter.writeln(
        f"Minimum synthesis score over the path: {compute_min_sa_score(top_point)}"
    )
    with open(SYN_PATH_FILE, 'wb') as f:
        pkl.dump(top_point.get_synthesis_path(), f)

    sorted_by_prop = sorted(pool, key=obj_func)[-5:]
    for opt_mol in sorted_by_prop:
        min_sa_score = compute_min_sa_score(opt_mol)
        reporter.writeln(
            f"Minimum synthesis score of optimal molecules: {min_sa_score}")

    vals = history['objective_vals']
    plt.title(f'Optimizing {objective} with random explorer')
    plt.plot(range(len(vals)), vals)
    plt.savefig(PLOT_FILE, format='eps', dpi=1000)
    with open(OPT_VALS_FILE, 'w') as f:
        f.write(' '.join([str(v) for v in vals]))
Esempio n. 14
0
def plot_tsne(func):
    n_mols = 250
    mols = get_chembl(max_size=n_mols, as_mols=True)
    smile_strings = [m.smiles for m in mols]

    title = f"{func} ot-dist"
    distance_computer = OTChemDistanceComputer(
        mass_assignment_method='molecular_mass',
        normalisation_method='total_mass',
        struct_pen_method='bond_frac')
    distances_mat = distance_computer(smile_strings, smile_strings)[0]

    # title = f"{func} similarity kernel"
    # kernel = mol_kern_factory('similarity_kernel')
    # kern_mat = kernel(mols, mols)
    # distances_mat = 1/kern_mat

    # title = f"{func} fingerprint dist"
    # distances_mat = np.zeros((len(smile_strings), len(smile_strings)))
    # for i in tqdm(range(len(smile_strings))):
    #     for j in range(len(smile_strings)):
    #         distances_mat[i, j] = np.sum((mols[i].to_fingerprint(ftype='numeric') -
    #             mols[j].to_fingerprint(ftype='numeric')) ** 2 )

    tsne = TSNE(metric='precomputed')
    points_to_plot = tsne.fit_transform(distances_mat)

    mols = get_chembl(max_size=n_mols)
    smile_strings = [mol.to_smiles() for mol in mols]
    func_ = get_objective_by_name(func)
    prop_list = [func_(mol) for mol in mols]

    plt.title(title, fontsize=22)
    plt.scatter(points_to_plot[:, 0],
                points_to_plot[:, 1],
                c=prop_list,
                cmap=plt.cm.Spectral,
                s=15,
                alpha=0.8)
    plt.xticks([])
    plt.yticks([])
    # extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(VIS_DIR,
                             title.replace(" ", "_") + '.eps'),
                format='eps',
                dpi=1000)  # bbox_inches=extent, pad_inches=0
Esempio n. 15
0
def compute_min_sa_score(mol):
    """ Compute sas scores along the synthesis path of molecule. """
    sa_score = get_objective_by_name("sascore")

    def get_min_score(syn):
        res = float('inf')
        for mol, syn_graph in syn.items():
            # if mol.begin_flag:
            if isinstance(syn_graph, str):
                return sa_score(Molecule(mol))
            res = min(res, get_min_score(syn_graph))
        return res

    synthesis_path = mol.get_synthesis_path()
    if isinstance(synthesis_path, dict):
        min_sa_score = get_min_score(synthesis_path)
    else:
        min_sa_score = sa_score(Molecule(synthesis_path))
    return min_sa_score
Esempio n. 16
0
def make_pairwise_kernel(kernel_name, func, **kwargs):
    n_mols = 100

    mols = get_chembl(max_size=n_mols)
    # smile_strings = [mol.to_smiles() for mol in mols]
    func_ = get_objective_by_name(func)
    kernel = mol_kern_factory(kernel_name, **kwargs)
    kern_mat = kernel(mols, mols)
    prop_list = [func_(mol) for mol in mols]

    xs, ys = [], []
    for i in range(n_mols):
        for j in range(n_mols):
            if mode == "inverse_sim":
                dist_in_dist = 1 / kern_mat[i, j]
            elif mode == "scaled_kernel":
                dist_in_dist = 1 / kern_mat[i, j]
                dist_in_dist /= np.sqrt(kern_mat[i, i] * kern_mat[j, j])
            elif mode == "fps_distance":
                dist_in_dist = np.sum(
                    (mols[i].to_fingerprint(ftype='numeric') -
                     mols[j].to_fingerprint(ftype='numeric'))**2)
            else:
                raise ValueError

            dist_in_val = np.abs(prop_list[i] - prop_list[j])
            xs.append(dist_in_dist)
            ys.append(dist_in_val)

    fig = plt.figure()  # figsize=fsize
    ax = fig.add_subplot(1, 1, 1)
    plt.scatter(xs, ys, s=2, alpha=0.6)
    # plt.yscale('log')
    plt.xscale('log')
    plt.xlim([11, 80])
    plt.xticks([])
    plt.yticks([])
    # extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(VIS_DIR, f"{kernel_name}_{func}.eps"),
                format='eps',
                dpi=1000)  # bbox_inches=extent, pad_inches=0
    plt.clf()
Esempio n. 17
0
def make_pairwise(func, as_subplots=False):
    n_mols = 100

    if func == 'prop':
        smile_strings, smiles_to_prop = get_chembl_prop(n_mols=n_mols)
        prop_list = [smiles_to_prop[sm] for sm in smile_strings]
    else:
        mols = get_chembl(max_size=n_mols)
        smile_strings = [mol.to_smiles() for mol in mols]
        func_ = get_objective_by_name(func)
        prop_list = [func_(mol) for mol in mols]

    dist_computers = [
        OTChemDistanceComputer(mass_assignment_method='equal',
                               normalisation_method='none',
                               struct_pen_method='bond_frac'),
        OTChemDistanceComputer(mass_assignment_method='equal',
                               normalisation_method='total_mass',
                               struct_pen_method='bond_frac'),
        OTChemDistanceComputer(mass_assignment_method='molecular_mass',
                               normalisation_method='none',
                               struct_pen_method='bond_frac'),
        OTChemDistanceComputer(mass_assignment_method='molecular_mass',
                               normalisation_method='total_mass',
                               struct_pen_method='bond_frac')
    ]
    titles = [
        'Unit weight, Unnormalized', 'Unit weight, Normalized',
        'Molecular mass weight, Unnormalized',
        'Molecular mass weight, Normalized'
    ]

    f, ll_ax = plt.subplots(2, 2, figsize=(15, 15))
    axes = itertools.chain.from_iterable(ll_ax)
    for ind, (ax, dist_computer,
              title) in enumerate(zip(axes, dist_computers, titles)):
        distmat = dist_computer(smile_strings, smile_strings)[0]
        xs, ys = [], []
        for i in range(n_mols):
            for j in range(n_mols):
                dist_in_dist = distmat[i, j]
                dist_in_val = np.abs(prop_list[i] - prop_list[j])
                xs.append(dist_in_dist)
                ys.append(dist_in_val)

        if as_subplots:
            ax.set_title(title)
            ax.scatter(xs, ys, s=2, alpha=0.6)
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            # save separately:
            plt.clf()
            fig = plt.figure()  # figsize=fsize
            ax = fig.add_subplot(1, 1, 1)
            plt.title(title, fontsize=22)
            plt.scatter(xs, ys, s=2, alpha=0.6)
            plt.xscale('log')
            plt.xticks([])
            plt.yticks([])
            plt.xlim([None, 1.03 * max(xs)])
            plt.xlabel("OT-distance, log scale", fontsize=20)
            if ind == 0:
                plt.ylabel(f"Difference in SA score", fontsize=20)
                extent = ax.get_window_extent().transformed(
                    fig.dpi_scale_trans.inverted())
                extent.x0 -= 0.5
                extent.x1 += 0.1
                extent.y0 -= 0.6
                extent.y1 += 0.7
            else:
                extent = ax.get_window_extent().transformed(
                    fig.dpi_scale_trans.inverted())
                extent.x0 -= 0.5
                extent.x1 += 0.1
                extent.y0 -= 0.6
                extent.y1 += 0.7
            plt.savefig(
                os.path.join(VIS_DIR, f"dist_vs_value_{func}_{ind+1}.pdf"),
                bbox_inches=extent,
                pad_inches=0
            )  #bbox_inches=extent, pad_inches=0, format='eps', dpi=1000,
            plt.clf()

    if as_subplots:
        plt.savefig(os.path.join(VIS_DIR, f"dist_vs_value_{func}.eps"),
                    format='eps',
                    dpi=1000)
        plt.clf()
 def test_plogp(self):
     plogp = get_objective_by_name("plogp")
     print(plogp(self.mol))
 def test_qed(self):
     qed = get_objective_by_name("qed")
     qed(self.mol)
 def test_sas(self):
     sas = get_objective_by_name("sascore")
     sas(self.mol)
Esempio n. 21
0
def make_tsne(func, as_subplots=False):
    """
    Plot TSNE embeddings colored with property
    for several distance computers.
    """
    n_mols = 200

    dist_computers = [
        OTChemDistanceComputer(mass_assignment_method='equal',
                               normalisation_method='none',
                               struct_pen_method='bond_frac'),
        OTChemDistanceComputer(mass_assignment_method='equal',
                               normalisation_method='total_mass',
                               struct_pen_method='bond_frac'),
        OTChemDistanceComputer(mass_assignment_method='molecular_mass',
                               normalisation_method='none',
                               struct_pen_method='bond_frac'),
        OTChemDistanceComputer(mass_assignment_method='molecular_mass',
                               normalisation_method='total_mass',
                               struct_pen_method='bond_frac')
    ]
    titles = [
        'Equal mass assign, no norm', 'Equal mass assign, total mass norm',
        'Mol mass assign, no norm', 'Mol mass assign, total mass norm'
    ]

    smile_strings, smiles_to_prop = get_chembl_prop(n_mols=n_mols)
    if func == 'prop':
        smile_strings, smiles_to_prop = get_chembl_prop(n_mols=n_mols)
        prop_list = [smiles_to_prop[sm] for sm in smile_strings]
    else:
        mols = get_chembl(max_size=n_mols)
        smile_strings = [mol.to_smiles() for mol in mols]
        func_ = get_objective_by_name(func)
        prop_list = [func_(mol) for mol in mols]

    f, ll_ax = plt.subplots(2, 2, figsize=(15, 15))
    axes = itertools.chain.from_iterable(ll_ax)
    for ind, (ax, dist_computer,
              title) in enumerate(zip(axes, dist_computers, titles)):
        distances_mat = dist_computer(smile_strings, smile_strings)[0]

        # plot them
        tsne = TSNE(metric='precomputed')
        points_to_plot = tsne.fit_transform(distances_mat)
        if as_subplots:
            ax.set_title(title)
            ax.scatter(points_to_plot[:, 0],
                       points_to_plot[:, 1],
                       c=prop_list,
                       cmap=plt.cm.Spectral,
                       s=9,
                       alpha=0.8)
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            # save separately:
            plt.clf()
            fig = plt.figure()  # figsize=fsize
            ax = fig.add_subplot(1, 1, 1)
            plt.title(title)
            plt.scatter(points_to_plot[:, 0],
                        points_to_plot[:, 1],
                        c=prop_list,
                        cmap=plt.cm.Spectral,
                        s=9,
                        alpha=0.8)
            plt.xticks([])
            plt.yticks([])
            # extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            plt.savefig(os.path.join(VIS_DIR,
                                     f'tsne_vis_{func}_{dist_computer}.eps'),
                        format='eps',
                        dpi=1000)  # bbox_inches=extent, pad_inches=0
            plt.clf()

    if as_subplots:
        plt.savefig(os.path.join(VIS_DIR, f'tsne_vis_{func}.eps'),
                    format='eps',
                    dpi=1000)
        plt.clf()