def train_model(model, dataloaders, criterion, optimizer, num_epochs=100): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = np.infty device = get_device() model.to(device) for epoch in range(num_epochs): print("Epoch {}/{}".format(epoch, num_epochs - 1)) print("-" * 10) # Each epoch has a training and validation phase for phase in ["train", "val"]: if phase == "train": model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for data in dataloaders[phase]: inputs = data[data_key].to(device) labels = data[label_key].to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == "train"): # Get model outputs and calculate loss outputs = model(inputs)["recons"] loss = criterion(outputs, inputs) # backward + optimize only if in training phase if phase == "train": loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(dataloaders[phase].dataset) print("{} Loss: {:.6f}".format(phase, epoch_loss)) # deep copy the model if it has the best val accurary if phase == "val" and epoch_loss < best_loss: best_loss = epoch_loss # best_model_wts = copy.deepcopy(model.state_dict()) time_elapsed = time.time() - since print("Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60)) # load best model weights # model.load_state_dict(best_model_wts) return model
def __init__( self, output_dir: str, latent_structure_model_config: dict = None, train_val_test_split: List = [0.7, 0.2, 0.1], batch_size: int = 64, num_epochs: int = 500, early_stopping: int = 20, random_state: int = 42, ): # I/O attributes self.output_dir = output_dir # Training attributes self.num_epochs = num_epochs self.early_stopping = early_stopping self.batch_size = batch_size self.train_val_test_split = train_val_test_split # Other attributes self.latent_structure_model_config = latent_structure_model_config self.random_state = random_state self.loss_dict = None self.device = get_device() # Fix random seeds for reproducibility and limit applicable algorithms to those believed to be deterministic torch.manual_seed(self.random_state) np.random.seed(self.random_state) seed(self.random_state) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
def __init__( self, output_dir: str, n_folds: int = 4, latent_structure_model_config: dict = None, train_val_split: List = [0.8, 0.2], batch_size: int = 32, num_epochs: int = 500, early_stopping: int = 20, random_state: int = 42, ): self.domain_configs = None self.output_dir = output_dir self.latent_structure_model_config = latent_structure_model_config self.n_folds = n_folds self.train_val_split = train_val_split self.batch_size = batch_size self.num_epochs = num_epochs self.early_stopping = early_stopping self.random_state = random_state self.device = get_device() self.loss_dicts = None self.trained_models = None # Fix random seeds for reproducibility and limit applicable algorithms to those believed to be deterministic torch.manual_seed(self.random_state) np.random.seed(self.random_state) seed(self.random_state) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
def analyze_guided_gradcam_for_genesets( image_to_geneset_translator: ImageToGeneSetTranslator, image_dataloader: DataLoader, image_data_key: str, target_layer: str, query_node: int, ): device = get_device() image_to_geneset_translator.eval().to(device) grad_cam = GradCam( model=image_to_geneset_translator, feature_module=image_to_geneset_translator.encoder, target_layer_names=[target_layer], device=device, ) gb_model = GuidedBackpropReLUModel( model=copy.deepcopy(image_to_geneset_translator), device=device) guided_backpropagation_maps = [] gradient_cams = [] images = [] for (i, sample) in enumerate(image_dataloader): inputs = sample[image_data_key].to(device) gradient_cams.append( grad_cam(inputs, target_node=query_node).cpu().numpy()) guided_backpropagation_maps.extend( list(gb_model(inputs, target_node=query_node))) images.extend(list(inputs.detach().cpu().numpy())) data_dict = { "images": np.array(images), "gb_maps": np.array(guided_backpropagation_maps), "grad_cams": np.array(gradient_cams), } return data_dict
def test_model(model, dataloaders, criterion): device = get_device() model.to(device) model.eval() runninq_loss = 0 for data in dataloaders["test"]: inputs = data[data_key].to(device) outputs = model(inputs)["recons"] runninq_loss += criterion(outputs, inputs).item() * inputs.size(0) test_loss = runninq_loss / len(dataloaders["test"].dataset) print("Test loss: ", test_loss) return test_loss
def get_latent_model_configuration(model_dict: dict, optimizer_dict: dict, loss_dict: dict, device: None) -> dict: if device is None: device = get_device() model_type = model_dict.pop("type") if model_type == "LatentDiscriminator": model = LatentDiscriminator(**model_dict) elif model_type == "LatentClassifier": model = LatentClassifier(**model_dict) elif model_type == "LatentRegressor": model = LatentRegressor(**model_dict) else: raise NotImplementedError('Unknown model type "{}"'.format(model_type)) optimizer = get_optimizer_for_model(optimizer_dict=optimizer_dict, model=model) if model_type != "LatentRegressor": try: weights = torch.FloatTensor(loss_dict.pop("weights")).to(device) except KeyError: weights = torch.ones(model_dict["n_classes"]).float().to(device) else: weights = None loss_type = loss_dict.pop("type") if loss_type == "ce": latent_loss = RobustCrossEntropyLoss(weight=weights) elif loss_type == "bce": latent_loss = RobustBCELoss(weight=weights) elif loss_type == "bce_ll": latent_loss = RobustBCEWithLogitsLoss(weight=weights) elif loss_type == "mse": latent_loss = RobustMSELoss() elif loss_type == "mae": latent_loss = RobustL1Loss() elif loss_type == "weighted_mse": latent_loss = RobustWeightedMSELoss() else: raise NotImplementedError('Unknown loss type "{}"'.format(loss_type)) latent_model_config = { "model": model, "optimizer": optimizer, "loss": latent_loss } return latent_model_config
def analyze_geneset_perturbation_in_image( geneset_ae: GeneSetAE, image_ae: Module, seq_dataloader: DataLoader, seq_data_key: str, silencing_node: int, ): device = get_device() geneset_ae.to(device).eval() image_ae.to(device).eval() perturbation_geneset_ae = PerturbationGeneSetAE( input_dim=geneset_ae.input_dim, latent_dim=geneset_ae.latent_dim, hidden_dims=geneset_ae.hidden_dims, geneset_adjacencies=geneset_ae.geneset_adjacencies, ) perturbation_geneset_ae.to(device) perturbation_geneset_ae.load_state_dict(geneset_ae.state_dict()) perturbation_geneset_ae.eval() geneset_ae.cpu() translated_images = [] perturbed_translated_images = [] recon_sequences = [] perturbed_recon_sequences = [] for (i, sample) in enumerate(seq_dataloader): inputs = sample[seq_data_key].to(device) output_dict = perturbation_geneset_ae(inputs, silencing_node) recon_sequences.extend( list(output_dict["recons"].detach().cpu().numpy())) latents = output_dict["latents"] geneset_activites = output_dict["geneset_activities"] perturbed_latents = output_dict["perturbed_latents"] perturbed_recon_sequences.extend( list(output_dict["perturbed_recons"].detach().cpu().numpy())) translated_images.extend( list(image_ae.decode(latents).detach().cpu().numpy())) perturbed_translated_images.extend( list(image_ae.decode(perturbed_latents).detach().cpu().numpy())) data_dict = { "seq_recons": np.array(recon_sequences), "perturbed_seq_recons": np.array(perturbed_recon_sequences), "trans_images": np.array(translated_images), "perturbed_trans_images": np.array(perturbed_translated_images), } return data_dict
def __init__( self, output_dir: str, data_config: dict, model_config: dict, domain_name: str, latent_structure_model_config: dict = None, train_val_test_split: List[float] = [0.7, 0.2, 0.1], batch_size: int = 64, num_epochs: int = 64, early_stopping: int = -1, random_state: int = 42, ): super().__init__( output_dir=output_dir, train_val_test_split=train_val_test_split, batch_size=batch_size, num_epochs=num_epochs, early_stopping=early_stopping, random_state=random_state, ) self.data_config = data_config self.model_config = model_config self.domain_name = domain_name self.latent_structure_model_config = latent_structure_model_config self.data_set = None self.data_transform_pipeline_dict = None self.data_loader_dict = None self.data_key = None self.label_key = None self.domain_config = None self.trained_models = None self.loss_dict = None self.device = get_device()
def __init__( self, input_channels: int = 1, latent_dim: int = 128, hidden_dims: List[int] = [128, 256, 512, 1024, 1024], lrelu_slope: int = 0.2, batchnorm: bool = True, ) -> None: super().__init__() self.in_channels = input_channels self.latent_dim = latent_dim self.hidden_dims = hidden_dims self.lrelu_slope = lrelu_slope self.batchnorm = batchnorm self.updated = False self.n_latent_spaces = 1 # Build encoder encoder_modules = [ nn.Sequential( nn.Conv2d( in_channels=self.in_channels, out_channels=self.hidden_dims[0], kernel_size=4, stride=2, padding=1, bias=False, ), nn.ReLU(), ) ] for i in range(1, len(self.hidden_dims)): encoder_modules.append( nn.Sequential( nn.Conv2d( in_channels=self.hidden_dims[i - 1], out_channels=self.hidden_dims[i], kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(self.hidden_dims[i]), nn.ReLU(), )) self.encoder = nn.Sequential(*encoder_modules) # Output of encoder are of shape 1024x4x4 self.device = get_device() if self.batchnorm: self.latent_mapper = nn.Sequential( nn.Linear(hidden_dims[-1] * 2 * 2, self.latent_dim), nn.BatchNorm1d(self.latent_dim), ) else: self.latent_mapper = nn.Linear(hidden_dims[-1] * 2 * 2, self.latent_dim) self.inv_latent_mapper = nn.Sequential( nn.Linear(self.latent_dim, hidden_dims[-1] * 2 * 2), nn.ReLU(inplace=True)) # decoder decoder_modules = [] for i in range(len(hidden_dims) - 1): decoder_modules.append( nn.Sequential( nn.ConvTranspose2d( in_channels=hidden_dims[-1 - i], out_channels=hidden_dims[-2 - i], kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(hidden_dims[-2 - i]), nn.ReLU(), )) decoder_modules.append( nn.Sequential( nn.ConvTranspose2d( in_channels=hidden_dims[0], out_channels=self.in_channels, kernel_size=4, stride=2, padding=1, bias=False, ), nn.Sigmoid(), )) self.decoder = nn.Sequential(*decoder_modules)
x = x.view(1, -1) return self.latent_mapper(x) feature_extractor = ae_model.encoder full_encoder = FullEncoder(feature_extractor, ae_model.latent_mapper) image = io.imread( "/home/daniel/PycharmProjects/domain_translation/data/cd4/nuclear_crops_all_experiments/128px/labeled_scaled_max_intensity_resized_images/11J1_CD4T_488Coro1A_555RPL10A_D001_05_nucleus_472_97_2_13.tif" ) image = np.array(image, dtype=np.float32) input_image = torch.from_numpy(image.copy()).unsqueeze(0) # transform_pipeline = Compose([ToPILImage(), ToTensor()]) # input_image = transform_pipeline(input_image) device = get_device() feature_extractor.to(device) full_encoder.to(device) input_image = input_image.to(device).unsqueeze(0) query_node = 115 grad_cam = GradCam( model=full_encoder, feature_module=full_encoder.feature_extractor, target_layer_names=["4"], device=device, ) grayscale_cam = grad_cam(input_image, query_node)
def perform_latent_walk_in_umap_space(domain_configs: List[DomainConfig], dataloader_type: str, random_state: int = 1234): if len(domain_configs) != 2: raise RuntimeError( "Expects two domain configurations (image and sequencing domain)") if domain_configs[0].name == "image" and domain_configs[1].name == "rna": image_domain_config = domain_configs[0] rna_domain_config = domain_configs[1] elif domain_configs[0].name == "rna" and domain_configs[1].name == "image": image_domain_config = domain_configs[1] rna_domain_config = domain_configs[0] else: raise RuntimeError( "Expected domain configuration types are >image< and >rna<.") rna_data_loader = rna_domain_config.data_loader_dict[dataloader_type] image_data_loader = image_domain_config.data_loader_dict[dataloader_type] device = get_device() geneset_ae = rna_domain_config.domain_model_config.model.to(device).eval() image_ae = image_domain_config.domain_model_config.model.to(device).eval() all_rna_latents = [] all_rna_labels = [] all_image_latents = [] all_image_labels = [] grid_sequences = [] grid_geneset_activities = [] grid_images = [] rna_cell_ids = [] image_cell_ids = [] for i, sample in enumerate(rna_data_loader): rna_inputs = sample[rna_domain_config.data_key].to(device) rna_labels = sample[rna_domain_config.label_key] rna_cell_ids.extend(sample["id"]) geneset_ae_output = geneset_ae(rna_inputs) latents = geneset_ae_output["latents"] all_rna_latents.extend(list(latents.clone().detach().cpu().numpy())) all_rna_labels.extend(list(rna_labels.clone().detach().cpu().numpy())) for i, sample in enumerate(image_data_loader): image_inputs = sample[image_domain_config.data_key].to(device) image_labels = sample[image_domain_config.label_key].to(device) image_cell_ids.extend(sample["id"]) image_ae_output = image_ae(image_inputs) latents = image_ae_output["latents"] all_image_latents.extend(list(latents.clone().detach().cpu().numpy())) all_image_labels.extend( list(image_labels.clone().detach().cpu().numpy())) all_latents = np.concatenate( (np.array(all_image_latents), np.array(all_rna_latents)), axis=0) all_labels = np.concatenate( (np.array(all_image_labels), np.array(all_rna_labels)), axis=0) all_domain_labels = np.concatenate( ( np.repeat("image", len(all_image_labels)), np.repeat("rna", len(all_rna_labels)), ), axis=0, ) all_cell_ids = np.concatenate((image_cell_ids, rna_cell_ids), axis=0) mapper = UMAP(random_state=random_state) transformed = mapper.fit_transform(all_latents) min_umap_c1 = min(transformed[:, 0]) max_umap_c1 = max(transformed[:, 0]) min_umap_c2 = min(transformed[:, 1]) max_umap_c2 = max(transformed[:, 1]) test_pts = np.array([ (np.array([min_umap_c1, max_umap_c2]) * (1 - x) + np.array([max_umap_c1, max_umap_c2]) * x) * (1 - y) + (np.array([min_umap_c1, min_umap_c2]) * (1 - x) + np.array([max_umap_c1, min_umap_c2]) * x) * y for y in np.linspace(0, 1, 10) for x in np.linspace(0, 1, 10) ]) inv_transformed_points = mapper.inverse_transform(test_pts) test_pts_ds = torch.utils.data.TensorDataset( torch.from_numpy(inv_transformed_points)) test_pts_loader = torch.utils.data.DataLoader(test_pts_ds, batch_size=64, shuffle=False) for i, sample in enumerate(test_pts_loader): image_recons = image_ae.decode(sample[0].to(device)) rna_recons, decoded_geneset_activities = geneset_ae.decode( sample[0].to(device)) grid_images.extend(list(image_recons.clone().detach().cpu().numpy())) grid_sequences.extend(list(rna_recons.clone().detach().cpu().numpy())) grid_geneset_activities.extend( list(decoded_geneset_activities.clone().detach().cpu().numpy())) data_dict = { "grid_points": test_pts, "grid_images": grid_images, "grid_sequences": grid_sequences, "grid_geneset_activities": grid_geneset_activities, "all_latents": all_latents, "all_labels": all_labels, "all_domain_labels": all_domain_labels, "all_cell_ids": all_cell_ids, } return data_dict
def get_geneset_activities_and_translated_images_sequences( domain_configs: List[DomainConfig], dataloader_type: str): if len(domain_configs) != 2: raise RuntimeError( "Expects two domain configurations (image and sequencing domain)") if domain_configs[0].name == "image" and domain_configs[1].name == "rna": image_domain_config = domain_configs[0] rna_domain_config = domain_configs[1] elif domain_configs[0].name == "rna" and domain_configs[1].name == "image": image_domain_config = domain_configs[1] rna_domain_config = domain_configs[0] else: raise RuntimeError( "Expected domain configuration types are >image< and >rna<.") rna_data_loader = rna_domain_config.data_loader_dict[dataloader_type] image_data_loader = image_domain_config.data_loader_dict[dataloader_type] device = get_device() rna_cell_ids = [] all_rna_labels = [] all_rna_inputs = [] all_rna_latents = [] all_geneset_activities = [] all_reconstructed_geneset_activities = [] all_reconstructed_rna_inputs = [] all_translated_images = [] all_image_latents = [] all_reenconded_rna_latents = [] image_cell_ids = [] all_image_labels = [] all_image_inputs = [] all_translated_rna_seq = [] all_translated_rna_latents = [] all_translated_geneset_activities = [] all_translated_image_latents = [] geneset_ae = rna_domain_config.domain_model_config.model.to(device).eval() image_ae = image_domain_config.domain_model_config.model.to(device).eval() for i, sample in enumerate(rna_data_loader): rna_inputs = sample[rna_domain_config.data_key].to(device) rna_labels = sample[rna_domain_config.label_key] cell_ids = sample["id"] geneset_ae_output = geneset_ae(rna_inputs) latents = geneset_ae_output["latents"] geneset_activities = geneset_ae_output["geneset_activities"] reconstructed_geneset_activities = geneset_ae_output[ "decoded_geneset_activities"] reconstructed_rna_inputs = geneset_ae_output["recons"] reencoded_rna_latents = geneset_ae(reconstructed_rna_inputs)["latents"] translated_images = image_ae.decode(latents) translated_image_latents = image_ae(translated_images)["latents"] rna_cell_ids.extend(cell_ids) all_rna_labels.extend(list(rna_labels.clone().detach().cpu().numpy())) all_rna_inputs.extend(list(rna_inputs.clone().detach().cpu().numpy())) all_geneset_activities.extend( list(geneset_activities.clone().detach().cpu().numpy())) all_translated_images.extend( list(translated_images.clone().detach().cpu().numpy())) all_translated_image_latents.extend( list(translated_image_latents.clone().detach().cpu().numpy())) all_rna_latents.extend(list(latents.clone().detach().cpu().numpy())) all_reconstructed_geneset_activities.extend( list(reconstructed_geneset_activities.clone().detach().cpu().numpy( ))) all_reconstructed_rna_inputs.extend( list(reconstructed_rna_inputs.clone().detach().cpu().numpy())) all_reenconded_rna_latents.extend( list(reencoded_rna_latents.clone().detach().cpu().numpy())) for i, sample in enumerate(image_data_loader): image_inputs = sample[image_domain_config.data_key].to(device) image_labels = sample[image_domain_config.label_key].to(device) cell_ids = sample["id"] image_ae_output = image_ae(image_inputs) latents = image_ae_output["latents"] translated_sequences, translated_geneset_activities = geneset_ae.decode( latents) geneset_ae_output = geneset_ae(translated_sequences) translated_rna_latents = geneset_ae_output["latents"] image_cell_ids.extend(cell_ids) all_image_labels.extend( list(image_labels.clone().detach().cpu().numpy())) all_image_inputs.extend( list(image_inputs.clone().detach().cpu().numpy())) all_translated_geneset_activities.extend( list(translated_geneset_activities.clone().detach().cpu().numpy())) all_translated_rna_seq.extend( list(translated_sequences.clone().detach().cpu().numpy())) all_image_latents.extend(list(latents.clone().detach().cpu().numpy())) all_translated_rna_latents.extend( list(translated_rna_latents.clone().detach().cpu().numpy())) data_dict = { "rna_cell_ids": rna_cell_ids, "rna_labels": all_rna_labels, "rna_inputs": all_rna_inputs, "rna_latents": all_rna_latents, "geneset_activities": all_geneset_activities, "reconstructed_geneset_activities": all_reconstructed_geneset_activities, "reconstructed_rna_inputs": all_reconstructed_rna_inputs, "reencoded_rna_latents": all_reenconded_rna_latents, "translated_images": all_translated_images, "translated_image_latents": all_translated_image_latents, "image_cell_ids": image_cell_ids, "image_labels": all_image_labels, "image_inputs": all_image_inputs, "image_latents": all_image_latents, "translated_sequences": all_translated_rna_seq, "translated_sequence_latents": all_translated_rna_latents, "translated_geneset_activities": all_translated_geneset_activities, } return data_dict