def visualize_trained_network_results(data_dict, z_dim=100): plt.close("all") data_name = data_dict.get('name', None) source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) cell_type_key = data_dict.get("cell_type", None) data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad") cell_types = data.obs[cell_type_key].unique().tolist() spec_cell_type = data_dict.get("spec_cell_types", None) if spec_cell_type is not []: cell_types = spec_cell_type for cell_type in cell_types: path_to_save = f"../results/CVAE/{data_name}/{cell_type}/{z_dim}/{source_key} to {target_key}/Visualizations/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) train_data = data.copy()[~((data.obs['condition'] == target_key) & (data.obs[cell_type_key] == cell_type))] cell_type_adata = data[data.obs[cell_type_key] == cell_type] network = trvae.CVAE( x_dimension=data.shape[1], z_dimension=z_dim, model_path=f"../models/CVAE/{data_name}/{cell_type}/{z_dim}/cvae") network.restore_model() if sparse.issparse(data.X): data.X = data.X.A feed_data = data.X train_labels, _ = trvae.label_encoder(data) fake_labels = np.ones(train_labels.shape) latent_with_true_labels = network.to_latent(feed_data, train_labels) latent_with_fake_labels = network.to_latent(feed_data, fake_labels) mmd_latent_with_true_labels = network.to_mmd_layer( feed_data, train_labels) mmd_latent_with_fake_labels = network.to_mmd_layer( feed_data, fake_labels) cell_type_ctrl = cell_type_adata.copy()[ cell_type_adata.obs['condition'] == source_key] print(cell_type_ctrl.shape, cell_type_adata.shape) pred_celltypes = network.predict(cell_type_ctrl, labels=np.ones( (cell_type_ctrl.shape[0], 1))) pred_adata = anndata.AnnData(X=pred_celltypes) pred_adata.obs['condition'] = ['predicted'] * pred_adata.shape[0] pred_adata.var = cell_type_adata.var if data_name == "pbmc": sc.tl.rank_genes_groups(cell_type_adata, groupby="condition", n_genes=100, method="wilcoxon") top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][ target_key].tolist() gene_list = top_100_genes[:10] else: sc.tl.rank_genes_groups(cell_type_adata, groupby="condition", n_genes=100, method="wilcoxon") top_50_down_genes = cell_type_adata.uns["rank_genes_groups"][ "names"][source_key].tolist() top_50_up_genes = cell_type_adata.uns["rank_genes_groups"][ "names"][target_key].tolist() top_100_genes = top_50_up_genes + top_50_down_genes gene_list = top_50_down_genes[:5] + top_50_up_genes[:5] cell_type_adata = cell_type_adata.concatenate(pred_adata) trvae.plotting.reg_mean_plot( cell_type_adata, top_100_genes=top_100_genes, gene_list=gene_list, condition_key='condition', axis_keys={ "x": 'predicted', 'y': target_key }, labels={ 'x': 'pred stim', 'y': 'real stim' }, legend=False, fontsize=20, textsize=14, title=cell_type, path_to_save=os.path.join( path_to_save, f'rcvae_reg_mean_{data_name}_{cell_type}.pdf')) trvae.plotting.reg_var_plot( cell_type_adata, top_100_genes=top_100_genes, gene_list=gene_list, condition_key='condition', axis_keys={ "x": 'predicted', 'y': target_key }, labels={ 'x': 'pred stim', 'y': 'real stim' }, legend=False, fontsize=20, textsize=14, title=cell_type, path_to_save=os.path.join( path_to_save, f'rcvae_reg_var_{data_name}_{cell_type}.pdf')) import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs['condition'] = data.obs['condition'].values latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels) latent_with_fake_labels.obs['condition'] = data.obs['condition'].values latent_with_fake_labels.obs[cell_type_key] = data.obs[ cell_type_key].values mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs['condition'] = data.obs[ 'condition'].values mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels) mmd_latent_with_fake_labels.obs['condition'] = data.obs[ 'condition'].values mmd_latent_with_fake_labels.obs[cell_type_key] = data.obs[ cell_type_key].values color = ['condition', cell_type_key] sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_{cell_type}_train_data', show=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_latent_with_true_labels", show=False) sc.pp.neighbors(latent_with_fake_labels) sc.tl.umap(latent_with_fake_labels) sc.pl.umap(latent_with_fake_labels, color=color, save=f"_{data_name}_{cell_type}_latent_with_fake_labels", show=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap( mmd_latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels", show=False) sc.pp.neighbors(mmd_latent_with_fake_labels) sc.tl.umap(mmd_latent_with_fake_labels) sc.pl.umap( mmd_latent_with_fake_labels, color=color, save=f"_{data_name}_{cell_type}_mmd_latent_with_fake_labels", show=False) sc.pl.violin(cell_type_adata, keys=top_100_genes[0], groupby='condition', save=f"_{data_name}_{cell_type}_{top_100_genes[0]}", show=False) plt.close("all")
def visualize_trained_network_results_multimodal(data_dict, z_dim=100): plt.close("all") data_name = data_dict.get('name', None) source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad") path_to_save = f"../results/RCVAE/{data_name}/{z_dim}/{source_key} to {target_key}/Visualizations/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) network = trvae.trVAE( x_dimension=data.shape[1], z_dimension=z_dim, model_path=f"../models/RCVAE/{data_name}/{z_dim}/", ) network.restore_model() if sparse.issparse(data.X): data.X = data.X.A feed_data = data.X train_labels, _ = trvae.label_encoder(data) fake_labels = np.ones(train_labels.shape) latent_with_true_labels = network.to_latent(feed_data, train_labels) latent_with_fake_labels = network.to_latent(feed_data, fake_labels) mmd_latent_with_true_labels = network.to_mmd_layer(network, feed_data, train_labels, feed_fake=False) mmd_latent_with_fake_labels = network.to_mmd_layer(network, feed_data, train_labels, feed_fake=True) import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs['condition'] = data.obs['condition'].values # latent_with_true_labels.obs[cell_type_key] = data.obs[cell_type_key].values latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels) latent_with_fake_labels.obs['condition'] = data.obs['condition'].values # latent_with_fake_labels.obs[cell_type_key] = data.obs[cell_type_key].values mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs['condition'] = data.obs['condition'].values # mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[cell_type_key].values mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels) mmd_latent_with_fake_labels.obs['condition'] = data.obs['condition'].values # mmd_latent_with_fake_labels.obs[cell_type_key] = data.obs[cell_type_key].values color = ['condition'] sc.pp.neighbors(data) sc.tl.umap(data) sc.pl.umap(data, color=color, save=f'_{data_name}_train_data', show=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_latent_with_true_labels", show=False) sc.pp.neighbors(latent_with_fake_labels) sc.tl.umap(latent_with_fake_labels) sc.pl.umap(latent_with_fake_labels, color=color, save=f"_{data_name}__latent_with_fake_labels", show=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap(mmd_latent_with_true_labels, color=color, save=f"_{data_name}_mmd_latent_with_true_labels", show=False) sc.pp.neighbors(mmd_latent_with_fake_labels) sc.tl.umap(mmd_latent_with_fake_labels) sc.pl.umap(mmd_latent_with_fake_labels, color=color, save=f"_{data_name}_mmd_latent_with_fake_labels", show=False) plt.close("all")
def visualize_trained_network_results(data_dict, z_dim=100, mmd_dimension=128, loss_fn='mse'): plt.close("all") data_name = data_dict['name'] source_keys = data_dict.get("source_conditions") target_keys = data_dict.get("target_conditions") cell_type_key = data_dict.get("cell_type", None) need_merge = data_dict.get('need_merge', False) label_encoder = data_dict.get('label_encoder', None) condition_key = data_dict.get('condition', 'condition') if need_merge: data, _ = merge_data(data_dict) else: data = sc.read(f"../data/{data_name}/{data_name}.h5ad") if loss_fn != 'mse': data = normalize_hvg(data, filter_min_counts=False, normalize_input=False, logtrans_input=True) cell_types = data.obs[cell_type_key].unique().tolist() spec_cell_type = data_dict.get("spec_cell_types", None) if spec_cell_type: cell_types = spec_cell_type for cell_type in cell_types: path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) train_data = data.copy()[~( (data.obs[condition_key].isin(target_keys)) & (data.obs[cell_type_key] == cell_type))] cell_type_adata = data.copy()[data.obs[cell_type_key] == cell_type] n_conditions = len(train_data.obs[condition_key].unique().tolist()) network = trvae.trVAEMulti( x_dimension=data.shape[1], z_dimension=z_dim, loss_fn=loss_fn, n_conditions=n_conditions, mmd_dimension=mmd_dimension, model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/", ) network.restore_model() if sparse.issparse(data.X): data.X = data.X.A feed_data = data train_labels, _ = trvae.label_encoder(data, label_encoder, condition_key) fake_labels = [] for i in range(n_conditions): fake_labels.append(np.zeros(train_labels.shape) + i) latent_with_true_labels = network.to_z_latent(feed_data.X, train_labels) latent_with_fake_labels = [ network.to_z_latent(feed_data.X, fake_labels[i]) for i in range(n_conditions) ] mmd_latent_with_true_labels = network.to_mmd_layer(feed_data, train_labels, feed_fake=0) mmd_latent_with_fake_labels = [ network.to_mmd_layer(feed_data, train_labels, feed_fake=i) for i in range(n_conditions) ] if data_name in ["pbmc", 'endo_norm']: sc.tl.rank_genes_groups(cell_type_adata, groupby=condition_key, n_genes=100, method="wilcoxon") top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][ target_keys[-1]].tolist() gene_list = top_100_genes[:10] elif data_name in ['pancreas', 'nmuil_count']: gene_list = None top_100_genes = None else: sc.tl.rank_genes_groups(cell_type_adata, groupby=condition_key, n_genes=100, method="wilcoxon") top_50_down_genes = cell_type_adata.uns["rank_genes_groups"][ "names"][source_keys[0]].tolist() top_50_up_genes = cell_type_adata.uns["rank_genes_groups"][ "names"][target_keys[-1]].tolist() top_100_genes = top_50_up_genes + top_50_down_genes gene_list = top_50_down_genes[:5] + top_50_up_genes[:5] perturbation_list = data_dict.get("perturbation", []) pred_adatas = None for source, dest, name, source_label, target_label in perturbation_list: print(source, dest, name) pred_adata = visualize_multi_perturbation_between( network, cell_type_adata, pred_adatas, source_condition=source, target_condition=dest, name=name, source_label=source_label, target_label=target_label, cell_type=cell_type, data_name=data_name, top_100_genes=top_100_genes, gene_list=gene_list, path_to_save=path_to_save, condition_key=condition_key) if pred_adatas is None: pred_adatas = pred_adata else: pred_adatas = pred_adatas.concatenate(pred_adata) pred_adatas.write_h5ad( filename=f"../data/reconstructed/RCVAEMulti/{data_name}.h5ad") import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) color = [condition_key, cell_type_key] latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values latent_with_fake_labels = [ sc.AnnData(X=latent_with_fake_labels[i]) for i in range(n_conditions) ] for i in range(n_conditions): latent_with_fake_labels[i].obs[condition_key] = data.obs[ condition_key].values latent_with_fake_labels[i].obs[cell_type_key] = data.obs[ cell_type_key].values sc.pp.neighbors(latent_with_fake_labels[i]) sc.tl.umap(latent_with_fake_labels[i]) sc.pl.umap( latent_with_fake_labels[i], color=color, save=f"_{data_name}_{cell_type}_latent_with_fake_labels_{i}", show=False, wspace=0.15, frameon=False) mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values mmd_latent_with_fake_labels = [ sc.AnnData(X=mmd_latent_with_fake_labels[i]) for i in range(n_conditions) ] for i in range(n_conditions): mmd_latent_with_fake_labels[i].obs[condition_key] = data.obs[ condition_key].values mmd_latent_with_fake_labels[i].obs[cell_type_key] = data.obs[ cell_type_key].values sc.pp.neighbors(mmd_latent_with_fake_labels[i]) sc.tl.umap(mmd_latent_with_fake_labels[i]) sc.pl.umap(mmd_latent_with_fake_labels[i], color=color, save=f"_{data_name}_latent_with_fake_labels_{i}", show=False, wspace=0.15, frameon=False) sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_{cell_type}_train_data', show=False, wspace=0.15, frameon=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_latent_with_true_labels", show=False, wspace=0.15, frameon=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap( mmd_latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels", show=False, wspace=0.15, frameon=False) if gene_list is not None: for target_condition in target_keys: pred_adata = pred_adatas[pred_adatas.obs[condition_key].str. endswith(target_condition)] violin_adata = cell_type_adata.concatenate(pred_adata) for gene in gene_list[:3]: sc.pl.violin( violin_adata, keys=gene, groupby=condition_key, save= f"_{data_name}_{cell_type}_{gene}_{target_condition}.pdf", show=False, wspace=0.2, rotation=90, frameon=False) plt.close("all")
def visualize_batch_correction(data_dict, z_dim=100, mmd_dimension=128): plt.close("all") data_name = data_dict['name'] source_keys = data_dict.get("source_conditions") target_keys = data_dict.get("target_conditions") cell_type_key = data_dict.get("cell_type", None) need_merge = data_dict.get('need_merge', False) label_encoder = data_dict.get('label_encoder', None) condition_key = data_dict.get('condition', 'condition') if need_merge: data, _ = merge_data(data_dict) else: data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad") cell_types = data.obs[cell_type_key].unique().tolist() spec_cell_type = data_dict.get("spec_cell_types", None) if spec_cell_type: cell_types = spec_cell_type for cell_type in cell_types: path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) train_data = data.copy()[~( (data.obs[condition_key].isin(target_keys)) & (data.obs[cell_type_key] == cell_type))] cell_type_adata = data[data.obs[cell_type_key] == cell_type] network = trvae.trVAEMulti( x_dimension=data.shape[1], z_dimension=z_dim, n_conditions=len(source_keys), mmd_dimension=mmd_dimension, model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/", ) network.restore_model() if sparse.issparse(data.X): data.X = data.X.A feed_data = data.X train_labels, _ = trvae.label_encoder(data, label_encoder, condition_key) mmd_latent_with_true_labels = network.to_mmd_layer(network, feed_data, train_labels, feed_fake=0) latent_with_true_labels = network.to_z_latent(feed_data, train_labels) import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) color = [condition_key] mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_{cell_type}_train_data', show=False, wspace=0.15, frameon=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap( mmd_latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels", show=False, wspace=0.15, frameon=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_latent_with_true_labels", show=False, wspace=0.15, frameon=False) # mmd_latent_with_true_labels.obs['mmd'] = 'others' # mmd_latent_with_true_labels.obs['mmd'] = mmd_latent_with_true_labels.obs.mmd.astype(str) # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-{target_keys[0]}'], inplace=True) # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-others'], inplace=True) # print(mmd_latent_with_true_labels.obs['mmd'].cat.categories) # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] == target_keys[0]) & # mmd_latent_with_true_labels.obs[ # cell_type_key] == cell_type), 'mmd'] = f'alpha-{target_keys[0]}' # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] != target_keys[0]) & # mmd_latent_with_true_labels.obs[ # cell_type_key] == cell_type), 'mmd'] = f'alpha-others' # # sc.pl.umap(mmd_latent_with_true_labels, color='mmd', # save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels_cell_comparison", # show=False, # wspace=0.15, # frameon=False) mmd_latent_with_true_labels.write_h5ad('../data/mmd.h5ad') latent_with_true_labels.write_h5ad('../data/latent.h5ad')
def visualize_trained_network_results(data_dict, z_dim=100, arch_style=1, preprocess=True, max_size=80000): plt.close("all") data_name = data_dict.get('name', None) source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) img_width = data_dict.get('width', None) img_height = data_dict.get('height', None) n_channels = data_dict.get('n_channels', None) train_digits = data_dict.get('train_digits', None) test_digits = data_dict.get('test_digits', None) attribute = data_dict.get('attribute', None) path_to_save = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/UMAPs/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) if data_name == "celeba": gender = data_dict.get('gender', None) data = trvae.prepare_and_load_celeba( file_path="../data/celeba/img_align_celeba.zip", attr_path="../data/celeba/list_attr_celeba.txt", landmark_path="../data/celeba/list_landmarks_align_celeba.txt", gender=gender, attribute=attribute, max_n_images=max_size, img_width=img_width, img_height=img_height, restore=True, save=False) if sparse.issparse(data.X): data.X = data.X.A train_images = data.X train_data = anndata.AnnData(X=data) train_data.obs['condition'] = data.obs['condition'].values train_data.obs.loc[train_data.obs['condition'] == 1, 'condition'] = f'with {attribute}' train_data.obs.loc[train_data.obs['condition'] == -1, 'condition'] = f'without {attribute}' train_data.obs['labels'] = data.obs['labels'].values train_data.obs.loc[train_data.obs['labels'] == 1, 'labels'] = f'Male' train_data.obs.loc[train_data.obs['labels'] == -1, 'labels'] = f'Female' if preprocess: train_images /= 255.0 else: train_data = sc.read(f"../data/{data_name}/{data_name}.h5ad") train_images = np.reshape(train_data.X, (-1, img_width, img_height, n_channels)) if preprocess: train_images /= 255.0 train_labels, _ = trvae.label_encoder(train_data) fake_labels = np.ones(train_labels.shape) network = trvae.DCtrVAE( x_dimension=(img_width, img_height, n_channels), z_dimension=z_dim, arch_style=arch_style, model_path= f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/", ) network.restore_model() train_data_feed = np.reshape(train_images, (-1, img_width, img_height, n_channels)) latent_with_true_labels = network.to_z_latent(train_data_feed, train_labels) latent_with_fake_labels = network.to_z_latent(train_data_feed, fake_labels) mmd_latent_with_true_labels = network.to_mmd_layer(network, train_data_feed, train_labels, feed_fake=False) mmd_latent_with_fake_labels = network.to_mmd_layer(network, train_data_feed, train_labels, feed_fake=True) latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs['condition'] = pd.Categorical( train_data.obs['condition'].values) latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels) latent_with_fake_labels.obs['condition'] = pd.Categorical( train_data.obs['condition'].values) mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs['condition'] = train_data.obs[ 'condition'].values mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels) mmd_latent_with_fake_labels.obs['condition'] = train_data.obs[ 'condition'].values if data_name.__contains__("mnist") or data_name == "celeba": latent_with_true_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) latent_with_fake_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) mmd_latent_with_true_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) mmd_latent_with_fake_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) color = ['condition', 'labels'] else: color = ['condition'] if train_digits is not None: train_data.obs.loc[(train_data.obs['condition'] == source_key) & (train_data.obs['labels'].isin(train_digits)), 'type'] = 'training' train_data.obs.loc[(train_data.obs['condition'] == source_key) & (train_data.obs['labels'].isin(test_digits)), 'type'] = 'training' train_data.obs.loc[(train_data.obs['condition'] == target_key) & (train_data.obs['labels'].isin(train_digits)), 'type'] = 'training' train_data.obs.loc[(train_data.obs['condition'] == target_key) & (train_data.obs['labels'].isin(test_digits)), 'type'] = 'heldout' sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_train_data.png', show=False, wspace=0.5) if train_digits is not None: sc.tl.umap(train_data) sc.pl.umap(train_data, color=['type', 'labels'], save=f'_{data_name}_data_type.png', show=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_latent_with_true_labels.png", wspace=0.5, show=False) sc.pp.neighbors(latent_with_fake_labels) sc.tl.umap(latent_with_fake_labels) sc.pl.umap(latent_with_fake_labels, color=color, save=f"_{data_name}_latent_with_fake_labels.png", wspace=0.5, show=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap(mmd_latent_with_true_labels, color=color, save=f"_{data_name}_mmd_latent_with_true_labels.png", wspace=0.5, show=False) sc.pp.neighbors(mmd_latent_with_fake_labels) sc.tl.umap(mmd_latent_with_fake_labels) sc.pl.umap(mmd_latent_with_fake_labels, color=color, save=f"_{data_name}_mmd_latent_with_fake_labels.png", wspace=0.5, show=False) plt.close("all")