def de_stats(vae, data_loader, M_sampling=100, use_cuda=True): """ Output average over statistics in a symmetric way (a against b) forget the sets if permutation is True :param vae: The generative vae and encoder network :param data_loader: a data loader for a particular dataset :param M_sampling: number of samples :return: A 1-d vector of statistics of size n_genes """ px_scales = [] all_labels = [] for tensors in data_loader: if use_cuda: tensors = to_cuda(tensors) sample_batch, _, _, batch_index, labels = tensors sample_batch = sample_batch.type(torch.float32) sample_batch = sample_batch.repeat(1, M_sampling).view( -1, sample_batch.size(1)) batch_index = batch_index.repeat(1, M_sampling).view(-1, 1) labels = labels.repeat(1, M_sampling).view(-1, 1) px_scales += [ vae.get_sample_scale(sample_batch, batch_index=batch_index, y=labels).cpu() ] all_labels += [labels.cpu()] px_scale = torch.cat(px_scales) all_labels = torch.cat(all_labels) return px_scale, all_labels
def fit(self, n_epochs=20, lr=1e-3): optimizer = self.optimizer if hasattr(self, 'optimizer') else \ torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=lr) self.epoch = 0 self.n_epochs = n_epochs self.compute_metrics() with trange(n_epochs, desc="training", file=sys.stdout, disable=self.verbose) as pbar: # We have to use tqdm this way so it works in Jupyter notebook. # See https://stackoverflow.com/questions/42212810/tqdm-in-jupyter-notebook self.on_epoch_begin() for epoch in pbar: pbar.update(1) for tensors_list in self.data_loaders: loss = self.loss(*[ to_cuda(tensors, use_cuda=self.use_cuda) for tensors in tensors_list ]) optimizer.zero_grad() loss.backward() optimizer.step() if not self.on_epoch_end(): break if self.save_best_state_metric is not None: self.model.load_state_dict(self.best_state_dict) self.compute_metrics()
def imputation(vae, data_loader, rate=0.1): distance_list = torch.FloatTensor([]) for tensorlist in data_loader: if vae.use_cuda: tensorlist = to_cuda(tensorlist) sample_batch, local_l_mean, local_l_var, batch_index, labels = tensorlist sample_batch = sample_batch.type(torch.float32) dropout_batch = sample_batch.clone() indices = torch.nonzero(dropout_batch) i, j = indices[:, 0], indices[:, 1] ix = torch.LongTensor( np.random.choice(range(len(i)), int(np.floor(rate * len(i))), replace=False)) dropout_batch[i[ix], j[ix]] *= 0 if vae.use_cuda: ix, i, j = to_cuda([ix, i, j], async=False) px_rate = vae.get_sample_rate(dropout_batch, labels, batch_index=batch_index) distance_list = torch.cat([ distance_list, torch.abs(px_rate[i[ix], j[ix]] - sample_batch[i[ix], j[ix]]).cpu() ]) return torch.median(distance_list)
def get_latent(vae, data_loader, use_cuda=True): latent = [] batch_indices = [] labels = [] for tensors in data_loader: tensors = to_cuda(tensors, use_cuda=use_cuda) sample_batch, local_l_mean, local_l_var, batch_index, label = tensors sample_batch = sample_batch.type(torch.float32) latent += [vae.sample_from_posterior_z(sample_batch, y=label)] batch_indices += [batch_index] labels += [label] return np.array(torch.cat(latent)), np.array(torch.cat(batch_indices)), np.array(torch.cat(labels)).ravel()
def get_latent(vae, data_loader): latent = [] batch_indices = [] labels = [] for tensors in data_loader: if vae.use_cuda: tensors = to_cuda(tensors) sample_batch, local_l_mean, local_l_var, batch_index, label = tensors sample_batch = sample_batch.type(torch.float32) latent += [vae.sample_from_posterior_z(sample_batch, y=label)] batch_indices += [batch_index] labels += [label] return torch.cat(latent), torch.cat(batch_indices), torch.cat(labels)
def compute_log_likelihood(vae, data_loader, use_cuda=True): # Iterate once over the data_loader and computes the total log_likelihood log_lkl = 0 for i_batch, tensors in enumerate(data_loader): tensors = to_cuda(tensors, use_cuda=use_cuda) sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors sample_batch = sample_batch.type(torch.float32) reconst_loss, kl_divergence = vae(sample_batch, local_l_mean, local_l_var, batch_index=batch_index, y=labels) log_lkl += torch.sum(reconst_loss).item() n_samples = (len(data_loader.dataset) if not (hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'indices')) else len(data_loader.sampler.indices)) return log_lkl / n_samples
def adapt_encoder(vae, dataloader, n_path=10, n_epochs=50, record_freq=5): parameters = list(vae.z_encoder.parameters()) + list( vae.l_encoder.parameters()) z_encoder_state = vae.z_encoder.state_dict() l_encoder_state = vae.l_encoder.state_dict() optimizer = torch.optim.Adam(parameters, eps=0.01) # Getting access to the stats during training stats = Stats(n_epochs=n_epochs, record_freq=record_freq, names=['test'], verbose=False) stats.callback(vae, dataloader) best_ll = stats.history["LL_test"][0] # Training the model for i in range(n_path): # Re-initialize to create new path vae.z_encoder.load_state_dict(z_encoder_state) vae.l_encoder.load_state_dict(l_encoder_state) for epoch in range(n_epochs): for i_batch, tensors in enumerate(dataloader): if vae.use_cuda: tensors = to_cuda(tensors) sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors sample_batch = sample_batch.type(torch.float32) reconst_loss, _ = vae(sample_batch, local_l_mean, local_l_var, batch_index=batch_index, y=labels) train_loss = torch.mean(reconst_loss) optimizer.zero_grad() train_loss.backward() optimizer.step() stats.callback(vae, dataloader) best_ll = min(min(stats.history["LL_test"]), best_ll) return best_ll
def compute_accuracy(vae, data_loader, classifier=None, use_cuda=True): all_y_pred = [] all_labels = [] for i_batch, tensors in enumerate(data_loader): tensors = to_cuda(tensors, use_cuda=use_cuda) sample_batch, _, _, _, labels = tensors sample_batch = sample_batch.type(torch.float32) all_labels += [labels.view(-1)] if hasattr(vae, 'classify'): y_pred = vae.classify(sample_batch).argmax(dim=-1) elif classifier is not None: # Then we use the specified classifier if vae is not None: sample_batch, _, _ = vae.z_encoder(sample_batch) y_pred = classifier(sample_batch).argmax(dim=-1) all_y_pred += [y_pred] accuracy = (torch.cat(all_y_pred) == torch.cat(all_labels)).type(torch.float32).mean().item() return accuracy
def compute_accuracy(vae, data_loader, classifier=None): all_y_pred = [] all_labels = [] for i_batch, tensors in enumerate(data_loader): if vae.use_cuda: tensors = to_cuda(tensors) sample_batch, _, _, _, labels = tensors sample_batch = sample_batch.type(torch.float32) all_labels += [labels.view(-1)] if classifier is not None: # Then we use the specified classifier mu_z, _, _ = vae.z_encoder(sample_batch) y_pred = classifier(mu_z).argmax(dim=-1) else: # Then the vae must implement a classify function y_pred = vae.classify(sample_batch).argmax(dim=-1) all_y_pred += [y_pred] accuracy = (torch.cat(all_y_pred) == torch.cat(all_labels)).type( torch.float32).mean().item() return accuracy
def get_statistics(vae, data_loader, M_sampling=100, M_permutation=100000, permutation=False): """ Output average over statistics in a symmetric way (a against b) forget the sets if permutation is True :param vae: The vae model :param data_loader: :param M_sampling: 200 - default value in Romain's code :param M_permutation: 10000 - default value in Romain's code :param permutation: :return: A 1-d vector of statistics of size n_genes """ # Compute sample rate for the whole dataset ? px_scales = [] all_labels = [] for tensors in data_loader: if vae.use_cuda: tensors = to_cuda(tensors) sample_batch, _, _, batch_index, labels = tensors sample_batch = sample_batch.type(torch.float32) sample_batch = sample_batch.repeat(1, M_sampling).view(-1, sample_batch.size(1)) batch_index = batch_index.repeat(1, M_sampling).view(-1, 1) labels = labels.repeat(1, M_sampling).view(-1, 1) px_scales += [vae.get_sample_scale(sample_batch, y=labels, batch_index=batch_index)] all_labels += [labels] cell_types = np.array(['astrocytes_ependymal', 'endothelial-mural', 'interneurons', 'microglia', 'oligodendrocytes', 'pyramidal CA1', 'pyramidal SS'], dtype=np.str) # oligodendrocytes (#4) VS pyramidal CA1 (#5) couple_celltypes = (4, 5) # the couple types on which to study DE print("\nDifferential Expression A/B for cell types\nA: %s\nB: %s\n" % tuple((cell_types[couple_celltypes[i]] for i in [0, 1]))) px_scale = torch.cat(px_scales) # Here instead of A, B = 200, 400: we do on whole dataset then select cells all_labels = torch.cat(all_labels) sample_rate_a = px_scale[all_labels.view(-1) == couple_celltypes[0]].view(-1, px_scale.size(1)).cpu().numpy() sample_rate_b = px_scale[all_labels.view(-1) == couple_celltypes[1]].view(-1, px_scale.size(1)).cpu().numpy() # agregate dataset samples = np.vstack((sample_rate_a, sample_rate_b)) # prepare the pairs for sampling list_1 = list(np.arange(sample_rate_a.shape[0])) list_2 = list(sample_rate_a.shape[0] + np.arange(sample_rate_b.shape[0])) if not permutation: # case1: no permutation, sample from A and then from B u, v = np.random.choice(list_1, size=M_permutation), np.random.choice(list_2, size=M_permutation) else: # case2: permutation, sample from A+B twice u, v = (np.random.choice(list_1 + list_2, size=M_permutation), np.random.choice(list_1 + list_2, size=M_permutation)) # then constitutes the pairs first_set = samples[u] second_set = samples[v] res = np.mean(first_set >= second_set, 0) res = np.log(res + 1e-8) - np.log(1 - res + 1e-8) genes_of_interest = ["Thy1", "Mbp"] gene_names = data_loader.dataset.gene_names result = [(gene_name, res[np.where(gene_names == gene_name.upper())[0]][0]) for gene_name in genes_of_interest] print('\n'.join([gene_name + " : " + str(r) for (gene_name, r) in result])) return res