iterations=iterations,
                                                  random_seed=random_seed,
                                                  verbose=verbose,
                                                  exe_dir=exe_dir)

spike_info = preproc_kilo.generate_spike_info_from_full_tsne(
    kilosort_folder=kilosort_folder_denoised, tsne_folder=tsne_cortex_folder)

# OR Load previously run t-sne
tsne_results = tsne_io.load_tsne_result(files_dir=tsne_cortex_folder)

# and previously generated spike_info
spike_info = pd.read_pickle(join(tsne_cortex_folder, 'spike_info.df'))

# Have a look
viz.plot_tsne_of_spikes(spike_info=spike_info, legent_on=False)

# Update the original spike info (created after just cleaning) with the new spike_info information from manually sorting
# on the t-sne
spike_info_after_cleaning = preproc_kilo.generate_spike_info_after_cleaning(
    kilosort_folder_denoised)
spike_info_cortex_sorted = spike_info
tsne_filename = join(tsne_cortex_folder, 'result.dat')
spike_info_after_cortex_sorting_file = join(
    kilosort_folder, 'spike_info_after_cortex_sorting.df')
spike_info_after_sorting = preproc_kilo.add_sorting_info_to_spike_info(
    spike_info_after_cleaning,
    spike_info_cortex_sorted,
    tsne_filename=tsne_filename,
    save_to_file=spike_info_after_cortex_sorting_file)
# Find the mua templates with large spike count
number_of_spikes_in_large_mua_templates = 10000
large_mua_templates = preproc_kilo.find_large_mua_templates(
    kilosort_folder, number_of_spikes_in_large_mua_templates)

mua_template_to_look = 2
mua_template = large_mua_templates[mua_template_to_look]

#tsne_folder_buggy = join(tsne_folder, 'Single_MUA_templates_Bug')
spike_info = np.load(
    join(tsne_folder, 'template_{}'.format(mua_template),
         'spike_info_original.df'))
spike_info = preproc_kilo.load_spike_info_of_template(tsne_folder,
                                                      mua_template)

vis.plot_tsne_of_spikes(spike_info)


def find_spikes_in_rect(spike_info, left, right, top, bottom):
    xs = np.array(spike_info['tsne_x'].tolist())
    ys = np.array(spike_info['tsne_y'].tolist())
    index_l = np.squeeze(np.argwhere(xs > left))
    index_r = np.squeeze(np.argwhere(xs < right))
    index_t = np.squeeze(np.argwhere(ys < top))
    index_b = np.squeeze(np.argwhere(ys > bottom))
    result = np.intersect1d(np.intersect1d(index_l, index_r),
                            np.intersect1d(index_t, index_b))
    return result


def is_single_unit(times):