コード例 #1
0
    def validate(self, dataset, save_path=None, prefix=''):
        """

        """
        with torch.no_grad():
            # make loader
            valid_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers,
                                                       shuffle=False, worker_init_fn=lambda _: np.random.seed())
            n_batch = len(valid_loader)
            l1_fn, l2_fn, gdl_fn = nn.L1Loss(reduction='mean'), nn.MSELoss(reduction='mean'), GDL(reduction='mean', device=self.device)

            self.ae.eval()
            # validate data by batch
            valid_loss = 0.0
            for b, data in enumerate(valid_loader):
                im, idx = data
                im = im.to(self.device).float()
                idx = idx.cpu().numpy()
                # reconstruct
                im_rec = self.ae(im)
                # compute L1 loss
                valid_loss += l1_fn(im_rec, im).item() + l2_fn(im_rec, im).item() + self.lambda_GDL*gdl_fn(im, im_rec).item()
                # save results
                if save_path:
                    for i in range(im.shape[0]):
                        arr = np.concatenate([im[i].permute(1,2,0).squeeze().cpu().numpy(), im_rec[i].permute(1,2,0).squeeze().cpu().numpy()], axis=1)
                        io.imsave(os.path.join(save_path, f'valid_im{idx[i]}{prefix}.png'), img_as_ubyte(arr), check_contrast=False)

                print_progessbar(b, n_batch, Name='Valid Batch', Size=100, erase=True)

        return valid_loss / n_batch
def main(input_data_path, output_data_path, window):
    """
    Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap for the brain extraction.
    """
    # open data info dataframe
    info_df = pd.read_csv(os.path.join(input_data_path, 'info.csv'), index_col=0)
    # make patient directory
    if not os.path.exists(output_data_path): os.mkdir(output_data_path)
    # iterate over volume to extract data
    output_info = []
    for n, id in enumerate(info_df.id.values):
        # read nii volume
        ct_nii = nib.load(os.path.join(input_data_path, f'ct_scans/{id}.nii'))
        mask_nii = nib.load(os.path.join(input_data_path, f'masks/{id}.nii.gz'))
        # get np.array
        ct_vol = ct_nii.get_fdata()
        mask_vol = skimage.img_as_bool(mask_nii.get_fdata())
        # rotate 90° counter clockwise for head pointing upward
        ct_vol = np.rot90(ct_vol, axes=(0,1))
        mask_vol = np.rot90(mask_vol, axes=(0,1))
        # window the ct volume to get better contrast of soft tissues
        if window is not None:
            ct_vol = window_ct(ct_vol, win_center=window[0], win_width=window[1], out_range=(0,1))

        if mask_vol.shape != ct_vol.shape:
            print(f'>>> Warning! The ct volume of patient {id} does not have '
                  f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})')
        # make patient directory
        if not os.path.exists(os.path.join(output_data_path, f'{id:03}/ct/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/ct/'))
        if not os.path.exists(os.path.join(output_data_path, f'{id:03}/mask/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/mask/'))
        # iterate over slices to save slices
        for i, slice in enumerate(range(ct_vol.shape[2])):
            ct_slice_fn =f'{id:03}/ct/{slice+1}.tif'
            # save CT slice
            skimage.io.imsave(os.path.join(output_data_path, ct_slice_fn), ct_vol[:,:,slice], check_contrast=False)
            is_low = True if skimage.exposure.is_low_contrast(ct_vol[:,:,slice]) else False
            # save mask if some brain on slice
            if np.any(mask_vol[:,:,slice]):
                mask_slice_fn = f'{id:03}/mask/{slice+1}_Seg.bmp'
                skimage.io.imsave(os.path.join(output_data_path, mask_slice_fn), skimage.img_as_ubyte(mask_vol[:,:,slice]), check_contrast=False)
            else:
                mask_slice_fn = 'None'
            # add info to output list
            output_info.append({'volume':id, 'slice':slice+1, 'ct_fn':ct_slice_fn, 'mask_fn':mask_slice_fn, 'low_contrast_ct':is_low})

            print_progessbar(i, ct_vol.shape[2], Name=f'Volume {id:03} {n+1:03}/{len(info_df.id):03}',
                             Size=20, erase=False)

    # Make dataframe of outputs
    output_info_df = pd.DataFrame(output_info)
    # save df
    output_info_df.to_csv(os.path.join(output_data_path, 'slice_info.csv'))
    print('>>> Slice informations saved at ' + os.path.join(output_data_path, 'slice_info.csv'))
    # save patient df
    info_df.to_csv(os.path.join(output_data_path, 'volume_info.csv'))
    print('>>> Volume informations saved at ' + os.path.join(output_data_path, 'volume_info.csv'))
コード例 #3
0
    def evaluate(self, dataset):
        """
        Evaluate the passed network on the given dataset for the Context retoration task.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for evaluation. It should output the original image,
            |           and the sample index.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()
        # make loader
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
                                             num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed())
        # put net on device
        self.net = self.net.to(self.device)

        # Evaluate
        logger.info('Start Evaluating the context restoration model.')
        start_time = time.time()
        idx_repr = [] # placeholder for bottleneck representation
        n_batch = len(loader)
        self.net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data : load in standard way (no patch swaped image)
                input, idx = data
                input = input.to(self.device).float()
                idx = idx.to(self.device)
                # get representation
                self.net.return_bottleneck = True
                _, repr = self.net(input)
                # down sample representation for reduced memory impact
                repr = nn.AdaptiveAvgPool2d((4,4))(repr)
                # add ravelled representations to placeholder
                idx_repr += list(zip(idx.cpu().data.tolist(), repr.view(repr.shape[0], -1).cpu().data.tolist()))
                # print_progress
                if self.print_progress:
                    print_progessbar(b, n_batch, Name='\t\tEvaluation Batch', Size=40, erase=True)
            # reset the network attriubtes
            self.net.return_bottleneck = False

        # compute tSNE for representation
        idx, repr = zip(*idx_repr)
        repr = np.array(repr)
        logger.info('Computing the t-SNE representation.')
        repr_2D = TSNE(n_components=2).fit_transform(repr)
        self.outputs['eval']['repr'] = list(zip(idx, repr_2D.tolist()))
        logger.info('Succesfully computed the t-SNE representation.')
        # finish evluation
        self.outputs['eval']['time'] = time.time() - start_time
        logger.info(f"Finished evaluating on the context restoration task in {timedelta(seconds=int(self.outputs['eval']['time']))}")
コード例 #4
0
    def _pixelwise_error(self, input, grid_masks, verbose=False):
        """
        Generate a the pixelwise error sample of input when masked with grid.
        ----------
        INPUT
            |---- input (torch.tensor) the image on which the error sample is computed with the grids. It should have
            |               dimension [C, H, W].
            |---- grid_masks (np.array) the set of grid mask to use for inpainting. It should have dimension [N_grid, H, W].
            |---- verbose (bool) whether to print a progress bar of the inpainting process.
        OUTPUT
            |---- err (np.array) the pixelwise sample of inpainting errors with dimension [N_grid, C, H, W].
        """
        assert input.ndim == 3, f"Input must be 3 dimensional (C x H x W). Got {input.shape}"

        grid_dataset = data.TensorDataset(
            torch.tensor(grid_masks).unsqueeze(1))
        grid_loader = data.DataLoader(grid_dataset, batch_size=self.batch_size)

        input = input.unsqueeze(0)  # add a batch dimension

        error_list = []
        for b, grid_mask_batch in enumerate(grid_loader):
            grid_mask_batch = grid_mask_batch[0]
            # repeat input to along batch dimension
            input_rep = input.repeat(grid_mask_batch.shape[0], 1, 1,
                                     1).to(self.device)
            # inpaint im
            inpaint_im = self._inpaint(input_rep, grid_mask_batch)
            # compute difference to input
            error_list.append(inpaint_im - input_rep)

            if verbose:
                print_progessbar(b,
                                 len(grid_loader),
                                 Name='Grid Inpainting',
                                 Size=50,
                                 erase=True)

        # keep only inpainting error where the grid mask were present
        measures = torch.cat(error_list, dim=0)  # [N_measure x C x H x W]
        grid_sample = torch.tensor(grid_masks).unsqueeze(1).repeat(
            1, input.shape[1], 1,
            1)  # use the whole grid array and repeat the channel dimension
        err = measures.permute(1, 2, 3,
                               0)[grid_sample.permute(1, 2, 3, 0) == 1]
        c, h, w = input.shape[1:]
        err = err.view(c, h, w, -1).permute(3, 0, 1, 2)  # [N_err, C, H, W]

        return err.cpu().numpy()
コード例 #5
0
 def evaluate(self, dataset):
     """
     Evaluate the network on the given dataset for the Contrastive task (get t-SNE representation of samples). Only if global task.
     ----------
     INPUT
         |---- dataset (torch.utils.data.Dataset) the dataset to use for evaluation. It should output the original image,
         |           and the sample index.
     OUTPUT
         |---- None
     """
     if self.is_global:
         logger = logging.getLogger()
         # initiliatize Dataloader
         loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
                                              worker_init_fn=lambda _: np.random.seed())
         # Evaluate
         logger.info("Start Evaluating the network on global contrastive task.")
         start_time = time.time()
         idx_repr = [] # placeholder for bottleneck representation
         n_batch = len(loader)
         self.net.eval()
         self.net.return_bottleneck = True
         with torch.no_grad():
             for b, data in enumerate(loader):
                 im, idx = data
                 im = im.to(self.device).float()
                 idx = idx.to(self.device)
                 # get representations
                 _, z = self.net(im)
                 # keep representations (bottleneck)
                 idx_repr += list(zip(idx.cpu().data.tolist(), z.squeeze().cpu().data.tolist()))
                 # print_progress
                 if self.print_progress:
                     print_progessbar(b, n_batch, Name='\t\tEvaluation Batch', Size=40, erase=True)
             # reset the network attriubtes
             self.net.return_bottleneck = False
         # compute tSNE for representation
         idx, repr = zip(*idx_repr)
         repr = np.array(repr)
         logger.info('Computing the t-SNE representation.')
         repr_2D = TSNE(n_components=2).fit_transform(repr)
         self.outputs['eval']['repr'] = list(zip(idx, repr_2D.tolist()))
         logger.info('Succesfully computed the t-SNE representation.')
         # finish evluation
         self.outputs['eval']['time'] = time.time() - start_time
         logger.info(f"Finished evaluating of encoder on the global contrastive task in {timedelta(seconds=int(self.outputs['eval']['time']))}")
     else:
         warnings.warn("Evaluation is only possible with a global contrastive task.")
def main(src_data_path, src_data_info, dst_data_path, dst_data_info, which):
    """

    """
    # load src csv
    src_df = pd.read_csv(src_data_info, index_col=0)
    src_df = src_df[['id', 'slice', f'ad_{which}_fn']]

    # load dst csv
    dst_df = pd.read_csv(dst_data_info, index_col=0)

    # merge df
    df = pd.merge(dst_df, src_df, on=['id', 'slice'])
    df["attention_fn"] = df.apply(
        lambda row: f"{row['id']:03}" + os.sep + 'anomaly' + os.sep + os.path.
        basename(row[f'ad_{which}_fn'])
        if row[f'ad_{which}_fn'] != 'None' else 'None',
        axis=1
    )  #df['id'] + os.sep + 'anomaly' + os.sep + os.path.basename(df[f'ad_{which}_fn'])

    # remove old folder and make new empty ones
    for i, id_i in enumerate(
            df.id.unique()):  #(_, row) in enumerate(df.iterrows()):
        dir_i = os.path.join(dst_data_path, f'{id_i:03}/anomaly/')
        if os.path.isdir(dir_i):
            for fn in glob.glob(os.path.join(dir_i, '*.png')):
                os.remove(fn)
        else:
            os.makedirs(dir_i)
        print_progessbar(i,
                         len(df.id.unique()),
                         Name='Folder cleaning',
                         Size=50)

    # for each sample : transfer file
    for i, (_, row) in enumerate(df.iterrows()):
        if os.path.basename(row[f'ad_{which}_fn']) != 'None':
            _ = shutil.copy2(
                os.path.join(src_data_path, row[f'ad_{which}_fn']),
                os.path.join(dst_data_path, row['attention_fn']))
        print_progessbar(i, len(df), Name='Sample', Size=50)

    # remove src fn and save df
    df = df.drop(columns=[f'ad_{which}_fn'])
    df.to_csv(os.path.join(dst_data_path, 'info.csv'))
    print(
        f">>> new info csv saved at {os.path.join(dst_data_path, 'info.csv')}")
コード例 #7
0
    def validate(self, dataset):
        """
        validate with dataset and retrun loss and AUC computed on anomaly score (i.e. sum of output feature map).
        """
        with torch.no_grad():
            # make loader
            valid_loader = torch.utils.data.DataLoader(
                dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=False,
                worker_init_fn=lambda _: np.random.seed())
            n_batch = len(valid_loader)
            loss_fn = HSCLoss(reduction='mean')

            self.net.eval()
            # validate data by batch
            valid_loss = 0.0
            label_score = []
            for b, data in enumerate(valid_loader):
                im, label, _ = data
                im = im.to(self.device).float()
                label = label.to(self.device)
                # reconstruct
                feat_map = self.net(im)
                # compute loss
                valid_loss += loss_fn(feat_map, label).item()
                # compute score
                ad_score = (torch.sqrt(feat_map**2 + 1) - 1).reshape(
                    feat_map.shape[0], -1).sum(-1)
                # save label and scores
                label_score += list(
                    zip(label.cpu().data.tolist(),
                        ad_score.cpu().data.tolist()))

                if self.print_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='Valid Batch',
                                     Size=100,
                                     erase=True)

            # compute AUC
            label, score = zip(*label_score)
            auc = roc_auc_score(np.array(label), np.array(score))

        return valid_loss / n_batch, auc
コード例 #8
0
def main(input_path, out_folder):
    """
    Convert the qureAI CQ500 dataset (at input_path) from series of dicom to 3D NIfTI volumes.
    """
    print(
        f'>>> Start converting dicom series from {input_path} to NIfTI volumes saved at {out_folder}.'
    )
    # adjust dicom conversion settings to not check for orthogonality
    dicom2nifti.settings.disable_validate_orthogonal()
    # Place holder for nifti file information
    out_info_list = []
    # iterate over subfolder
    dir_list = glob.glob(input_path + '*/')
    for n, patient_dir in enumerate(dir_list):
        # get patient CT ID
        ID = patient_dir.split('/')[-2]
        # iterate over patient's series
        dcm_list = []
        for dcm_fn in glob.glob(patient_dir + '*.dcm'):
            # read dicom and decompress it
            ds = pydicom.dcmread(dcm_fn)
            ds.decompress()
            dcm_list.append(ds)
        # convert the dicom serie into a nifti file
        out_path = out_folder + ID + '.nii'
        _ = dicom2nifti.convert_dicom.dicom_array_to_nifti(dcm_list, out_path)
        out_info_list.append({
            'id': int(ID),
            'filename': ID + '.nii',
            'n_slice': len(dcm_list)
        })

        print_progessbar(n, len(dir_list), '\tCT Scan', Size=40, erase=False)

    print(
        f'>>> {len(out_info_list)} NIfTI volumes successfully saved at {out_folder}.'
    )
    # read info csv and merge with filepath
    in_df = pd.read_csv(input_path + 'ICH_probabilities.csv', index_col=0)
    fn_df = pd.DataFrame(out_info_list)
    df = pd.merge(fn_df, in_df, left_on='id', right_index=True, how='outer')
    # save data info
    df.to_csv(out_folder + 'info.csv')
    print(f">>> NIfTI file informations saved at {out_folder + 'info.csv'}.")
コード例 #9
0
    def get_min_max(self,
                    loader,
                    reception=True,
                    std=None,
                    cpu=True,
                    q_min=0.025,
                    q_max=0.975):
        """
        Compute the Min and Max values of the heat maps on the dataset. For each batch a possible new min or max values
        is selected as the q_min or q_max quantile of the heatmaps entries.
        """
        self.net.eval()
        # get scaling parameters with one forward pass
        min_val, max_val = np.inf, -np.inf
        for b, data in enumerate(loader):
            #im, _, _ = data
            im = data[0]
            im = im.to(self.device).float()
            #label = label.to(self.device)
            heatmap = self.generate_heatmap(im,
                                            reception=reception,
                                            std=std,
                                            cpu=cpu)
            qmax = torch.kthvalue(heatmap.reshape(-1),
                                  int(q_max * heatmap.reshape(-1).size(0))
                                  )[0] if q_max < 1.0 else heatmap.max()
            if qmax > max_val:
                max_val = qmax
            qmin = torch.kthvalue(heatmap.reshape(-1),
                                  int(q_min * heatmap.reshape(-1).size(0))
                                  )[0] if q_min > 0 else heatmap.min()
            if qmin < min_val:
                min_val = qmin

            if self.print_progress:
                print_progessbar(b,
                                 len(loader),
                                 Name='Getting Scaling Factor',
                                 Size=100,
                                 erase=True)

        return min_val, max_val
        def evaluate(self, net, dataset, return_score=False, print_to_logger=True, save_path=None):
            """
            Evaluate the network with the given dataset. The evaluation score is given for the 3D prediction.
            ----------
            INPUT
                |---- net (nn.Module) the network architecture to train.
                |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a
                |           target binary mask, the patientID.
                |---- return_score (bool) whether to return the mean Dice and mean IoU scores of 3D segmentation (for
                |           the 2D case the Dice is computed on the concatenation of prediction for a patient).
                |---- print_to_logger (bool) whether to print information to the logger.
                |---- save_path (str) the folder path where to save segemntation map (as bitmap for each slice) and the
                |           preformance dataframe. If not provided (i.e. None) nothing is saved.
            OUTPUT
                |---- (Dice) (float) the average Dice coefficient for the 3D segemntation.
                |---- (IoU) (flaot) the average IoU coefficient for the 3D segementation.
            """
            # manage to save predictions if save path is given (dataset give patient ID and slice number as well)
            # Need to report Dice for 3D input (even if 2D model) --> need to rearange 2D prediction

            if print_to_logger:
                logger = logging.getLogger()

            # make dataloader
            loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
                                                 num_workers=self.num_workers)
            # put net on device
            net = net.to(self.device)

            # Evaluation
            if print_to_logger:
                logger.info('Start evaluating the UNet.')
            start_time = time.time()
            id_pred = [] # Placeholder for each 2D prediction scores

            net.eval()
            with torch.no_grad():
                for b, data in enumerate(loader):
                    # get data on device
                    input, target, pID = data
                    input = input.to(self.device).float()
                    target = target.to(self.device)
                    # make prediction
                    pred = net(input).argmax(dim=1)

                    # get confusion matrix for each slice of each samples
                    # decompose volume in slice and treat process them separately
                    for id, target_samp, pred_samp in zip(pID, target, pred): # iterate over batch
                        tn, fp, fn, tp = confusion_matrix(target_samp.cpu().data.numpy().ravel(),
                                                          pred_samp.cpu().data.numpy().ravel()).ravel()
                        # save slice prediction if required
                        if save_path:
                            img = nib.Nifti1Image(pred_samp.cpu().numpy().astype(bool), np.eye(4))
                            pred_path = f'{save_path}/{id}.nii'
                            nib.save(img, pred_path)
                        else:
                            pred_path = 'None

                        # add to list
                        id_pred.append({'PatientID':id.cpu(), 'TP':tp, 'TN':tn, 'FP':fp, 'FN':fn, 'pred_fn':None})

                    if self.print_progress:
                        print_progessbar(b, len(loader), Name='\t\tEvaluation Batch', Size=40, erase=True)

            # make DataFrame from ID_pred to compute Dice score per image and per volume
            result_df = pd.DataFrame(id_pred)

            # compute Dice & Jaccard (IoU) per Slice + save DF if required
            result_df['Dice'] = (2*result_df.TP + 1e-9) / (2*result_df.TP + result_df.FP + result_df.FN + 1e-9)
            result_df['IoU'] = (result_df.TP + 1e-9) / (result_df.TP + result_df.FP + result_df.FN + 1e-9)
            if save_path:
                result_df.to_csv(f'{save_path}/prediction_scores.csv')

            # aggregate by patient TP/TN/FP/FN (sum) + recompute 3D Dice & Jaccard then take mean and return values
            result_3D_df = result_df[['PatientID', 'TP', 'TN', 'FP', 'FN']].groupby('PatientID').sum()
            result_3D_df['Dice'] = (2*result_3D_df.TP + 1e-9) / (2*result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9)
            result_3D_df['IoU'] = (result_3D_df.TP + 1e-9) / (result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9)
            avg_results = result_3D_df[['Dice', 'IoU']].mean(axis=0)

            self.eval_time = time.time() - start_time()
            self.eval_dice = avg_results.Dice
            self.eval_IoU = avg_result.IoU

            if print_to_logger:
                logger.info(f'Evaluation time: {self.eval_time} [s].')
                logger.info(f'Evaluation Dice: {self.eval_dice:.3%}.')
                logger.info(f'Evaluation IoU: {self.eval_IoU:.3%}.')
                logger.info('Finished evaluating the UNet.')

            if return_score:
                return avg_results.Dice, avg_results.IoU
コード例 #11
0
    def train(self, dataset, valid_dataset=None, checkpoint_path=None):
        """
        Train the network with the given dataset(s).
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a
            |           target binary mask, the patientID, and the slice number.
            |---- valid_dataset (torch.utils.data.Dataset) the optional validation dataset. If provided, the model is
            |           validated at each epoch. It must have the same struture as the train dataset.
            |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from.
        OUTPUT
            |---- net (nn.Module) the trained network.
        """
        logger = logging.getLogger()
        # make the dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            worker_init_fn=lambda _: np.random.seed())
        # put net to device
        self.unet = self.unet.to(self.device)
        # define optimizer
        optimizer = optim.Adam(self.unet.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        # define the lr scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
        # define the loss function
        loss_fn = self.loss_fn(**self.loss_fn_kwargs)
        # Load checkpoint if present
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            self.unet.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['lr_state'])
            epoch_loss_list = checkpoint['loss_evolution']
            logger.info(
                f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from beginning.')
            n_epoch_finished = 0
            epoch_loss_list = []  # Placeholder for epoch evolution

        # start training
        logger.info('Start training the U-Net 2.5D.')
        start_time = time.time()
        n_batch = len(train_loader)

        for epoch in range(n_epoch_finished, self.n_epoch):
            self.unet.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get data
                input, target, _, _ = data
                # put data tensors on device
                input = input.to(self.device).float().requires_grad_(True)
                target = target.to(self.device).float().requires_grad_(True)
                # zero the networks' gradients
                optimizer.zero_grad()
                # optimize weights with backpropagation on the batch
                pred = self.unet(input)
                loss = loss_fn(pred, target)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                # print process
                if self.print_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tTrain Batch',
                                     Size=40,
                                     erase=True)

            # Get validation performance if required
            valid_dice = ''
            if valid_dataset:
                self.evaluate(valid_dataset,
                              print_to_logger=False,
                              save_path=None)
                valid_dice = f"| Valid Dice: {self.outputs['eval']['dice']['all']:.5f} " + \
                             f"| Valid Dice (Positive Slices): {self.outputs['eval']['dice']['positive']:.5f} "

            # log the epoch statistics
            logger.info(
                f'\t| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} '
                f'| Train Loss: {epoch_loss / n_batch:.6f} ' + valid_dice +
                f'| lr: {scheduler.get_last_lr()[0]:.7f} |')
            # Store epoch loss and epoch dice
            epoch_loss_list.append([
                epoch + 1, epoch_loss / n_batch,
                self.outputs['eval']['dice']['all'],
                self.outputs['eval']['dice']['positive']
            ])
            # update scheduler
            scheduler.step()
            # Save Checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0 and checkpoint_path:
                checkpoint = {
                    'n_epoch_finished': epoch + 1,
                    'net_state': self.unet.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'lr_state': scheduler.state_dict(),
                    'loss_evolution': epoch_loss_list
                }
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

        # End training
        self.outputs['train']['time'] = time.time() - start_time
        self.outputs['train']['evolution'] = epoch_loss_list
        logger.info(
            f"Finished training U-Net 2D in {timedelta(seconds=int(self.outputs['train']['time']))}"
        )
コード例 #12
0
    def segement_volume(self,
                        vol,
                        save_fn=None,
                        window=None,
                        input_size=(256, 256),
                        return_pred=False):
        """
        Segement each slice of the passed Nifti volume and save the results as a Nifti volume.
        ----------
        INPUT
            |---- vol (nibabel.nifti1.Nifti1Pair) the nibabel volume with metadata to segement.
            |---- save_fn (str) where to save the segmentation.
            |---- window (tuple (center, width)) the winowing to apply to the ct-scan.
            |---- input_size (tuple (h, w)) the input size for the network.
            |---- return_pred (bool) whether to return the volume of prediction.
        OUTPUT
            |---- (mask_vol) (nibabel.nifti1.Nifti1Pair) the prediction volume.
        """
        pred_list = []
        vol_data = np.rot90(vol.get_fdata(),
                            axes=(0, 1))  # 90° counterclockwise rotation
        if window:
            vol_data = window_ct(vol_data,
                                 win_center=window[0],
                                 win_width=window[1],
                                 out_range=(0, 1))
        transform = tf.Compose(tf.Resize(H=input_size[0], W=input_size[1]),
                               tf.ToTorchTensor())
        self.unet.eval()
        self.unet.to(self.device)
        with torch.no_grad():
            for s in range(0, vol_data.shape[2], self.batch_size):
                # get slice in good size and as tensor
                input = transform(vol_data[:, :, s:s + self.batch_size]).to(
                    self.device).float().permute(3, 0, 1, 2)
                # predict
                pred = self.unet(input)
                pred = torch.where(pred >= 0.5,
                                   torch.ones_like(pred, device=self.device),
                                   torch.zeros_like(pred, device=self.device))
                # store pred (B x H x W)
                pred_list.append(
                    pred.squeeze(dim=1).permute(1, 2, 0).cpu().numpy().astype(
                        np.uint8) * 255)
                if self.print_progress:
                    print_progessbar(s + pred.shape[0] - 1,
                                     Max=vol_data.shape[2],
                                     Name='Slice',
                                     Size=20,
                                     erase=True)

        # make the prediction volume
        vol_pred = np.concatenate(pred_list, axis=2)
        # resize to input size and rotate 90° clockwise
        vol_pred = np.rot90(skimage.transform.resize(
            vol_pred, (vol.header['dim'][1], vol.header['dim'][2]), order=0),
                            axes=(1, 0))
        # make Nifty and save it
        vol_pred_nii = nib.Nifti1Pair(vol_pred.astype(np.uint8), vol.affine)
        if save_fn:
            nib.save(vol_pred_nii, save_fn)
        # return Nifti prediction
        if return_pred:
            return vol_pred_nii
コード例 #13
0
    def evaluate(self, dataset, print_to_logger=True, save_path=None):
        """
        Evaluate the network with the given dataset. The evaluation score is given for the 3D prediction.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a
            |           target binary mask, the volume ID and the slice number.
            |---- print_to_logger (bool) whether to print information to the logger.
            |---- save_path (str) the folder path where to save segemntation map (as bitmap for each slice) and the
            |           preformance dataframe. If not provided (i.e. None) nothing is saved.
        OUTPUT
            |---- None
        """
        if print_to_logger:
            logger = logging.getLogger()

        # make dataloader
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            worker_init_fn=lambda _: np.random.seed())
        # put net on device
        self.unet = self.unet.to(self.device)

        # Evaluation
        if print_to_logger:
            logger.info('Start evaluating the U-Net 2.5D.')
        start_time = time.time()
        id_pred = {
            'volID': [],
            'slice': [],
            'label': [],
            'TP': [],
            'TN': [],
            'FP': [],
            'FN': [],
            'pred_fn': []
        }  # Placeholder for each 2D prediction scores
        self.unet.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data on device
                input, target, ID, slice_nbr = data
                input = input.to(self.device).float()
                target = target.to(self.device).float()
                # make prediction
                pred = self.unet(input)
                # threshold to binarize
                pred = torch.where(pred >= 0.5,
                                   torch.ones_like(pred, device=self.device),
                                   torch.zeros_like(pred, device=self.device))
                # get confusion matrix for each slice of each volumes
                tn, fp, fn, tp = batch_binary_confusion_matrix(pred, target)
                # save prediction if required
                if save_path:
                    pred_path = []
                    for id, s_nbr, pred_samp in zip(ID, slice_nbr, pred):
                        # save slice prediction if required
                        os.makedirs(os.path.join(save_path, f'{id}/'),
                                    exist_ok=True)
                        io.imsave(
                            os.path.join(save_path, f'{id}/{s_nbr}.bmp'),
                            pred_samp[0, :, :].cpu().numpy().astype(np.uint8) *
                            255,
                            check_contrast=False
                        )  # image are binary --> put in uint8 and scale to 255
                        pred_path.append(
                            f'{id}/{s_nbr}.bmp'
                        )  # file name with volume and slice number
                else:
                    pred_path = ['-'] * ID.shape[0]
                # add data to placeholder
                id_pred['volID'] += ID.cpu().tolist()
                id_pred['slice'] += slice_nbr.cpu().tolist()
                id_pred['label'] += target.view(
                    target.shape[0], -1).max(dim=1)[0].cpu().tolist(
                    )  # 1 if mask has some positive, else 0
                id_pred['TP'] += tp.cpu().tolist()
                id_pred['TN'] += tn.cpu().tolist()
                id_pred['FP'] += fp.cpu().tolist()
                id_pred['FN'] += fn.cpu().tolist()
                id_pred['pred_fn'] += pred_path

                if self.print_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\tEvaluation Batch',
                                     Size=40,
                                     erase=True)

        # make DataFrame from ID_pred to compute Dice score per image and per volume
        result_df = pd.DataFrame(id_pred)

        # compute Dice per Slice + save DF if required
        result_df['Dice'] = (2 * result_df.TP + 1) / (
            2 * result_df.TP + result_df.FP + result_df.FN + 1)
        if save_path:
            result_df.to_csv(
                os.path.join(save_path, 'slice_prediction_scores.csv'))

        # aggregate by patient TP/TN/FP/FN (sum) + recompute 3D Dice & Jaccard then take mean and return values
        result_3D_df = result_df[['volID', 'label', 'TP', 'TN', 'FP',
                                  'FN']].groupby('volID').agg({
                                      'label': 'max',
                                      'TP': 'sum',
                                      'TN': 'sum',
                                      'FP': 'sum',
                                      'FN': 'sum'
                                  })
        result_3D_df['Dice'] = (2 * result_3D_df.TP + 1) / (
            2 * result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1)
        if save_path:
            result_3D_df.to_csv(
                os.path.join(save_path, 'volume_prediction_scores.csv'))

        # take average over positive volumes only and all together
        avg_results_ICH = result_3D_df.loc[result_3D_df.label == 1,
                                           'Dice'].mean(axis=0)
        avg_results = result_3D_df.Dice.mean(axis=0)
        self.outputs['eval']['time'] = time.time() - start_time
        self.outputs['eval']['dice'] = {
            'all': avg_results,
            'positive': avg_results_ICH
        }

        if print_to_logger:
            logger.info(
                f"Evaluation time: {timedelta(seconds=int(self.outputs['eval']['time']))}"
            )
            logger.info(
                f"Evaluation Dice: {self.outputs['eval']['dice']['all']:.5f}.")
            logger.info(
                f"Evaluation Dice (Positive only): {self.outputs['eval']['dice']['positive']:.5f}."
            )
            logger.info("Finished evaluating the U-Net 2.5D.")
コード例 #14
0
    def train(self, dataset, checkpoint_path=None):
        """
        Train the network on the given dataset for the contrastive task (global or local).
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It should output two augmented
            |           version of the same image as well as the image index.
            |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. If None, the
            |           network's weights are not saved regularily during training.
        OUTPUT
            |---- None.
        """
        logger = logging.getLogger()
        # initialize dataloader
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
                                             drop_last=True, worker_init_fn=lambda _: np.random.seed())
        # initialize loss function
        loss_fn = self.loss_fn(**self.loss_fn_kwargs)
        # initialize otpitimizer
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        # initialize scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
        # load checkpoint if any
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            self.net.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['lr_state'])
            epoch_loss_list = checkpoint['loss_evolution']
            logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from beginning.')
            n_epoch_finished = 0
            epoch_loss_list = []
        # Train Loop
        logger.info(f"Start trianing the network on the {'global' if self.is_global else 'local'} contrastive task.")
        start_time = time.time()
        n_batch = len(loader)
        for epoch in range(n_epoch_finished, self.n_epoch):
            self.net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(loader):
                # get data
                im1, im2, _ = data
                im1 = im1.to(self.device).float().requires_grad_(True)
                im2 = im2.to(self.device).float().requires_grad_(True)
                # zeros gradient
                optimizer.zero_grad()
                # get image representations
                z1 = self.net(im1)
                z2 = self.net(im2)
                # normalize representations
                if self.is_global:
                    z1 = nn.functional.normalize(z1, dim=1)
                    z2 = nn.functional.normalize(z2, dim=1)
                # compute loss and backpropagate
                loss = loss_fn(z1, z2)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                # print progress
                if self.print_progress:
                    print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True)

            # print epoch statistics
            logger.info(f'\t| Epoch : {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} '
                        f'| lr: {scheduler.get_last_lr()[0]:.7f} |')
            # store epoch loss
            epoch_loss_list.append([epoch+1, epoch_loss/n_batch])
            # update lr
            scheduler.step()
            # save checkpoint if needed
            if (epoch+1)%1 == 0 and checkpoint_path:
                checkpoint = {'n_epoch_finished': epoch+1,
                              'net_state': self.net.state_dict(),
                              'optimizer_state': optimizer.state_dict(),
                              'lr_state': scheduler.state_dict(),
                              'loss_evolution': epoch_loss_list}
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

        # End training
        self.outputs['train']['time'] = time.time() - start_time
        self.outputs['train']['evolution'] = epoch_loss_list
        logger.info(f"Finished training on the network on the {'global' if self.is_global else 'local'} contrastive task in {timedelta(seconds=int(self.outputs['train']['time']))}")
コード例 #15
0
def main(config_path):
    """  """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    #-------------------------------------------
    #       Make Dataset
    #-------------------------------------------

    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0)
    dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data,
                    augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()],
                    output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width))

    #-------------------------------------------
    #       Load FCDD Model
    #-------------------------------------------

    cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path)
    fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias)
    loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device)
    fcdd_net.load_state_dict(loaded_state_dict)
    fcdd_net = fcdd_net.to(cfg.device).eval()
    logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}")

    # make FCDD object
    fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                device=cfg.device, print_progress=cfg.print_progress)

    #-------------------------------------------
    #       Load Classifier Model
    #-------------------------------------------

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}")

    #-------------------------------------------
    #       Generate Heat-Map for each slice
    #-------------------------------------------

    with torch.no_grad():
        # make loader
        loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                                             shuffle=False, worker_init_fn=lambda _: np.random.seed())
        fcdd_net.eval()

        min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param)

        # computing and saving heatmaps
        out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[],
                   TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[])
        for b, data in enumerate(loader):
            im, mask, id, slice = data
            im = im.to(cfg.device).float()
            mask = mask.to(cfg.device).float()

            # get heatmap
            heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std,
                                            cpu=cfg.heatmap_param.cpu)
            # scaling
            heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1)

            # Threshold
            ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device),
                                                                    torch.zeros_like(heatmap, device=heatmap.device))

            # Compute CM
            tn, fp, fn, tp  = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device))

            # Save heatmaps/mask
            map_fn, mask_fn = [], []
            for i in range(im.shape[0]):
                # Save AD Map
                ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png"
                save_path = os.path.join(out_path, 'pred/', ad_map_fn)
                if not os.path.isdir(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

                ad_map = heatmap[i].squeeze().cpu().numpy()
                io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False)
                # save ad_mask
                ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp"
                save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
                io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False)

                map_fn.append(ad_map_fn)
                mask_fn.append(ad_mask_fn)

            # apply classifier ResNet-18
            if cfg.classifier_model_path is not None:
                pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score
                clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device),
                                                                                    torch.zeros_like(pred_score, device=pred_score.device))
            else:
                clss_pred = [None]*im.shape[0]

            # Save Values
            out['id'] += id.cpu().tolist()
            out['slice'] += slice.cpu().tolist()
            out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist()
            out['ad_map_fn'] += map_fn
            out['ad_mask_fn'] += mask_fn
            out['TN'] += tn.cpu().tolist()
            out['FP'] += fp.cpu().tolist()
            out['FN'] += fn.cpu().tolist()
            out['TP'] += tp.cpu().tolist()
            out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])]
            out['classifier_pred'] += clss_pred.cpu().tolist()

            if cfg.print_progress:
                print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True)

    # make df and save as csv
    slice_df = pd.DataFrame(out)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'})

    slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1)
    volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}")

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}")
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}")
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
コード例 #16
0
def main(video_path, output_path):
    """
    Read the equation showed on the video at `video_path` by tracking the robot.
    The video with the written equation and the robot trajectory is saved at
    `output_path`.
    """
    # load video from video_path
    video = imageio.get_reader(video_path)
    N_frames = int(video.get_meta_data()['duration'] *
                   video.get_meta_data()['fps'])

    # Read 1st frame and analyse the environment (classify operator and digits)
    frame1 = video.get_data(0)

    detector = Detector(frame1, '../models/Digit_model.pickle',
                        '../models/Operators_model.pickle',
                        '../models/KMeans_centers.json')
    eq_element_list = detector.analyse_frame(verbose=True)

    # initialize equation and output-video
    equation = ''
    output_frames = []  # list of frame

    # initialize tracker
    tracker = Tracker()

    # iterate over frames
    print('>>> Equation reading with arrow tracking.')
    is_free = True  # state if the arrow has moved to another element (i.e. if there has been an absence of overlap before)
    unsolved = True  # whether the equation has been solved or not
    for i, frame in enumerate(video):
        # get position <- track robot position
        tracker.track(frame)
        trajectory = tracker.position_list
        arrow_bbox = tracker.bbox
        # check if bbox overlap with any digit/operator
        overlap_elem_list = [
            elem for elem in eq_element_list
            if elem.has_overlap(arrow_bbox, frac=1.0)
        ]
        # append character to equation string
        if len(overlap_elem_list) >= 1 and is_free:
            equation += overlap_elem_list[0].value
            is_free = False

        # reset the is_free if no element overlapped
        if len(overlap_elem_list) == 0:
            is_free = True

        # solve expression if '=' is detected
        if len(equation) > 0:
            if equation[-1] == '=' and unsolved:
                # evaluate equation
                results = numexpr.evaluate(equation[:-1]).item()
                # add results to equation
                equation += str(results)
                unsolved = False

        # draw track and equation on output frame
        output_frames.append(
            draw_output_frame(frame,
                              trajectory,
                              equation,
                              eq_elem_list=eq_element_list))

        # print progress bar
        print_progessbar(i, N_frames, Name='Frame', Size=40, erase=False)

    # check if equation has been solved.
    if unsolved:
        print('>>> WARNING : The equation could not be solved!')
    else:
        print('>>> Successfully read and solve the equation.')

    # save output-video at output_path
    save_video(output_path,
               output_frames,
               fps=video.get_meta_data()['fps'] * 2)
    print(f'>>> Output video saved at {output_path}.')
コード例 #17
0
    def train(self, dataset, checkpoint_path=None, valid_dataset=None):
        """

        """
        logger = logging.getLogger()
        # make dataloader
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            worker_init_fn=lambda _: np.random.seed())
        # make optimizers
        optimizer = optim.Adam(self.net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        # make scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
        # make loss functions
        loss_fn = HSCLoss(reduction='mean')
        # Load checkpoint if present
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            self.net.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['lr_state'])
            epoch_loss_list = checkpoint['loss_evolution']
            logger.info(
                f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from beginning.')
            n_epoch_finished = 0
            epoch_loss_list = []  # Placeholder for epoch evolution

        logger.info('Start training FCDD.')
        start_time = time.time()
        n_batch = len(loader)
        # train loop
        for epoch in range(n_epoch_finished, self.n_epoch):
            self.net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(loader):
                im, label, _ = data
                im = im.to(self.device).float().requires_grad_(True)
                label = label.to(self.device)

                optimizer.zero_grad()
                feat_map = self.net(im)

                loss = loss_fn(feat_map, label)

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if self.print_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='Train Batch',
                                     Size=100,
                                     erase=True)

            # validate
            valid_loss, valid_auc = 0.0, 0.0
            if valid_dataset:
                valid_loss, valid_auc = self.validate(valid_dataset)

            # print epoch summary
            logger.info(
                f"| Epoch {epoch+1:03}/{self.n_epoch:03} "
                f"| Time {timedelta(seconds=int(time.time() - epoch_start_time))} "
                f"| Loss {epoch_loss/n_batch:.5f} "
                f"| Valid Loss {valid_loss:.5f} | Valid AUC {valid_auc:.2%} "
                f"| lr {scheduler.get_last_lr()[0]:.6f} |")

            epoch_loss_list.append(
                [epoch + 1, epoch_loss / n_batch, valid_loss, valid_auc])

            # update lr
            scheduler.step()

            # save checkpoint
            if (epoch + 1
                ) % self.checkpoint_freq == 0 and checkpoint_path is not None:
                checkpoint = {
                    'n_epoch_finished': epoch + 1,
                    'net_state': self.net.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'lr_state': scheduler.state_dict(),
                    'loss_evolution': epoch_loss_list
                }
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

        self.outputs['train']['time'] = time.time() - start_time
        self.outputs['train']['evolution'] = {
            'col_name': ['Epoch', 'Train_Loss', 'Valid_Loss', 'Valid_AUC'],
            'data': epoch_loss_list
        }
        logger.info(
            f"Finished training FCDD in {timedelta(seconds=int(self.outputs['train']['time']))}"
        )
コード例 #18
0
    def validate(self, dataset, save_path=None, epoch=0):
        """
        Validate the generator inpainting capabilities on a samll sample of data.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) dataset returning a small sample of fixed pairs (image, mask, idx)
            |               on which to validate the GAN over training. It must return an image tensor of dimension
            |               [C, H, W], a mask tensor of dimension [1, H, W] and a tensor of index of dimension [1].
            |               If None, no validation is performed during training.
            |---- save_path (str) path to directory where to save the inpaint results of the valida_data as .png. Each
            |               image is saved as save_path/valid_imY_epXXX.png where Y is the image index and XXX is the epoch.
            |---- epoch (int) the current epoch number.
        OUTPUT
            |---- l1_loss (float) the mean Discounted L1Loss over the validation images.
        """
        with torch.no_grad():
            # make loader
            valid_loader = torch.utils.data.DataLoader(
                dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=False,
                worker_init_fn=lambda _: np.random.seed())
            n_batch = len(valid_loader)
            l1_loss_fn = DiscountedL1(gamma=self.gammaL1,
                                      reduction='mean',
                                      device=self.device)

            self.generator.eval()
            # validate data by batch
            l1_loss = 0.0
            for b, data in enumerate(valid_loader):
                im_v, mask_v, idx = data
                im_v = im_v.to(self.device).float()
                mask_v = mask_v.to(self.device).float()
                idx = idx.cpu().numpy()

                # inpaint
                im_inpaint, coarse = self.generator(im_v, mask_v)
                # recover non-masked regions
                im_inpaint = im_v * (1 - mask_v) + im_inpaint * mask_v
                coarse = im_v * (1 - mask_v) + coarse * mask_v
                # compute L1 loss
                l1_loss += l1_loss_fn(im_inpaint, im_v, mask_v).item()
                # save results
                if save_path:
                    for i in range(im_inpaint.shape[0]):
                        arr = im_inpaint[i].permute(1, 2,
                                                    0).squeeze().cpu().numpy()
                        io.imsave(
                            os.path.join(save_path,
                                         f'valid_im{idx[i]}_ep{epoch}.png'),
                            img_as_ubyte(arr))

                        arr = coarse[i].permute(1, 2,
                                                0).squeeze().cpu().numpy()
                        io.imsave(
                            os.path.join(
                                save_path,
                                f'valid_im{idx[i]}_coarse_ep{epoch}.png'),
                            img_as_ubyte(arr))

                print_progessbar(b,
                                 n_batch,
                                 Name='Valid Batch',
                                 Size=40,
                                 erase=True)

        return l1_loss / n_batch
コード例 #19
0
def main(input_data_path, output_data_path, window):
    """
    Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap.
    """
    # open data info dataframe
    info_df = pd.read_csv(input_data_path + 'hemorrhage_diagnosis_raw_ct.csv')
    # replace No-Hemorrhage to hemorrange
    info_df['Hemorrhage'] = 1 - info_df.No_Hemorrhage
    info_df.drop(columns='No_Hemorrhage', inplace=True)
    # open patient info dataframe
    patient_df = pd.read_csv(input_data_path + 'Patient_demographics.csv', header=1, skipfooter=2, engine='python') \
                   .rename(columns={'Unnamed: 0':'PatientNumber', 'Unnamed: 1':'Age',
                                    'Unnamed: 2':'Gender', 'Unnamed: 8':'Fracture', 'Unnamed: 9':'Note'})
    patient_df[patient_df.columns[3:9]] = patient_df[
        patient_df.columns[3:9]].fillna(0).astype(int)
    # add columns Hemorrgae (any ICH)
    patient_df['Hemorrhage'] = patient_df[patient_df.columns[3:8]].max(axis=1)

    # make patient directory
    if not os.path.exists(output_data_path): os.mkdir(output_data_path)
    if not os.path.exists(output_data_path + 'Patient_CT/'):
        os.mkdir(output_data_path + 'Patient_CT/')
    # iterate over volume to extract data
    output_info = []
    for n, id in enumerate(info_df.PatientNumber.unique()):
        # read nii volume
        ct_nii = nib.load(input_data_path + f'ct_scans/{id:03}.nii')
        mask_nii = nib.load(input_data_path + f'masks/{id:03}.nii')
        # get np.array
        ct_vol = ct_nii.get_fdata()
        mask_vol = skimage.img_as_bool(mask_nii.get_fdata())
        # rotate 90° counter clockwise for head pointing upward
        ct_vol = np.rot90(ct_vol, axes=(0, 1))
        mask_vol = np.rot90(mask_vol, axes=(0, 1))
        # window the ct volume to get better contrast of soft tissues
        if window is not None:
            ct_vol = window_ct(ct_vol,
                               win_center=window[0],
                               win_width=window[1],
                               out_range=(0, 1))

        if mask_vol.shape != ct_vol.shape:
            print(
                f'>>> Warning! The ct volume of patient {id} does not have '
                f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})'
            )
        # make patient directory
        if not os.path.exists(output_data_path + f'Patient_CT/{id:03}/'):
            os.mkdir(output_data_path + f'Patient_CT/{id:03}/')
        # iterate over slices to save slices
        for i, slice in enumerate(range(ct_vol.shape[2])):
            ct_slice_fn = f'Patient_CT/{id:03}/{slice+1}.tif'
            # save CT slice
            skimage.io.imsave(output_data_path + ct_slice_fn,
                              ct_vol[:, :, slice],
                              check_contrast=False)
            is_low = True if skimage.exposure.is_low_contrast(
                ct_vol[:, :, slice]) else False
            # save mask if some positive ICH
            if np.any(mask_vol[:, :, slice]):
                mask_slice_fn = f'Patient_CT/{id:03}/{slice+1}_ICH_Seg.bmp'
                skimage.io.imsave(output_data_path + mask_slice_fn,
                                  skimage.img_as_ubyte(mask_vol[:, :, slice]),
                                  check_contrast=False)
            else:
                mask_slice_fn = 'None'
            # add info to output list
            output_info.append({
                'PatientNumber': id,
                'SliceNumber': slice + 1,
                'CT_fn': ct_slice_fn,
                'mask_fn': mask_slice_fn,
                'low_contrast_CT': is_low
            })

            print_progessbar(
                i,
                ct_vol.shape[2],
                Name=
                f'Patient {id:03} {n+1:03}/{len(info_df.PatientNumber.unique()):03}',
                Size=20,
                erase=False)

    # Make dataframe of outputs
    output_info_df = pd.DataFrame(output_info)
    # Merge with input info
    info_df = pd.merge(info_df,
                       output_info_df,
                       how='inner',
                       on=['PatientNumber', 'SliceNumber'])
    # save df
    info_df.to_csv(output_data_path + 'ct_info.csv')
    print('>>> Data informations saved at ' + output_data_path + 'ct_info.csv')
    # save patient df
    patient_df.to_csv(output_data_path + 'patient_info.csv')
    print('>>> Patient informations saved at ' + output_data_path +
          'patient_info.csv')
コード例 #20
0
    def train(self, dataset, checkpoint_path=None, valid_dataset=None, valid_path=None, valid_freq=5):
        """

        """
        logger = logging.getLogger()
        # make dataloader
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers,
                                             shuffle=True, worker_init_fn=lambda _: np.random.seed())
        # make optimizers
        optimizer = optim.Adam(self.ae.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        # make scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
        # make loss functions
        gdl_fn = GDL(reduction='mean', device=self.device)
        mae_fn = nn.L1Loss(reduction='mean')
        mse_fn = nn.MSELoss(reduction='mean')
        # Load checkpoint if present
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            self.ae.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['lr_state'])
            epoch_loss_list = checkpoint['loss_evolution']
            logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from beginning.')
            n_epoch_finished = 0
            epoch_loss_list = [] # Placeholder for epoch evolution

        logger.info('Start training the inpainting AE.')
        start_time = time.time()
        n_batch = len(loader)
        # train loop
        for epoch in range(n_epoch_finished, self.n_epoch):
            self.ae.train()
            epoch_loss, epoch_loss_l1, epoch_loss_l2, epoch_loss_gdl = 0.0, 0.0, 0.0, 0.0
            epoch_start_time = time.time()

            # update Lambda GDL
            if str(epoch) in self.ep_GDL.keys():
                self.lambda_GDL = self.ep_GDL[str(epoch)]
                logger.info(f"Lambda GLD set to {self.lambda_GDL}.")

            for b, data in enumerate(loader):
                im, _ = data
                im = im.to(self.device).float().requires_grad_(True)

                optimizer.zero_grad()
                rec = self.ae(im)

                loss_l1 = mae_fn(rec, im)
                loss_l2 = mse_fn(rec, im)
                #loss_gdl = self.lambda_GDL*gdl_fn(im, rec) if epoch+1 >= self.ep_GDL else 0.0*gdl_fn(im, rec) # consider gdl loss only after some epoch
                loss_gdl = self.lambda_GDL*gdl_fn(im, rec)
                loss = loss_l1 + loss_l2 + loss_gdl

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                epoch_loss_l1 += loss_l1.item()
                epoch_loss_l2 += loss_l2.item()
                epoch_loss_gdl += loss_gdl.item()

                if self.print_progress:
                    print_progessbar(b, n_batch, Name='Train Batch', Size=100, erase=True)

            # validate
            valid_loss = 0.0
            if valid_dataset:
                save_path = valid_path if (epoch+1)%valid_freq == 0 else None
                valid_loss = self.validate(valid_dataset, save_path=save_path, prefix=f'_ep{epoch+1}')

            # print epoch summary
            logger.info(f"| Epoch {epoch+1:03}/{self.n_epoch:03} "
                        f"| Time {timedelta(seconds=int(time.time() - epoch_start_time))} "
                        f"| Loss (L1 + L2 + GDL) {epoch_loss/n_batch:.5f} = {epoch_loss_l1/n_batch:.5f} + {epoch_loss_l2/n_batch:.5f} + {epoch_loss_gdl/n_batch:.5f} "
                        f"| Valid Loss {valid_loss:.5f} | lr {scheduler.get_last_lr()[0]:.6f} |")

            epoch_loss_list.append([epoch+1, epoch_loss/n_batch, epoch_loss_l1/n_batch, epoch_loss_l2/n_batch, epoch_loss_gdl/n_batch])

            # update lr
            scheduler.step()

            # save checkpoint
            if (epoch+1)%self.checkpoint_freq == 0 and checkpoint_path is not None:
                checkpoint ={
                    'n_epoch_finished': epoch+1,
                    'net_state': self.ae.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'lr_state': scheduler.state_dict(),
                    'loss_evolution': epoch_loss_list
                }
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

        self.outputs['train']['time'] = time.time() - start_time
        self.outputs['train']['evolution'] = {'col_name': ['Epoch', 'Loss_total', 'L1_loss', 'L2_loss', 'GDL_loss'],
                                              'data': epoch_loss_list}
        logger.info(f"Finished training inpainter AE in {timedelta(seconds=int(self.outputs['train']['time']))}")
    def train(self, net, dataset, valid_dataset=None, checkpoint_path=None):
        """
        Train the passed network with the given dataset(s).
        ----------
        INPUT
            |---- net (nn.Module) the network architecture to train.
            |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a
            |           target binary mask, the patientID.
            |---- valid_dataset (torch.utils.data.Dataset) the optional validation dataset. If provided, the model is
            |           validated at each epoch. It must have the same struture as the train dataset.
            |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from.
        OUTPUT
            |---- net (nn.Module) the trained network.
        """
        logger = logging.getLogger()

        # make the dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
                                                   num_workers=self.num_workers)
        # put net to device
        net = net.to(self.device)

        # define optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # Load checkpoint if present
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            net.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from begining.')
            n_epoch_finished = 0

        # define the lr scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)

        # define the loss function
        loss_fn = self.loss_fn(**self.loss_fn_kwargs)

        # start training
        logger.info('Start training the UNet.')
        start_time = time.time()
        epoch_loss_list = [] # Placeholder for epoch evolution
        n_batch = len(train_loader)

        for epoch in range(n_epoch_finished, self.n_epoch):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get data
                input, target, _ = data
                # put data tensors on device
                input = input.to(self.device).float().requires_grad_(True)
                target = target.to(self.device)
                # zero the networks' gradients
                optimizer.zero_grad()
                # optimize weights with backpropagation on the batch
                pred = net(input).argmax(dim=1)
                loss = loss_fn(pred, target)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                # print process
                if self.print_progress:
                    print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True)

            # Get validation performance if required
            valid_dice, dice, IoU = '', None, None
            if valid_dataset:
                dice, IoU = self.evaluate(net, valid_dataset, return_score=True, print_to_logger=False, save_path=None)
                valid_dice = f'| Valid Dice: {dice:.3%} | Valid IoU: {IoU:.3%} '

            # log the epoch statistics
            logger.info(f'\t| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Train time: {time.time() - epoch_start_time:.3f} [s] '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} ' + valid_dice
                        f'| lr: {scheduler.get_lr()[0]:g} |')
            # Store epoch loss and epoch dice
            epoch_loss_list.append([epoch+1, epoch_loss/n_batch, dice])

            # update scheduler
            scheduler.step()

            # Save Checkpoint every 10 epochs
            if (epoch+1) % 10 == 0 and checkpoint_path:
                checkpoint = {'n_epoch_finished': epoch+1,
                              'net_state': net.state_dict(),
                              'optimizer_state': optimizer.state_dict()}
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

            # End training
            self.train_time = time.time() - start_time
            self.train_evolution = epoch_loss_list
            logger.info(f'Finished training UNet in {self.train_time:.3f} [s].')

            return net

        def evaluate(self, net, dataset, return_score=False, print_to_logger=True, save_path=None):
            """
            Evaluate the network with the given dataset. The evaluation score is given for the 3D prediction.
            ----------
            INPUT
                |---- net (nn.Module) the network architecture to train.
                |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a
                |           target binary mask, the patientID.
                |---- return_score (bool) whether to return the mean Dice and mean IoU scores of 3D segmentation (for
                |           the 2D case the Dice is computed on the concatenation of prediction for a patient).
                |---- print_to_logger (bool) whether to print information to the logger.
                |---- save_path (str) the folder path where to save segemntation map (as bitmap for each slice) and the
                |           preformance dataframe. If not provided (i.e. None) nothing is saved.
            OUTPUT
                |---- (Dice) (float) the average Dice coefficient for the 3D segemntation.
                |---- (IoU) (flaot) the average IoU coefficient for the 3D segementation.
            """
            # manage to save predictions if save path is given (dataset give patient ID and slice number as well)
            # Need to report Dice for 3D input (even if 2D model) --> need to rearange 2D prediction

            if print_to_logger:
                logger = logging.getLogger()

            # make dataloader
            loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False,
                                                 num_workers=self.num_workers)
            # put net on device
            net = net.to(self.device)

            # Evaluation
            if print_to_logger:
                logger.info('Start evaluating the UNet.')
            start_time = time.time()
            id_pred = [] # Placeholder for each 2D prediction scores

            net.eval()
            with torch.no_grad():
                for b, data in enumerate(loader):
                    # get data on device
                    input, target, pID = data
                    input = input.to(self.device).float()
                    target = target.to(self.device)
                    # make prediction
                    pred = net(input).argmax(dim=1)

                    # get confusion matrix for each slice of each samples
                    # decompose volume in slice and treat process them separately
                    for id, target_samp, pred_samp in zip(pID, target, pred): # iterate over batch
                        tn, fp, fn, tp = confusion_matrix(target_samp.cpu().data.numpy().ravel(),
                                                          pred_samp.cpu().data.numpy().ravel()).ravel()
                        # save slice prediction if required
                        if save_path:
                            img = nib.Nifti1Image(pred_samp.cpu().numpy().astype(bool), np.eye(4))
                            pred_path = f'{save_path}/{id}.nii'
                            nib.save(img, pred_path)
                        else:
                            pred_path = 'None

                        # add to list
                        id_pred.append({'PatientID':id.cpu(), 'TP':tp, 'TN':tn, 'FP':fp, 'FN':fn, 'pred_fn':None})

                    if self.print_progress:
                        print_progessbar(b, len(loader), Name='\t\tEvaluation Batch', Size=40, erase=True)

            # make DataFrame from ID_pred to compute Dice score per image and per volume
            result_df = pd.DataFrame(id_pred)

            # compute Dice & Jaccard (IoU) per Slice + save DF if required
            result_df['Dice'] = (2*result_df.TP + 1e-9) / (2*result_df.TP + result_df.FP + result_df.FN + 1e-9)
            result_df['IoU'] = (result_df.TP + 1e-9) / (result_df.TP + result_df.FP + result_df.FN + 1e-9)
            if save_path:
                result_df.to_csv(f'{save_path}/prediction_scores.csv')

            # aggregate by patient TP/TN/FP/FN (sum) + recompute 3D Dice & Jaccard then take mean and return values
            result_3D_df = result_df[['PatientID', 'TP', 'TN', 'FP', 'FN']].groupby('PatientID').sum()
            result_3D_df['Dice'] = (2*result_3D_df.TP + 1e-9) / (2*result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9)
            result_3D_df['IoU'] = (result_3D_df.TP + 1e-9) / (result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9)
            avg_results = result_3D_df[['Dice', 'IoU']].mean(axis=0)

            self.eval_time = time.time() - start_time()
            self.eval_dice = avg_results.Dice
            self.eval_IoU = avg_result.IoU

            if print_to_logger:
                logger.info(f'Evaluation time: {self.eval_time} [s].')
                logger.info(f'Evaluation Dice: {self.eval_dice:.3%}.')
                logger.info(f'Evaluation IoU: {self.eval_IoU:.3%}.')
                logger.info('Finished evaluating the UNet.')

            if return_score:
                return avg_results.Dice, avg_results.IoU
コード例 #22
0
    def localize_anomalies(self,
                           dataset,
                           save_path=None,
                           reception=True,
                           std=None,
                           cpu=True,
                           q_min=0.025,
                           q_max=0.975):
        """
        Generate heat map for image in dataset and save them.
        """
        with torch.no_grad():
            # make loader
            loader = torch.utils.data.DataLoader(
                dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=False,
                worker_init_fn=lambda _: np.random.seed())
            self.net.eval()

            min_val, max_val = self.get_min_max(loader,
                                                reception=reception,
                                                std=std,
                                                cpu=cpu,
                                                q_min=q_min,
                                                q_max=q_max)

            # computing and saving heatmaps
            for b, data in enumerate(loader):
                im, label, idx = data
                im = im.to(self.device).float()
                label = label.to(self.device)

                heatmap = self.generate_heatmap(im,
                                                reception=reception,
                                                std=std,
                                                cpu=cpu)  #, qu=qu)
                # scaling
                heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(
                    0, 1)

                if save_path:
                    for i in range(im.shape[0]):
                        arr = np.concatenate([
                            im[i].squeeze().cpu().numpy(),
                            heatmap.repeat(1, im.shape[1], 1,
                                           1)[i].squeeze().cpu().numpy()
                        ],
                                             axis=1)
                        io.imsave(os.path.join(
                            save_path,
                            f'heatmap_{idx[i]}_{label[i].item()}.png'),
                                  img_as_ubyte(arr),
                                  check_contrast=False)

                if self.print_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='Heatmap Generation Batch',
                                     Size=100,
                                     erase=True)
    def train(self, dataset, valid_dataset=None, checkpoint_path=None):
        """
        Train the passed network on the given dataset for the Context retoration task.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It should output the image,
            |           the binary label, and the sample index.
            |---- valid_dataset (torch.utils.data.Dataset) the dataset to use for validation at each epoch. It should
            |           output the image, the binary label, and the sample index.
            |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. If None, the
            |           network's weights are not saved regularily during training.
        OUTPUT
            |---- None.
        """
        logger = logging.getLogger()
        # make dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=False,
            worker_init_fn=lambda _: np.random.seed())
        # put net on device
        self.net = self.net.to(self.device)
        # define optimizer
        optimizer = optim.Adam(self.net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        # define lr scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
        # define the loss function
        loss_fn = self.loss_fn(**self.loss_fn_kwargs)
        # load checkpoint if any
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            self.net.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['lr_state'])
            epoch_loss_list = checkpoint['loss_evolution']
            logger.info(
                f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from beginning.')
            n_epoch_finished = 0
            epoch_loss_list = []

        # Start Training
        logger.info('Start training the Binary Classification task.')
        start_time = time.time()
        n_batch = len(train_loader)

        for epoch in range(n_epoch_finished, self.n_epoch):
            self.net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get data : target is original image, input is the corrupted image
                input, label, _ = data
                # put data on device
                input = input.to(self.device).float().requires_grad_(True)
                label = label.to(self.device).long()  #.requires_grad_(True)
                # zero the networks' gradients
                optimizer.zero_grad()
                # optimize the weights with backpropagation on the batch
                pred = self.net(input)
                pred = nn.functional.softmax(pred, dim=1)
                loss = loss_fn(pred, label)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                # print progress
                if self.print_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tTrain Batch',
                                     Size=50,
                                     erase=True)

            auc = None
            if valid_dataset:
                auc, acc, recall, precision, f1 = self.evaluate(
                    valid_dataset, save_tsne=False, return_scores=True)
                valid_summary = f'| AUC {auc:.3%} | Accuracy {acc:.3%} | Recall {recall:.3%} | Precision {precision:.3%} | F1 {f1:.3%} '
            else:
                valid_summary = ''

            # Print epoch statistics
            logger.info(
                f'\t| Epoch : {epoch + 1:03}/{self.n_epoch:03} '
                f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} '
                f'| Train Loss: {epoch_loss / n_batch:.6f} '
                f'{valid_summary}'
                f'| lr: {scheduler.get_last_lr()[0]:.7f} |')
            # Store epoch loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch, auc])
            # update lr
            scheduler.step()
            # Save Checkpoint every epochs
            if (epoch + 1) % 1 == 0 and checkpoint_path:
                checkpoint = {
                    'n_epoch_finished': epoch + 1,
                    'net_state': self.net.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'lr_state': scheduler.state_dict(),
                    'loss_evolution': epoch_loss_list
                }
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

        # End training
        self.outputs['train']['time'] = time.time() - start_time
        self.outputs['train']['evolution'] = {
            'col_name': ['Epoch', 'loss', 'Valid_AUC'],
            'data': epoch_loss_list
        }
        logger.info(
            f"Finished training inpainter Binary Classifier in {timedelta(seconds=int(self.outputs['train']['time']))}"
        )
    def evaluate(self, dataset, save_tsne=False, return_scores=False):
        """
        Evaluate the passed network on the given dataset for the Context retoration task.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for evaluation. It should output the original image,
            |           and the sample index.
            |---- save_tsne (bool) whether to compute and store in self.outputs the tsne representation of the feature map
            |           after the average pooling layer and before the MLP
            |---- return_scores (bool) whether to return the measured ROC AUC, accuracy, recall, precision and f1-score.
        OUTPUT
            |---- (auc) (float) the ROC AUC on the dataset.
            |---- (acc) (float) the accuracy on the dataset.
            |---- (recall) (float) the recall on the dataset.
            |---- (precision) (float) the precision on the dataset.
            |---- (f1) (float) the f1-score on the dataset.
        """
        logger = logging.getLogger()
        # make loader
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            worker_init_fn=lambda _: np.random.seed())
        # put net on device
        self.net = self.net.to(self.device)

        # Evaluate
        start_time = time.time()
        idx_repr = []  # placeholder for bottleneck representation
        idx_label_pred = []  # placeholder for label & prediction
        n_batch = len(loader)
        self.net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data : load in standard way (no patch swaped image)
                input, label, idx = data
                input = input.to(self.device).float()
                label = label.to(self.device).float()
                idx = idx.to(self.device)
                if save_tsne:
                    # get representation
                    self.net.return_bottleneck = True
                    pred_score, repr = self.net(input)
                    pred_score = torch.sigmoid(pred_score)
                    pred = torch.where(
                        pred_score > 0.5,
                        torch.ones_like(pred_score, device=self.device),
                        torch.zeros_like(pred_score, device=self.device))
                    # down sample representation for reduced memory impact
                    #repr = nn.AdaptiveAvgPool2d((4,4))(repr)
                    # add ravelled representations to placeholder
                    idx_repr += list(
                        zip(idx.cpu().data.tolist(),
                            repr.view(repr.shape[0], -1).cpu().data.tolist()))
                    idx_label_pred += list(
                        zip(idx.cpu().data.tolist(),
                            label.cpu().data.tolist(),
                            pred.cpu().data.tolist(),
                            pred_score.cpu().data.tolist())
                    )  # pred score is softmax activation of class 1
                else:
                    pred_score = self.net(input)
                    pred_score = torch.sigmoid(pred_score)  # B x N_class
                    pred = torch.where(
                        pred_score > 0.5,
                        torch.ones_like(pred_score, device=self.device),
                        torch.zeros_like(pred_score, device=self.device))
                    idx_label_pred += list(
                        zip(idx.cpu().data.tolist(),
                            label.cpu().data.tolist(),
                            pred.cpu().data.tolist(),
                            pred_score.cpu().data.tolist()))
                # print_progress
                if self.print_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tEvaluation Batch',
                                     Size=50,
                                     erase=True)
            # reset the network attriubtes
            if save_tsne:
                self.net.return_bottleneck = False

        # compute tSNE for representation
        if save_tsne:
            idx, repr = zip(*idx_repr)
            repr = np.array(repr)
            logger.info('Computing the t-SNE representation.')
            repr_2D = TSNE(n_components=2).fit_transform(repr)
            self.outputs['eval']['repr'] = list(zip(idx, repr_2D.tolist()))
            logger.info('Succesfully computed the t-SNE representation.')

        # Compute Accuracy
        _, label, pred, pred_score = zip(*idx_label_pred)
        label, pred, pred_score = np.array(label), np.array(pred), np.array(
            pred_score)
        auc = roc_auc_score(label, pred_score, average=self.score_average)
        acc = accuracy_score(label.ravel(), pred.ravel())
        sub_acc = accuracy_score(label, pred)
        recall = recall_score(label, pred, average=self.score_average)
        precision = precision_score(label, pred, average=self.score_average)
        f1 = f1_score(label, pred, average=self.score_average)
        self.outputs['eval']['auc'] = auc
        self.outputs['eval']['acc'] = acc
        self.outputs['eval']['subset_acc'] = sub_acc
        self.outputs['eval']['recall'] = recall
        self.outputs['eval']['precision'] = precision
        self.outputs['eval']['f1'] = f1
        self.outputs['eval']['pred'] = idx_label_pred

        # finish evluation
        self.outputs['eval']['time'] = time.time() - start_time

        if return_scores:
            return auc, acc, sub_acc, recall, precision, f1
コード例 #25
0
    def train(self, dataset, checkpoint_path=None):
        """
        Train the passed network on the given dataset for the Context retoration task.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It should output the original image,
            |           the corrupted image (Patch swapped) and the sample index.
            |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. If None, the
            |           network's weights are not saved regularily during training.
        OUTPUT
            |---- None.
        """
        logger = logging.getLogger()
        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
                                                   num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed())
        # put net on device
        self.net = self.net.to(self.device)
        # define optimizer
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        # define lr scheduler
        scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
        # define the loss function
        loss_fn = self.loss_fn(**self.loss_fn_kwargs)
        # load checkpoint if any
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            self.net.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['lr_state'])
            epoch_loss_list = checkpoint['loss_evolution']
            logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.')
        except FileNotFoundError:
            logger.info('No Checkpoint found. Training from beginning.')
            n_epoch_finished = 0
            epoch_loss_list = []

        # Start Training
        logger.info('Start training the Context Restoration task.')
        start_time = time.time()
        n_batch = len(train_loader)

        for epoch in range(n_epoch_finished, self.n_epoch):
            self.net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get data : target is original image, input is the corrupted image
                target, input, _ = data
                # put data on device
                target = target.to(self.device).float().requires_grad_(True)
                input = input.to(self.device).float().requires_grad_(True)
                # zero the networks' gradients
                optimizer.zero_grad()
                # optimize the weights with backpropagation on the batch
                rec = self.net(input)
                loss = loss_fn(rec, target)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                # print progress
                if self.print_progress:
                    print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True)

            # Print epoch statistics
            logger.info(f'\t| Epoch : {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} '
                        f'| lr: {scheduler.get_last_lr()[0]:.7f} |')
            # Store epoch loss
            epoch_loss_list.append([epoch+1, epoch_loss/n_batch])
            # update lr
            scheduler.step()
            # Save Checkpoint every epochs
            if (epoch+1)%1 == 0 and checkpoint_path:
                checkpoint = {'n_epoch_finished': epoch+1,
                              'net_state': self.net.state_dict(),
                              'optimizer_state': optimizer.state_dict(),
                              'lr_state': scheduler.state_dict(),
                              'loss_evolution': epoch_loss_list}
                torch.save(checkpoint, checkpoint_path)
                logger.info('\tCheckpoint saved.')

        # End training
        self.outputs['train']['time'] = time.time() - start_time
        self.outputs['train']['evolution'] = epoch_loss_list
        logger.info(f"Finished training on the context restoration task in {timedelta(seconds=int(self.outputs['train']['time']))}")