示例#1
0
def load_inputs(image_path,
                metadata_path,
                use_heatmaps,
                benign_heatmap_path=None,
                malignant_heatmap_path=None):
    """
    Load a single input example, optionally with heatmaps
    """
    if use_heatmaps:
        assert benign_heatmap_path is not None
        assert malignant_heatmap_path is not None
    else:
        assert benign_heatmap_path is None
        assert malignant_heatmap_path is None
    metadata = pickling.unpickle_from_file(metadata_path)
    image = loading.load_image(
        image_path=image_path,
        view=metadata["full_view"],
        horizontal_flip=metadata["horizontal_flip"],
    )
    if use_heatmaps:
        heatmaps = loading.load_heatmaps(
            benign_heatmap_path=benign_heatmap_path,
            malignant_heatmap_path=malignant_heatmap_path,
            view=metadata["full_view"],
            horizontal_flip=metadata["horizontal_flip"],
        )
    else:
        heatmaps = None
    return ModelInput(image=image, heatmaps=heatmaps, metadata=metadata)
def run(parameters):
    """
    Outputs the predictions as csv file
    """
    random_number_generator = np.random.RandomState(parameters["seed"])
    model, device = load_model(parameters)

    model_input = load_inputs(
        image_path=parameters["cropped_mammogram_path"],
        metadata_path=parameters["metadata_path"],
        use_heatmaps=parameters["use_heatmaps"],
        benign_heatmap_path=parameters["heatmap_path_benign"],
        malignant_heatmap_path=parameters["heatmap_path_malignant"],
    )
    assert model_input.metadata["full_view"] == parameters["view"]

    all_predictions = []

    # set up hook

    activation = {'out_resnet': []}
    # out_shape = [0,256]
    # activation = {'out_resnet': torch.empty(out_shape)}
    handle = model.all_views_avg_pool.register_forward_hook(
        tools.get_activation(activation, 'out_resnet'))
    # handle = model.view_resnet.layer_list[4][1].conv2.register_forward_hook(tools.get_activation(activation, 'out_resnet'))

    for datum in tqdm.tqdm(exam_list):
        for short_file_path in datum[view]:
            loaded_image = loading.load_image(
                image_path=os.path.join(parameters["image_path"],
                                        short_file_path + image_extension),
                view=view,
                horizontal_flip=datum["horizontal_flip"],
            )
            loaded_image_dict[view].append(loaded_image)

    for data_batch in tools.partition_batch(range(parameters["num_epochs"]),
                                            parameters["batch_size"]):
        batch = []
        for _ in data_batch:
            batch.append(
                process_augment_inputs(
                    model_input=model_input,
                    random_number_generator=random_number_generator,
                    parameters=parameters,
                ))

        tensor_batch = batch_to_tensor(batch, device)
        y_hat = model(tensor_batch)
        predictions = np.exp(y_hat.cpu().detach().numpy())[:, :2, 1]
        all_predictions.append(predictions)
    agg_predictions = np.concatenate(all_predictions, axis=0).mean(0)
    predictions_dict = {
        "benign": float(agg_predictions[0]),
        "malignant": float(agg_predictions[1]),
    }
    print(json.dumps(predictions_dict))
示例#3
0
def run_model(model, parameters, data_path):
    """
    Run the model over images in sample_data.
    Save the predictions as csv and visualizations as png.
    """
    exam_list = os.listdir(data_path)

    if (parameters["device_type"] == "gpu") and torch.has_cudnn:
        device = torch.device("cuda:{}".format(parameters["gpu_number"]))
    else:
        device = torch.device("cpu")
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        # load image
        # the image is already flipped so no need to do it again

        for image in tqdm.tqdm(exam_list):
            loaded_image = loading.load_image(
                image_path=os.path.join(parameters["image_path"], image),
                horizontal_flip=False,
            )

            # convert python 2D array into 4D torch tensor in N,C,H,W format
            loaded_image = np.expand_dims(np.expand_dims(loaded_image, 0),
                                          0).copy()
            tensor_batch = torch.Tensor(loaded_image).to(device)
            # forward propagation
            output = model(tensor_batch)

            #turn_on_visualization:
            saliency_maps = model.saliency_map.data.cpu().numpy()
            print('>>>>>>>>>>>>', saliency_maps.shape)
            patch_locations = model.patch_locations
            patch_imgs = model.patches
            patch_attentions = model.patch_attns[0, :].data.cpu().numpy()

            # create directories
            output_path = "/content/gdrive/My Drive/chestxray/chestnet_results/sample_images_final_0.1_final/"
            os.makedirs(output_path, exist_ok=True)
            os.makedirs(os.path.join(output_path, "visualization"),
                        exist_ok=True)
            short_file_path = image.split('.')[0]
            save_dir = os.path.join(parameters["output_path"], "visualization",
                                    "{0}.png".format(short_file_path))
            #print(save_dir)
            visualize_example(loaded_image, saliency_maps, patch_locations,
                              patch_imgs, patch_attentions, save_dir,
                              parameters)
示例#4
0
def ori_image_prepare(image_path, view, horizontal_flip, parameters):
    """
    Loads an image and creates stride_lists
    """
    patch_size = parameters['patch_size']
    more_patches = parameters['more_patches']
    stride_fixed = parameters['stride_fixed']

    image = loading.load_image(image_path, view, horizontal_flip)
    image = image.astype(float)
    loading.standard_normalize_single_image(image)
    
    img_width, img_length = image.shape
    width_stride_list = stride_list_generator(img_width, patch_size, more_patches, stride_fixed)
    length_stride_list = stride_list_generator(img_length, patch_size, more_patches, stride_fixed)

    return image, width_stride_list, length_stride_list
示例#5
0
def ori_image_prepare(short_file_path, view, horizontal_flip, parameters):
    """
    Loads an image and creates stride_lists
    """
    orginal_image_path = parameters['orginal_image_path']
    patch_size = parameters['patch_size']
    more_patches = parameters['more_patches']
    stride_fixed = parameters['stride_fixed']
    
    image_extension = '.hdf5' if parameters['use_hdf5'] else '.png'
    image_path = os.path.join(orginal_image_path, short_file_path + image_extension)
    image = loading.load_image(image_path, view, horizontal_flip)
    image = image.astype(float)
    loading.standard_normalize_single_image(image)
    
    img_width, img_length = image.shape
    width_stride_list = stride_list_generator(img_width, patch_size, more_patches, stride_fixed)
    length_stride_list = stride_list_generator(img_length, patch_size, more_patches, stride_fixed)

    return image, width_stride_list, length_stride_list
示例#6
0
文件: run_model.py 项目: BigOsoft/IMA
def run_model(model, device, exam_list, parameters):
    """
    Returns predictions of image only model or image+heatmaps model.
    Prediction for each exam is averaged for a given number of epochs.
    """
    random_number_generator = np.random.RandomState(parameters["seed"])

    image_extension = ".hdf5" if parameters["use_hdf5"] else ".png"

    with torch.no_grad():
        predictions_ls = []
        for datum in tqdm.tqdm(exam_list):
            predictions_for_datum = []
            loaded_image_dict = {view: [] for view in VIEWS.LIST}
            loaded_heatmaps_dict = {view: [] for view in VIEWS.LIST}
            for view in VIEWS.LIST:
                for short_file_path in datum[view]:
                    loaded_image = loading.load_image(
                        image_path=os.path.join(
                            parameters["image_path"],
                            short_file_path + image_extension),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                    if parameters["use_heatmaps"]:
                        loaded_heatmaps = loading.load_heatmaps(
                            benign_heatmap_path=os.path.join(
                                parameters["heatmaps_path"], "heatmap_benign",
                                short_file_path + ".hdf5"),
                            malignant_heatmap_path=os.path.join(
                                parameters["heatmaps_path"],
                                "heatmap_malignant",
                                short_file_path + ".hdf5"),
                            view=view,
                            horizontal_flip=datum["horizontal_flip"],
                        )
                    else:
                        loaded_heatmaps = None

                    loaded_image_dict[view].append(loaded_image)
                    loaded_heatmaps_dict[view].append(loaded_heatmaps)
            for data_batch in tools.partition_batch(
                    range(parameters["num_epochs"]), parameters["batch_size"]):
                batch_dict = {view: [] for view in VIEWS.LIST}
                for _ in data_batch:
                    for view in VIEWS.LIST:
                        image_index = 0
                        if parameters["augmentation"]:
                            image_index = random_number_generator.randint(
                                low=0, high=len(datum[view]))
                        cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(
                            image=loaded_image_dict[view][image_index],
                            auxiliary_image=loaded_heatmaps_dict[view]
                            [image_index],
                            view=view,
                            best_center=datum["best_center"][view]
                            [image_index],
                            random_number_generator=random_number_generator,
                            augmentation=parameters["augmentation"],
                            max_crop_noise=parameters["max_crop_noise"],
                            max_crop_size_noise=parameters[
                                "max_crop_size_noise"],
                        )
                        if loaded_heatmaps_dict[view][image_index] is None:
                            batch_dict[view].append(cropped_image[:, :,
                                                                  np.newaxis])
                        else:
                            batch_dict[view].append(
                                np.concatenate([
                                    cropped_image[:, :, np.newaxis],
                                    cropped_heatmaps,
                                ],
                                               axis=2))

                tensor_batch = {
                    view: torch.tensor(np.stack(batch_dict[view])).permute(
                        0, 3, 1, 2).to(device)
                    for view in VIEWS.LIST
                }
                output = model(tensor_batch)
                batch_predictions = compute_batch_predictions(
                    output, mode=parameters["model_mode"])
                pred_df = pd.DataFrame(
                    {k: v[:, 1]
                     for k, v in batch_predictions.items()})
                pred_df.columns.names = ["label", "view_angle"]
                predictions = pred_df.T.reset_index().groupby(
                    "label").mean().T[LABELS.LIST].values
                predictions_for_datum.append(predictions)
            predictions_ls.append(
                np.mean(np.concatenate(predictions_for_datum, axis=0), axis=0))

    return np.array(predictions_ls)
示例#7
0
            def run_model(model, exam_list, parameters, turn_on_visualization):
                """
                Run the model over images in sample_data.
                Save the predictions as csv and visualizations as png.
                """
                if (parameters["device_type"] == "gpu") and torch.has_cudnn:
                    device = torch.device("cuda:{}".format(
                        parameters["gpu_number"]))
                else:
                    device = torch.device("cpu")
                model = model.to(device)
                model.eval()

                # initialize data holders
                pred_dict = {
                    "image_index": [],
                    "benign_pred": [],
                    "malignant_pred": [],
                    "benign_label": 'no input',
                    "malignant_label": 'no input'
                }
                with torch.no_grad():
                    # iterate through each exam
                    for datum in tqdm.tqdm(exam_list):
                        for view in VIEWS.LIST:
                            short_file_path = datum[view][0]
                            # load image
                            # the image is already flipped so no need to do it again
                            loaded_image = loading.load_image(
                                image_path=os.path.join(
                                    parameters["image_path"],
                                    short_file_path + ".png"),
                                view=view,
                                horizontal_flip=False,
                            )
                            loading.standard_normalize_single_image(
                                loaded_image)
                            # load segmentation if available
                            benign_seg_path = os.path.join(
                                parameters["segmentation_path"],
                                "{0}_{1}".format(short_file_path,
                                                 "benign.png"))
                            malignant_seg_path = os.path.join(
                                parameters["segmentation_path"],
                                "{0}_{1}".format(short_file_path,
                                                 "malignant.png"))

                            benign_seg = np.zeros([1920, 2944], dtype=int)
                            #benign_seg = None
                            malignant_seg = np.zeros([1920, 2944], dtype=int)
                            #malignant_seg = None
                            if os.path.exists(benign_seg_path):
                                loaded_seg = loading.load_image(
                                    image_path=benign_seg_path,
                                    view=view,
                                    horizontal_flip=False,
                                )
                                benign_seg = loaded_seg
                            if os.path.exists(malignant_seg_path):
                                loaded_seg = loading.load_image(
                                    image_path=malignant_seg_path,
                                    view=view,
                                    horizontal_flip=False,
                                )
                                malignant_seg = loaded_seg
                            # convert python 2D array into 4D torch tensor in N,C,H,W format
                            loaded_image = np.expand_dims(
                                np.expand_dims(loaded_image, 0), 0).copy()
                            tensor_batch = torch.Tensor(loaded_image).to(
                                device)
                            # forward propagation
                            output = model(tensor_batch)
                            pred_numpy = output.data.cpu().numpy()
                            benign_pred, malignant_pred = pred_numpy[
                                0, 0], pred_numpy[0, 1]
                            # save visualization
                            if turn_on_visualization:
                                saliency_maps = model.saliency_map.data.cpu(
                                ).numpy()
                                patch_locations = model.patch_locations
                                patch_imgs = model.patches
                                patch_attentions = model.patch_attns[
                                    0, :].data.cpu().numpy()
                                save_dir = os.path.join(
                                    parameters["output_path"], "visualization",
                                    "{0}.png".format(short_file_path))
                                #b = np.zeros([2560 , 3328], dtype = int)
                                visualize_example(loaded_image, saliency_maps,
                                                  [benign_seg, malignant_seg],
                                                  patch_locations, patch_imgs,
                                                  patch_attentions, save_dir,
                                                  parameters)
                            # propagate holders
                            #benign_label, malignant_label = fetch_cancer_label_by_view(view, datum["cancer_label"])
                            pred_dict["image_index"].append(short_file_path)
                            pred_dict["benign_pred"].append(benign_pred)
                            pred_dict["malignant_pred"].append(malignant_pred)
                            #pred_dict["benign_label"].append(benign_label)
                            #pred_dict["malignant_label"].append(malignant_label)
                return pd.DataFrame(pred_dict)
    def __getitem__(self, idx):
        batch_x_dict = self.x[idx * self.batch_size:(idx + 1) *
                              self.batch_size]

        batch_y = []
        batch_x = []
        batch_x_feed = {}
        batch_x_feed_new = {}
        for view in VIEWS.LIST:
            batch_x_feed_new[view] = []

        test_x_var = []
        test_y_var = []

        image_extension = ".hdf5" if self.parameters["use_hdf5"] else ".png"
        image_index = 0

        for datum in (
                batch_x_dict
        ):  # THIS AUGMENTATION PART IS PARTIALLY TAKEN FROM run_model.py

            patient = datum['L-CC'][0].split('_')[1]  # get name of patient
            patient_screens = []
            patient_screens_dict = {}

            patient_dict = {
                view: []
                for view in VIEWS.LIST
            }  # create structure that will be filled with images
            for view in VIEWS.LIST:
                patient_screens_dict[view] = []

                short_file_path = datum[view][
                    image_index]  # name of image file associated to that view
                test_x_var.append(short_file_path)

                loaded_image = loading.load_image(
                    image_path=os.path.join(self.parameters["image_path"],
                                            short_file_path + image_extension),
                    view=view,
                    horizontal_flip=datum["horizontal_flip"],
                )
                if self.parameters["use_heatmaps"]:
                    loaded_heatmaps = loading.load_heatmaps(
                        benign_heatmap_path=os.path.join(
                            self.parameters["heatmaps_path"], "heatmap_benign",
                            short_file_path + ".hdf5"),
                        malignant_heatmap_path=os.path.join(
                            self.parameters["heatmaps_path"],
                            "heatmap_malignant", short_file_path + ".hdf5"),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                else:
                    loaded_heatmaps = None

                if self.parameters["augmentation"]:
                    image_index = self.random_number_generator.randint(
                        low=0, high=len(datum[view]))
                cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(  # producing cropped image
                    image=loaded_image,
                    auxiliary_image=loaded_heatmaps,
                    view=view,
                    best_center=datum["best_center"][view][image_index],
                    random_number_generator=self.random_number_generator,
                    augmentation=self.parameters["augmentation"],
                    max_crop_noise=self.parameters["max_crop_noise"],
                    max_crop_size_noise=self.parameters["max_crop_size_noise"],
                )
                if loaded_heatmaps is None:
                    patient_dict[view].append(
                        cropped_image[:, :, np.newaxis]
                    )  # adding the image to the list by specific view (i.e L-MLO, R-CC...)
                else:
                    patient_dict[view].append(
                        np.concatenate(
                            [  # adding the image and heatmap to the list by specific view (i.e L-MLO, R-CC...)
                                cropped_image[:, :, np.newaxis],
                                cropped_heatmaps,
                            ],
                            axis=2))

                batch_x_feed_new[view].append(
                    patient_dict[view][-1]
                )  # adding image of specific view to dictionary of patient

            batch_y.append(self.y[patient])  # output related to the patient

        batch_y = np.stack(np.array(batch_y), axis=0)

        # IF U WANT (BATCH, 4, 2) DECOMMENT THIS
        # for ni, i in enumerate(batch_y):
        #     if ni==0:
        #         #batch_y_tf = tf.expand_dims(tf.one_hot(i,2), 0)
        #         batch_y_tf = tf.expand_dims(tf.cast(tf.math.logical_not(tf.cast(tf.one_hot(i,2), dtype=tf.bool)), dtype='float32'), 0)
        #         #print(batch_y_tf)
        #     else:
        #         new_patient = tf.expand_dims(tf.cast(tf.math.logical_not(tf.cast(tf.one_hot(i,2), dtype=tf.bool)), dtype='float32'), 0)
        #         batch_y_tf = tf.concat([batch_y_tf, new_patient] , axis =0)
        # #print(batch_y_tf)
        # for n, view in enumerate(VIEWS.LIST):
        #     batch_x_feed[view] = np.moveaxis(np.stack(np.array(batch_x)[:, n], axis=0), -1,1) # dictionary: key=scan(i.e. 'L-CC', 'R-CC'...), value=tensor of scans (1 image per patient)

        for n, view in enumerate(VIEWS.LIST):
            batch_x_feed_new[view] = np.moveaxis(
                np.stack(np.array(batch_x_feed_new[view]), axis=0), -1,
                1)  # to obtain (1, width, height)

        # returns AS X a dict with all images for all patient patient organized by keys
        # that are the views (i.e. {'L-CC: [all L-CC iamges of patients per batch], 'L-MLO': ...})
        # returns as y a tensor of [batch, classes]
        return batch_x_feed_new, batch_y  #, [None]
def run_sub_model(model, exam_list, parameters):
    """
    Returns predictions of image only model or image+heatmaps model. 
    Prediction for each exam is averaged for a given number of epochs.
    """
    if (parameters["device_type"] == "gpu") and torch.has_cudnn:
        device = torch.device(f"cuda:{parameters["gpu_number"]}")
    else:
        device = torch.device("cpu")
    sub_model = sub_model.to(device)
    # F: sets model in evaluation mode. It has an effect in certain modules: e.g. Dropout or BatchNorm Layers
    sub_model.eval()

    random_number_generator = np.random.RandomState(parameters["seed"])

    image_extension = ".hdf5" if parameters["use_hdf5"] else ".png"

    with torch.no_grad():
        predictions_ls = []
        for datum in tqdm.tqdm(exam_list):
            predictions_for_datum = []
            loaded_image_dict = {view: [] for view in VIEWS.LIST}
            loaded_heatmaps_dict = {view: [] for view in VIEWS.LIST}
            for view in VIEWS.LIST:
                # F: for one exam, all images of a specific view
                for short_file_path in datum[view]:
                    loaded_image = loading.load_image(
                        image_path=os.path.join(parameters["image_path"], short_file_path + image_extension),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                 
                    loaded_image_dict[view].append(loaded_image)
            print(f"length loaded_image: {len(loaded_image_dict)}")
            for data_batch in tools.partition_batch(range(parameters["num_epochs"]), parameters["batch_size"]):
                print(f"num_epochs: {parameters['num_epochs']}")
                print(f"batch_size: {parameters['batch_size']}")
                tmp = tools.partition_batch(range(parameters["num_epochs"]), parameters["batch_size"])
                print(f"partition_batch: {tmp}")
                batch_dict = {view: [] for view in VIEWS.LIST}
                for _ in data_batch:
                    for view in VIEWS.LIST:
                        image_index = 0
                        # F: they use different augmentation for each view
                        if parameters["augmentation"]:
                            image_index = random_number_generator.randint(low=0, high=len(datum[view]))
                        cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(
                            image=loaded_image_dict[view][image_index], 
                            auxiliary_image=loaded_heatmaps_dict[view][image_index],
                            view=view,
                            best_center=datum["best_center"][view][image_index],
                            random_number_generator=random_number_generator,
                            augmentation=parameters["augmentation"],
                            max_crop_noise=parameters["max_crop_noise"],
                            max_crop_size_noise=parameters["max_crop_size_noise"],
                        )
                        # print(f"cropped_image: {image_index} of m in minibatch: {_} size: {cropped_image.shape}")

                        else:
                            # F: e.g. batch_dict[view][:,:,1] is the first heatmap 
                            batch_dict[view].append(np.concatenate([
                                cropped_image[:, :, np.newaxis],
                                cropped_heatmaps,
                            ], axis=2))

                        # print(f"batch_dict_view: {len(batch_dict[view])}")
                        # print(f"batch_img_size: {batch_dict[view][_].shape}")


                tensor_batch = {
                    # F: result of np.stack has one more dimension:
                    # F: 4 dimensions: batch_data_i, y_pixels, x_pixels, channels 
                    view: torch.tensor(np.stack(batch_dict[view])).permute(0, 3, 1, 2).to(device)
                    for view in VIEWS.LIST
                }


                # print(f"layer_names: {model.state_dict().keys()}")
                # Print model's state_dict
                output = model(tensor_batch)
                batch_predictions = compute_batch_predictions(output)
                print(f"batch_predictions: \n {batch_predictions}")
                print(len(batch_predictions.keys()))
                # F: they pick value 1, disregarding value 0 which is the complement of that (prob = 1) 
                pred_df = pd.DataFrame({k: v[:, 1] for k, v in batch_predictions.items()})
                pred_df.columns.names = ["label", "view_angle"]
                # print(f"pred_df.head: {pred_df.head()}")
                # F: complicated way of grouping by label and calculating the mean                
                predictions = pred_df.T.reset_index().groupby("label").mean().T[LABELS.LIST].values
                predictions_for_datum.append(predictions)
                print(f"predictions: {predictions}")
                exit()
            predictions_ls.append(np.mean(np.concatenate(predictions_for_datum, axis=0), axis=0))
示例#10
0
def run_model(model, device, exam_list, parameters):
    """
    Returns predictions of image only model or image+heatmaps model.
    Prediction for each exam is averaged for a given number of epochs.
    """
    random_number_generator = np.random.RandomState(parameters["seed"])

    image_extension = ".hdf5" if parameters["use_hdf5"] else ".png"

    with torch.no_grad():
        predictions_ls = []
        for datum in tqdm.tqdm(exam_list):
            predictions_for_datum = []
            # F: VIEWS is an adhoc class
            # F: VIEWS.LIST : list of views as string
            loaded_image_dict = {view: [] for view in VIEWS.LIST}
            loaded_heatmaps_dict = {view: [] for view in VIEWS.LIST}
            for view in VIEWS.LIST:
                # F: for one exam, all images of a specific view
                for short_file_path in datum[view]:
                    loaded_image = loading.load_image(
                        image_path=os.path.join(
                            parameters["image_path"],
                            short_file_path + image_extension),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                    if parameters["use_heatmaps"]:
                        loaded_heatmaps = loading.load_heatmaps(
                            benign_heatmap_path=os.path.join(
                                parameters["heatmaps_path"], "heatmap_benign",
                                short_file_path + ".hdf5"),
                            malignant_heatmap_path=os.path.join(
                                parameters["heatmaps_path"],
                                "heatmap_malignant",
                                short_file_path + ".hdf5"),
                            view=view,
                            horizontal_flip=datum["horizontal_flip"],
                        )
                    else:
                        loaded_heatmaps = None

                    loaded_image_dict[view].append(loaded_image)
                    loaded_heatmaps_dict[view].append(loaded_heatmaps)
            # print(f"length loaded_image: {len(loaded_image_dict)}")
            for data_batch in tools.partition_batch(
                    range(parameters["num_epochs"]), parameters["batch_size"]):
                # print(f"num_epochs: {parameters['num_epochs']}")
                # print(f"batch_size: {parameters['batch_size']}")
                tmp = tools.partition_batch(range(parameters["num_epochs"]),
                                            parameters["batch_size"])
                # print(f"partition_batch: {tmp}")
                batch_dict = {view: [] for view in VIEWS.LIST}
                for _ in data_batch:
                    for view in VIEWS.LIST:
                        image_index = 0
                        # F: they use different augmentation for each view
                        if parameters["augmentation"]:
                            image_index = random_number_generator.randint(
                                low=0, high=len(datum[view]))

                        cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(
                            image=loaded_image_dict[view][image_index],
                            auxiliary_image=loaded_heatmaps_dict[view]
                            [image_index],
                            view=view,
                            best_center=datum["best_center"][view]
                            [image_index],
                            random_number_generator=random_number_generator,
                            augmentation=parameters["augmentation"],
                            max_crop_noise=parameters["max_crop_noise"],
                            max_crop_size_noise=parameters[
                                "max_crop_size_noise"],
                        )
                        # print(f"cropped_image: {image_index} of m in minibatch: {_} size: {cropped_image.shape}")

                        if loaded_heatmaps_dict[view][image_index] is None:
                            batch_dict[view].append(cropped_image[:, :,
                                                                  np.newaxis])
                            # F: e.g. batch_dict[view][_].shape = (2974, 1748, 1)

                        else:
                            # F: e.g. batch_dict[view][:,:,1] is the first heatmap
                            batch_dict[view].append(
                                np.concatenate([
                                    cropped_image[:, :, np.newaxis],
                                    cropped_heatmaps,
                                ],
                                               axis=2))

                        # print(f"batch_dict_view: {len(batch_dict[view])}")
                        # print(f"batch_img_size: {batch_dict[view][_].shape}")

                tensor_batch = {
                    # F: result of np.stack has one more dimension:
                    # F: 4 dimensions: batch_data_i, y_pixels, x_pixels, channels
                    view: torch.tensor(np.stack(batch_dict[view])).permute(
                        0, 3, 1, 2).to(device)
                    for view in VIEWS.LIST
                }

                # print(f"layer_names: {model.state_dict().keys()}")
                # Print model's state_dict
                # print("Model's state_dict:")
                # for param_tensor in model.state_dict():
                # print(param_tensor, "\t", model.state_dict()[param_tensor].size())
                output = model(tensor_batch)
                batch_predictions = compute_batch_predictions(
                    output, mode=parameters["model_mode"])
                # print(f"batch_predictions: \n {batch_predictions}")
                # print(len(batch_predictions.keys()))
                # F: they pick value 1, disregarding value 0 which is the complement of that (prob = 1)
                pred_df = pd.DataFrame(
                    {k: v[:, 1]
                     for k, v in batch_predictions.items()})
                pred_df.columns.names = ["label", "view_angle"]
                # print(f"pred_df.head: {pred_df.head()}")
                # F: complicated way of grouping by label and calculating the mean
                predictions = pred_df.T.reset_index().groupby(
                    "label").mean().T[LABELS.LIST].values
                predictions_for_datum.append(predictions)
                # print(f"predictions: {predictions}")
            predictions_ls.append(
                np.mean(np.concatenate(predictions_for_datum, axis=0), axis=0))

    return np.array(predictions_ls)