def visualize_trained_network_results(data_dict, z_dim=100, mmd_dimension=128, loss_fn='mse'): plt.close("all") data_name = data_dict['name'] source_keys = data_dict.get("source_conditions") target_keys = data_dict.get("target_conditions") cell_type_key = data_dict.get("cell_type", None) need_merge = data_dict.get('need_merge', False) label_encoder = data_dict.get('label_encoder', None) condition_key = data_dict.get('condition', 'condition') if need_merge: data, _ = merge_data(data_dict) else: data = sc.read(f"../data/{data_name}/{data_name}.h5ad") if loss_fn != 'mse': data = normalize_hvg(data, filter_min_counts=False, normalize_input=False, logtrans_input=True) cell_types = data.obs[cell_type_key].unique().tolist() spec_cell_type = data_dict.get("spec_cell_types", None) if spec_cell_type: cell_types = spec_cell_type for cell_type in cell_types: path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) train_data = data.copy()[~( (data.obs[condition_key].isin(target_keys)) & (data.obs[cell_type_key] == cell_type))] cell_type_adata = data.copy()[data.obs[cell_type_key] == cell_type] n_conditions = len(train_data.obs[condition_key].unique().tolist()) network = trvae.trVAEMulti( x_dimension=data.shape[1], z_dimension=z_dim, loss_fn=loss_fn, n_conditions=n_conditions, mmd_dimension=mmd_dimension, model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/", ) network.restore_model() if sparse.issparse(data.X): data.X = data.X.A feed_data = data train_labels, _ = trvae.label_encoder(data, label_encoder, condition_key) fake_labels = [] for i in range(n_conditions): fake_labels.append(np.zeros(train_labels.shape) + i) latent_with_true_labels = network.to_z_latent(feed_data.X, train_labels) latent_with_fake_labels = [ network.to_z_latent(feed_data.X, fake_labels[i]) for i in range(n_conditions) ] mmd_latent_with_true_labels = network.to_mmd_layer(feed_data, train_labels, feed_fake=0) mmd_latent_with_fake_labels = [ network.to_mmd_layer(feed_data, train_labels, feed_fake=i) for i in range(n_conditions) ] if data_name in ["pbmc", 'endo_norm']: sc.tl.rank_genes_groups(cell_type_adata, groupby=condition_key, n_genes=100, method="wilcoxon") top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][ target_keys[-1]].tolist() gene_list = top_100_genes[:10] elif data_name in ['pancreas', 'nmuil_count']: gene_list = None top_100_genes = None else: sc.tl.rank_genes_groups(cell_type_adata, groupby=condition_key, n_genes=100, method="wilcoxon") top_50_down_genes = cell_type_adata.uns["rank_genes_groups"][ "names"][source_keys[0]].tolist() top_50_up_genes = cell_type_adata.uns["rank_genes_groups"][ "names"][target_keys[-1]].tolist() top_100_genes = top_50_up_genes + top_50_down_genes gene_list = top_50_down_genes[:5] + top_50_up_genes[:5] perturbation_list = data_dict.get("perturbation", []) pred_adatas = None for source, dest, name, source_label, target_label in perturbation_list: print(source, dest, name) pred_adata = visualize_multi_perturbation_between( network, cell_type_adata, pred_adatas, source_condition=source, target_condition=dest, name=name, source_label=source_label, target_label=target_label, cell_type=cell_type, data_name=data_name, top_100_genes=top_100_genes, gene_list=gene_list, path_to_save=path_to_save, condition_key=condition_key) if pred_adatas is None: pred_adatas = pred_adata else: pred_adatas = pred_adatas.concatenate(pred_adata) pred_adatas.write_h5ad( filename=f"../data/reconstructed/RCVAEMulti/{data_name}.h5ad") import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) color = [condition_key, cell_type_key] latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values latent_with_fake_labels = [ sc.AnnData(X=latent_with_fake_labels[i]) for i in range(n_conditions) ] for i in range(n_conditions): latent_with_fake_labels[i].obs[condition_key] = data.obs[ condition_key].values latent_with_fake_labels[i].obs[cell_type_key] = data.obs[ cell_type_key].values sc.pp.neighbors(latent_with_fake_labels[i]) sc.tl.umap(latent_with_fake_labels[i]) sc.pl.umap( latent_with_fake_labels[i], color=color, save=f"_{data_name}_{cell_type}_latent_with_fake_labels_{i}", show=False, wspace=0.15, frameon=False) mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values mmd_latent_with_fake_labels = [ sc.AnnData(X=mmd_latent_with_fake_labels[i]) for i in range(n_conditions) ] for i in range(n_conditions): mmd_latent_with_fake_labels[i].obs[condition_key] = data.obs[ condition_key].values mmd_latent_with_fake_labels[i].obs[cell_type_key] = data.obs[ cell_type_key].values sc.pp.neighbors(mmd_latent_with_fake_labels[i]) sc.tl.umap(mmd_latent_with_fake_labels[i]) sc.pl.umap(mmd_latent_with_fake_labels[i], color=color, save=f"_{data_name}_latent_with_fake_labels_{i}", show=False, wspace=0.15, frameon=False) sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_{cell_type}_train_data', show=False, wspace=0.15, frameon=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_latent_with_true_labels", show=False, wspace=0.15, frameon=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap( mmd_latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels", show=False, wspace=0.15, frameon=False) if gene_list is not None: for target_condition in target_keys: pred_adata = pred_adatas[pred_adatas.obs[condition_key].str. endswith(target_condition)] violin_adata = cell_type_adata.concatenate(pred_adata) for gene in gene_list[:3]: sc.pl.violin( violin_adata, keys=gene, groupby=condition_key, save= f"_{data_name}_{cell_type}_{gene}_{target_condition}.pdf", show=False, wspace=0.2, rotation=90, frameon=False) plt.close("all")
def visualize_batch_correction(data_dict, z_dim=100, mmd_dimension=128): plt.close("all") data_name = data_dict['name'] source_keys = data_dict.get("source_conditions") target_keys = data_dict.get("target_conditions") cell_type_key = data_dict.get("cell_type", None) need_merge = data_dict.get('need_merge', False) label_encoder = data_dict.get('label_encoder', None) condition_key = data_dict.get('condition', 'condition') if need_merge: data, _ = merge_data(data_dict) else: data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad") cell_types = data.obs[cell_type_key].unique().tolist() spec_cell_type = data_dict.get("spec_cell_types", None) if spec_cell_type: cell_types = spec_cell_type for cell_type in cell_types: path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) train_data = data.copy()[~( (data.obs[condition_key].isin(target_keys)) & (data.obs[cell_type_key] == cell_type))] cell_type_adata = data[data.obs[cell_type_key] == cell_type] network = trvae.trVAEMulti( x_dimension=data.shape[1], z_dimension=z_dim, n_conditions=len(source_keys), mmd_dimension=mmd_dimension, model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/", ) network.restore_model() if sparse.issparse(data.X): data.X = data.X.A feed_data = data.X train_labels, _ = trvae.label_encoder(data, label_encoder, condition_key) mmd_latent_with_true_labels = network.to_mmd_layer(network, feed_data, train_labels, feed_fake=0) latent_with_true_labels = network.to_z_latent(feed_data, train_labels) import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) color = [condition_key] mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs[condition_key] = data.obs[ condition_key].values latent_with_true_labels.obs[cell_type_key] = data.obs[ cell_type_key].values sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_{cell_type}_train_data', show=False, wspace=0.15, frameon=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap( mmd_latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels", show=False, wspace=0.15, frameon=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_{cell_type}_latent_with_true_labels", show=False, wspace=0.15, frameon=False) # mmd_latent_with_true_labels.obs['mmd'] = 'others' # mmd_latent_with_true_labels.obs['mmd'] = mmd_latent_with_true_labels.obs.mmd.astype(str) # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-{target_keys[0]}'], inplace=True) # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-others'], inplace=True) # print(mmd_latent_with_true_labels.obs['mmd'].cat.categories) # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] == target_keys[0]) & # mmd_latent_with_true_labels.obs[ # cell_type_key] == cell_type), 'mmd'] = f'alpha-{target_keys[0]}' # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] != target_keys[0]) & # mmd_latent_with_true_labels.obs[ # cell_type_key] == cell_type), 'mmd'] = f'alpha-others' # # sc.pl.umap(mmd_latent_with_true_labels, color='mmd', # save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels_cell_comparison", # show=False, # wspace=0.15, # frameon=False) mmd_latent_with_true_labels.write_h5ad('../data/mmd.h5ad') latent_with_true_labels.write_h5ad('../data/latent.h5ad')
def train_network( data_dict=None, z_dim=100, mmd_dimension=256, alpha=0.001, beta=100, kernel='multi-scale-rbf', n_epochs=500, batch_size=512, early_stop_limit=50, dropout_rate=0.2, learning_rate=0.001, loss_fn='mse', verbose=2, ): data_name = data_dict['name'] target_keys = data_dict.get("target_conditions") cell_type_key = data_dict.get("cell_type", None) need_merge = data_dict.get('need_merge', False) label_encoder = data_dict.get('label_encoder', None) condition_key = data_dict.get('condition', 'condition') if need_merge: train_data, valid_data = merge_data(data_dict) else: adata = sc.read(f"../data/{data_name}/{data_name}.h5ad") if loss_fn != 'mse': adata = normalize_hvg(adata, filter_min_counts=False, normalize_input=False, logtrans_input=True) train_data, valid_data = train_test_split(adata, 0.80) spec_cell_type = data_dict.get("spec_cell_types", None) if cell_type_key is not None: cell_types = train_data.obs[cell_type_key].unique().tolist() if spec_cell_type: cell_types = spec_cell_type for cell_type in cell_types: net_train_data = train_data.copy()[~( (train_data.obs[cell_type_key] == cell_type) & (train_data.obs[condition_key].isin(target_keys)))] net_valid_data = valid_data.copy()[~( (valid_data.obs[cell_type_key] == cell_type) & (valid_data.obs[condition_key].isin(target_keys)))] n_conditions = len( net_train_data.obs[condition_key].unique().tolist()) if data_name == 'pancreas': use_leaky_relu = True else: use_leaky_relu = False network = trvae.trVAEMulti( x_dimension=net_train_data.shape[1], z_dimension=z_dim, n_conditions=n_conditions, mmd_dimension=mmd_dimension, alpha=alpha, beta=beta, kernel=kernel, learning_rate=learning_rate, loss_fn=loss_fn, model_path= f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/", dropout_rate=dropout_rate, use_leaky_relu=use_leaky_relu) network.train(net_train_data, label_encoder, condition_key, use_validation=True, valid_adata=net_valid_data, n_epochs=n_epochs, batch_size=batch_size, verbose=verbose, early_stop_limit=early_stop_limit, shuffle=True, save=True) print(f"Model for {cell_type} has been trained")