def train_model(model, dataloaders, criterion, optimizer, num_epochs=100):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = np.infty
    device = get_device()

    model.to(device)

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                inputs = data[data_key].to(device)
                labels = data[label_key].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)["recons"]
                    loss = criterion(outputs, inputs)

                    # backward + optimize only if in training phase
                if phase == "train":
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            print("{} Loss: {:.6f}".format(phase, epoch_loss))

            # deep copy the model if it has the best val accurary
            if phase == "val" and epoch_loss < best_loss:
                best_loss = epoch_loss
                # best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print("Training complete in {:.0f}m {:.0f}s".format(
        time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    # model.load_state_dict(best_model_wts)
    return model
Example #2
0
    def __init__(
        self,
        output_dir: str,
        latent_structure_model_config: dict = None,
        train_val_test_split: List = [0.7, 0.2, 0.1],
        batch_size: int = 64,
        num_epochs: int = 500,
        early_stopping: int = 20,
        random_state: int = 42,
    ):
        # I/O attributes
        self.output_dir = output_dir

        # Training attributes
        self.num_epochs = num_epochs
        self.early_stopping = early_stopping
        self.batch_size = batch_size
        self.train_val_test_split = train_val_test_split

        # Other attributes
        self.latent_structure_model_config = latent_structure_model_config
        self.random_state = random_state
        self.loss_dict = None
        self.device = get_device()

        # Fix random seeds for reproducibility and limit applicable algorithms to those believed to be deterministic
        torch.manual_seed(self.random_state)
        np.random.seed(self.random_state)
        seed(self.random_state)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
Example #3
0
    def __init__(
        self,
        output_dir: str,
        n_folds: int = 4,
        latent_structure_model_config: dict = None,
        train_val_split: List = [0.8, 0.2],
        batch_size: int = 32,
        num_epochs: int = 500,
        early_stopping: int = 20,
        random_state: int = 42,
    ):
        self.domain_configs = None
        self.output_dir = output_dir
        self.latent_structure_model_config = latent_structure_model_config
        self.n_folds = n_folds
        self.train_val_split = train_val_split
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.early_stopping = early_stopping
        self.random_state = random_state
        self.device = get_device()

        self.loss_dicts = None
        self.trained_models = None

        # Fix random seeds for reproducibility and limit applicable algorithms to those believed to be deterministic
        torch.manual_seed(self.random_state)
        np.random.seed(self.random_state)
        seed(self.random_state)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
Example #4
0
def analyze_guided_gradcam_for_genesets(
    image_to_geneset_translator: ImageToGeneSetTranslator,
    image_dataloader: DataLoader,
    image_data_key: str,
    target_layer: str,
    query_node: int,
):
    device = get_device()
    image_to_geneset_translator.eval().to(device)

    grad_cam = GradCam(
        model=image_to_geneset_translator,
        feature_module=image_to_geneset_translator.encoder,
        target_layer_names=[target_layer],
        device=device,
    )
    gb_model = GuidedBackpropReLUModel(
        model=copy.deepcopy(image_to_geneset_translator), device=device)
    guided_backpropagation_maps = []
    gradient_cams = []
    images = []

    for (i, sample) in enumerate(image_dataloader):
        inputs = sample[image_data_key].to(device)
        gradient_cams.append(
            grad_cam(inputs, target_node=query_node).cpu().numpy())
        guided_backpropagation_maps.extend(
            list(gb_model(inputs, target_node=query_node)))
        images.extend(list(inputs.detach().cpu().numpy()))
    data_dict = {
        "images": np.array(images),
        "gb_maps": np.array(guided_backpropagation_maps),
        "grad_cams": np.array(gradient_cams),
    }
    return data_dict
def test_model(model, dataloaders, criterion):
    device = get_device()
    model.to(device)
    model.eval()
    runninq_loss = 0
    for data in dataloaders["test"]:
        inputs = data[data_key].to(device)
        outputs = model(inputs)["recons"]
        runninq_loss += criterion(outputs, inputs).item() * inputs.size(0)
    test_loss = runninq_loss / len(dataloaders["test"].dataset)
    print("Test loss: ", test_loss)
    return test_loss
Example #6
0
def get_latent_model_configuration(model_dict: dict, optimizer_dict: dict,
                                   loss_dict: dict, device: None) -> dict:

    if device is None:
        device = get_device()

    model_type = model_dict.pop("type")
    if model_type == "LatentDiscriminator":
        model = LatentDiscriminator(**model_dict)
    elif model_type == "LatentClassifier":
        model = LatentClassifier(**model_dict)
    elif model_type == "LatentRegressor":
        model = LatentRegressor(**model_dict)
    else:
        raise NotImplementedError('Unknown model type "{}"'.format(model_type))

    optimizer = get_optimizer_for_model(optimizer_dict=optimizer_dict,
                                        model=model)

    if model_type != "LatentRegressor":
        try:
            weights = torch.FloatTensor(loss_dict.pop("weights")).to(device)
        except KeyError:
            weights = torch.ones(model_dict["n_classes"]).float().to(device)
    else:
        weights = None

    loss_type = loss_dict.pop("type")
    if loss_type == "ce":
        latent_loss = RobustCrossEntropyLoss(weight=weights)
    elif loss_type == "bce":
        latent_loss = RobustBCELoss(weight=weights)
    elif loss_type == "bce_ll":
        latent_loss = RobustBCEWithLogitsLoss(weight=weights)
    elif loss_type == "mse":
        latent_loss = RobustMSELoss()
    elif loss_type == "mae":
        latent_loss = RobustL1Loss()
    elif loss_type == "weighted_mse":
        latent_loss = RobustWeightedMSELoss()
    else:
        raise NotImplementedError('Unknown loss type "{}"'.format(loss_type))

    latent_model_config = {
        "model": model,
        "optimizer": optimizer,
        "loss": latent_loss
    }
    return latent_model_config
Example #7
0
def analyze_geneset_perturbation_in_image(
    geneset_ae: GeneSetAE,
    image_ae: Module,
    seq_dataloader: DataLoader,
    seq_data_key: str,
    silencing_node: int,
):
    device = get_device()
    geneset_ae.to(device).eval()
    image_ae.to(device).eval()
    perturbation_geneset_ae = PerturbationGeneSetAE(
        input_dim=geneset_ae.input_dim,
        latent_dim=geneset_ae.latent_dim,
        hidden_dims=geneset_ae.hidden_dims,
        geneset_adjacencies=geneset_ae.geneset_adjacencies,
    )
    perturbation_geneset_ae.to(device)
    perturbation_geneset_ae.load_state_dict(geneset_ae.state_dict())
    perturbation_geneset_ae.eval()
    geneset_ae.cpu()
    translated_images = []
    perturbed_translated_images = []
    recon_sequences = []
    perturbed_recon_sequences = []

    for (i, sample) in enumerate(seq_dataloader):
        inputs = sample[seq_data_key].to(device)
        output_dict = perturbation_geneset_ae(inputs, silencing_node)
        recon_sequences.extend(
            list(output_dict["recons"].detach().cpu().numpy()))
        latents = output_dict["latents"]
        geneset_activites = output_dict["geneset_activities"]
        perturbed_latents = output_dict["perturbed_latents"]
        perturbed_recon_sequences.extend(
            list(output_dict["perturbed_recons"].detach().cpu().numpy()))
        translated_images.extend(
            list(image_ae.decode(latents).detach().cpu().numpy()))
        perturbed_translated_images.extend(
            list(image_ae.decode(perturbed_latents).detach().cpu().numpy()))

    data_dict = {
        "seq_recons": np.array(recon_sequences),
        "perturbed_seq_recons": np.array(perturbed_recon_sequences),
        "trans_images": np.array(translated_images),
        "perturbed_trans_images": np.array(perturbed_translated_images),
    }
    return data_dict
Example #8
0
    def __init__(
        self,
        output_dir: str,
        data_config: dict,
        model_config: dict,
        domain_name: str,
        latent_structure_model_config: dict = None,
        train_val_test_split: List[float] = [0.7, 0.2, 0.1],
        batch_size: int = 64,
        num_epochs: int = 64,
        early_stopping: int = -1,
        random_state: int = 42,
    ):
        super().__init__(
            output_dir=output_dir,
            train_val_test_split=train_val_test_split,
            batch_size=batch_size,
            num_epochs=num_epochs,
            early_stopping=early_stopping,
            random_state=random_state,
        )

        self.data_config = data_config
        self.model_config = model_config
        self.domain_name = domain_name
        self.latent_structure_model_config = latent_structure_model_config

        self.data_set = None
        self.data_transform_pipeline_dict = None
        self.data_loader_dict = None
        self.data_key = None
        self.label_key = None
        self.domain_config = None

        self.trained_models = None
        self.loss_dict = None

        self.device = get_device()
Example #9
0
    def __init__(
        self,
        input_channels: int = 1,
        latent_dim: int = 128,
        hidden_dims: List[int] = [128, 256, 512, 1024, 1024],
        lrelu_slope: int = 0.2,
        batchnorm: bool = True,
    ) -> None:
        super().__init__()
        self.in_channels = input_channels
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.lrelu_slope = lrelu_slope
        self.batchnorm = batchnorm
        self.updated = False
        self.n_latent_spaces = 1

        # Build encoder
        encoder_modules = [
            nn.Sequential(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.hidden_dims[0],
                    kernel_size=4,
                    stride=2,
                    padding=1,
                    bias=False,
                ),
                nn.ReLU(),
            )
        ]

        for i in range(1, len(self.hidden_dims)):
            encoder_modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=self.hidden_dims[i - 1],
                        out_channels=self.hidden_dims[i],
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                    nn.BatchNorm2d(self.hidden_dims[i]),
                    nn.ReLU(),
                ))
        self.encoder = nn.Sequential(*encoder_modules)

        # Output of encoder are of shape 1024x4x4
        self.device = get_device()

        if self.batchnorm:
            self.latent_mapper = nn.Sequential(
                nn.Linear(hidden_dims[-1] * 2 * 2, self.latent_dim),
                nn.BatchNorm1d(self.latent_dim),
            )
        else:
            self.latent_mapper = nn.Linear(hidden_dims[-1] * 2 * 2,
                                           self.latent_dim)

        self.inv_latent_mapper = nn.Sequential(
            nn.Linear(self.latent_dim, hidden_dims[-1] * 2 * 2),
            nn.ReLU(inplace=True))

        # decoder
        decoder_modules = []
        for i in range(len(hidden_dims) - 1):
            decoder_modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        in_channels=hidden_dims[-1 - i],
                        out_channels=hidden_dims[-2 - i],
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                    nn.BatchNorm2d(hidden_dims[-2 - i]),
                    nn.ReLU(),
                ))
        decoder_modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=hidden_dims[0],
                    out_channels=self.in_channels,
                    kernel_size=4,
                    stride=2,
                    padding=1,
                    bias=False,
                ),
                nn.Sigmoid(),
            ))

        self.decoder = nn.Sequential(*decoder_modules)
Example #10
0
        x = x.view(1, -1)
        return self.latent_mapper(x)


feature_extractor = ae_model.encoder
full_encoder = FullEncoder(feature_extractor, ae_model.latent_mapper)

image = io.imread(
    "/home/daniel/PycharmProjects/domain_translation/data/cd4/nuclear_crops_all_experiments/128px/labeled_scaled_max_intensity_resized_images/11J1_CD4T_488Coro1A_555RPL10A_D001_05_nucleus_472_97_2_13.tif"
)
image = np.array(image, dtype=np.float32)
input_image = torch.from_numpy(image.copy()).unsqueeze(0)
# transform_pipeline = Compose([ToPILImage(), ToTensor()])
# input_image = transform_pipeline(input_image)

device = get_device()

feature_extractor.to(device)
full_encoder.to(device)
input_image = input_image.to(device).unsqueeze(0)

query_node = 115

grad_cam = GradCam(
    model=full_encoder,
    feature_module=full_encoder.feature_extractor,
    target_layer_names=["4"],
    device=device,
)

grayscale_cam = grad_cam(input_image, query_node)
Example #11
0
def perform_latent_walk_in_umap_space(domain_configs: List[DomainConfig],
                                      dataloader_type: str,
                                      random_state: int = 1234):
    if len(domain_configs) != 2:
        raise RuntimeError(
            "Expects two domain configurations (image and sequencing domain)")
    if domain_configs[0].name == "image" and domain_configs[1].name == "rna":
        image_domain_config = domain_configs[0]
        rna_domain_config = domain_configs[1]
    elif domain_configs[0].name == "rna" and domain_configs[1].name == "image":
        image_domain_config = domain_configs[1]
        rna_domain_config = domain_configs[0]
    else:
        raise RuntimeError(
            "Expected domain configuration types are >image< and >rna<.")

    rna_data_loader = rna_domain_config.data_loader_dict[dataloader_type]
    image_data_loader = image_domain_config.data_loader_dict[dataloader_type]
    device = get_device()

    geneset_ae = rna_domain_config.domain_model_config.model.to(device).eval()
    image_ae = image_domain_config.domain_model_config.model.to(device).eval()

    all_rna_latents = []
    all_rna_labels = []
    all_image_latents = []
    all_image_labels = []
    grid_sequences = []
    grid_geneset_activities = []
    grid_images = []
    rna_cell_ids = []
    image_cell_ids = []

    for i, sample in enumerate(rna_data_loader):
        rna_inputs = sample[rna_domain_config.data_key].to(device)
        rna_labels = sample[rna_domain_config.label_key]
        rna_cell_ids.extend(sample["id"])

        geneset_ae_output = geneset_ae(rna_inputs)
        latents = geneset_ae_output["latents"]
        all_rna_latents.extend(list(latents.clone().detach().cpu().numpy()))
        all_rna_labels.extend(list(rna_labels.clone().detach().cpu().numpy()))

    for i, sample in enumerate(image_data_loader):
        image_inputs = sample[image_domain_config.data_key].to(device)
        image_labels = sample[image_domain_config.label_key].to(device)
        image_cell_ids.extend(sample["id"])

        image_ae_output = image_ae(image_inputs)
        latents = image_ae_output["latents"]
        all_image_latents.extend(list(latents.clone().detach().cpu().numpy()))
        all_image_labels.extend(
            list(image_labels.clone().detach().cpu().numpy()))

    all_latents = np.concatenate(
        (np.array(all_image_latents), np.array(all_rna_latents)), axis=0)
    all_labels = np.concatenate(
        (np.array(all_image_labels), np.array(all_rna_labels)), axis=0)
    all_domain_labels = np.concatenate(
        (
            np.repeat("image", len(all_image_labels)),
            np.repeat("rna", len(all_rna_labels)),
        ),
        axis=0,
    )
    all_cell_ids = np.concatenate((image_cell_ids, rna_cell_ids), axis=0)

    mapper = UMAP(random_state=random_state)
    transformed = mapper.fit_transform(all_latents)
    min_umap_c1 = min(transformed[:, 0])
    max_umap_c1 = max(transformed[:, 0])
    min_umap_c2 = min(transformed[:, 1])
    max_umap_c2 = max(transformed[:, 1])

    test_pts = np.array([
        (np.array([min_umap_c1, max_umap_c2]) *
         (1 - x) + np.array([max_umap_c1, max_umap_c2]) * x) * (1 - y) +
        (np.array([min_umap_c1, min_umap_c2]) *
         (1 - x) + np.array([max_umap_c1, min_umap_c2]) * x) * y
        for y in np.linspace(0, 1, 10) for x in np.linspace(0, 1, 10)
    ])

    inv_transformed_points = mapper.inverse_transform(test_pts)
    test_pts_ds = torch.utils.data.TensorDataset(
        torch.from_numpy(inv_transformed_points))
    test_pts_loader = torch.utils.data.DataLoader(test_pts_ds,
                                                  batch_size=64,
                                                  shuffle=False)

    for i, sample in enumerate(test_pts_loader):
        image_recons = image_ae.decode(sample[0].to(device))
        rna_recons, decoded_geneset_activities = geneset_ae.decode(
            sample[0].to(device))

        grid_images.extend(list(image_recons.clone().detach().cpu().numpy()))
        grid_sequences.extend(list(rna_recons.clone().detach().cpu().numpy()))
        grid_geneset_activities.extend(
            list(decoded_geneset_activities.clone().detach().cpu().numpy()))

    data_dict = {
        "grid_points": test_pts,
        "grid_images": grid_images,
        "grid_sequences": grid_sequences,
        "grid_geneset_activities": grid_geneset_activities,
        "all_latents": all_latents,
        "all_labels": all_labels,
        "all_domain_labels": all_domain_labels,
        "all_cell_ids": all_cell_ids,
    }

    return data_dict
Example #12
0
def get_geneset_activities_and_translated_images_sequences(
        domain_configs: List[DomainConfig], dataloader_type: str):
    if len(domain_configs) != 2:
        raise RuntimeError(
            "Expects two domain configurations (image and sequencing domain)")
    if domain_configs[0].name == "image" and domain_configs[1].name == "rna":
        image_domain_config = domain_configs[0]
        rna_domain_config = domain_configs[1]
    elif domain_configs[0].name == "rna" and domain_configs[1].name == "image":
        image_domain_config = domain_configs[1]
        rna_domain_config = domain_configs[0]
    else:
        raise RuntimeError(
            "Expected domain configuration types are >image< and >rna<.")

    rna_data_loader = rna_domain_config.data_loader_dict[dataloader_type]
    image_data_loader = image_domain_config.data_loader_dict[dataloader_type]
    device = get_device()

    rna_cell_ids = []
    all_rna_labels = []
    all_rna_inputs = []
    all_rna_latents = []
    all_geneset_activities = []
    all_reconstructed_geneset_activities = []
    all_reconstructed_rna_inputs = []
    all_translated_images = []
    all_image_latents = []
    all_reenconded_rna_latents = []

    image_cell_ids = []
    all_image_labels = []
    all_image_inputs = []
    all_translated_rna_seq = []
    all_translated_rna_latents = []
    all_translated_geneset_activities = []
    all_translated_image_latents = []

    geneset_ae = rna_domain_config.domain_model_config.model.to(device).eval()
    image_ae = image_domain_config.domain_model_config.model.to(device).eval()

    for i, sample in enumerate(rna_data_loader):
        rna_inputs = sample[rna_domain_config.data_key].to(device)
        rna_labels = sample[rna_domain_config.label_key]
        cell_ids = sample["id"]

        geneset_ae_output = geneset_ae(rna_inputs)
        latents = geneset_ae_output["latents"]
        geneset_activities = geneset_ae_output["geneset_activities"]
        reconstructed_geneset_activities = geneset_ae_output[
            "decoded_geneset_activities"]
        reconstructed_rna_inputs = geneset_ae_output["recons"]
        reencoded_rna_latents = geneset_ae(reconstructed_rna_inputs)["latents"]
        translated_images = image_ae.decode(latents)
        translated_image_latents = image_ae(translated_images)["latents"]

        rna_cell_ids.extend(cell_ids)
        all_rna_labels.extend(list(rna_labels.clone().detach().cpu().numpy()))
        all_rna_inputs.extend(list(rna_inputs.clone().detach().cpu().numpy()))
        all_geneset_activities.extend(
            list(geneset_activities.clone().detach().cpu().numpy()))
        all_translated_images.extend(
            list(translated_images.clone().detach().cpu().numpy()))
        all_translated_image_latents.extend(
            list(translated_image_latents.clone().detach().cpu().numpy()))

        all_rna_latents.extend(list(latents.clone().detach().cpu().numpy()))

        all_reconstructed_geneset_activities.extend(
            list(reconstructed_geneset_activities.clone().detach().cpu().numpy(
            )))
        all_reconstructed_rna_inputs.extend(
            list(reconstructed_rna_inputs.clone().detach().cpu().numpy()))
        all_reenconded_rna_latents.extend(
            list(reencoded_rna_latents.clone().detach().cpu().numpy()))

    for i, sample in enumerate(image_data_loader):
        image_inputs = sample[image_domain_config.data_key].to(device)
        image_labels = sample[image_domain_config.label_key].to(device)
        cell_ids = sample["id"]

        image_ae_output = image_ae(image_inputs)
        latents = image_ae_output["latents"]
        translated_sequences, translated_geneset_activities = geneset_ae.decode(
            latents)
        geneset_ae_output = geneset_ae(translated_sequences)
        translated_rna_latents = geneset_ae_output["latents"]

        image_cell_ids.extend(cell_ids)
        all_image_labels.extend(
            list(image_labels.clone().detach().cpu().numpy()))
        all_image_inputs.extend(
            list(image_inputs.clone().detach().cpu().numpy()))
        all_translated_geneset_activities.extend(
            list(translated_geneset_activities.clone().detach().cpu().numpy()))
        all_translated_rna_seq.extend(
            list(translated_sequences.clone().detach().cpu().numpy()))
        all_image_latents.extend(list(latents.clone().detach().cpu().numpy()))

        all_translated_rna_latents.extend(
            list(translated_rna_latents.clone().detach().cpu().numpy()))

    data_dict = {
        "rna_cell_ids": rna_cell_ids,
        "rna_labels": all_rna_labels,
        "rna_inputs": all_rna_inputs,
        "rna_latents": all_rna_latents,
        "geneset_activities": all_geneset_activities,
        "reconstructed_geneset_activities":
        all_reconstructed_geneset_activities,
        "reconstructed_rna_inputs": all_reconstructed_rna_inputs,
        "reencoded_rna_latents": all_reenconded_rna_latents,
        "translated_images": all_translated_images,
        "translated_image_latents": all_translated_image_latents,
        "image_cell_ids": image_cell_ids,
        "image_labels": all_image_labels,
        "image_inputs": all_image_inputs,
        "image_latents": all_image_latents,
        "translated_sequences": all_translated_rna_seq,
        "translated_sequence_latents": all_translated_rna_latents,
        "translated_geneset_activities": all_translated_geneset_activities,
    }
    return data_dict