def validate(self, dataset, save_path=None, prefix=''): """ """ with torch.no_grad(): # make loader valid_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) n_batch = len(valid_loader) l1_fn, l2_fn, gdl_fn = nn.L1Loss(reduction='mean'), nn.MSELoss(reduction='mean'), GDL(reduction='mean', device=self.device) self.ae.eval() # validate data by batch valid_loss = 0.0 for b, data in enumerate(valid_loader): im, idx = data im = im.to(self.device).float() idx = idx.cpu().numpy() # reconstruct im_rec = self.ae(im) # compute L1 loss valid_loss += l1_fn(im_rec, im).item() + l2_fn(im_rec, im).item() + self.lambda_GDL*gdl_fn(im, im_rec).item() # save results if save_path: for i in range(im.shape[0]): arr = np.concatenate([im[i].permute(1,2,0).squeeze().cpu().numpy(), im_rec[i].permute(1,2,0).squeeze().cpu().numpy()], axis=1) io.imsave(os.path.join(save_path, f'valid_im{idx[i]}{prefix}.png'), img_as_ubyte(arr), check_contrast=False) print_progessbar(b, n_batch, Name='Valid Batch', Size=100, erase=True) return valid_loss / n_batch
def main(input_data_path, output_data_path, window): """ Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap for the brain extraction. """ # open data info dataframe info_df = pd.read_csv(os.path.join(input_data_path, 'info.csv'), index_col=0) # make patient directory if not os.path.exists(output_data_path): os.mkdir(output_data_path) # iterate over volume to extract data output_info = [] for n, id in enumerate(info_df.id.values): # read nii volume ct_nii = nib.load(os.path.join(input_data_path, f'ct_scans/{id}.nii')) mask_nii = nib.load(os.path.join(input_data_path, f'masks/{id}.nii.gz')) # get np.array ct_vol = ct_nii.get_fdata() mask_vol = skimage.img_as_bool(mask_nii.get_fdata()) # rotate 90° counter clockwise for head pointing upward ct_vol = np.rot90(ct_vol, axes=(0,1)) mask_vol = np.rot90(mask_vol, axes=(0,1)) # window the ct volume to get better contrast of soft tissues if window is not None: ct_vol = window_ct(ct_vol, win_center=window[0], win_width=window[1], out_range=(0,1)) if mask_vol.shape != ct_vol.shape: print(f'>>> Warning! The ct volume of patient {id} does not have ' f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})') # make patient directory if not os.path.exists(os.path.join(output_data_path, f'{id:03}/ct/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/ct/')) if not os.path.exists(os.path.join(output_data_path, f'{id:03}/mask/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/mask/')) # iterate over slices to save slices for i, slice in enumerate(range(ct_vol.shape[2])): ct_slice_fn =f'{id:03}/ct/{slice+1}.tif' # save CT slice skimage.io.imsave(os.path.join(output_data_path, ct_slice_fn), ct_vol[:,:,slice], check_contrast=False) is_low = True if skimage.exposure.is_low_contrast(ct_vol[:,:,slice]) else False # save mask if some brain on slice if np.any(mask_vol[:,:,slice]): mask_slice_fn = f'{id:03}/mask/{slice+1}_Seg.bmp' skimage.io.imsave(os.path.join(output_data_path, mask_slice_fn), skimage.img_as_ubyte(mask_vol[:,:,slice]), check_contrast=False) else: mask_slice_fn = 'None' # add info to output list output_info.append({'volume':id, 'slice':slice+1, 'ct_fn':ct_slice_fn, 'mask_fn':mask_slice_fn, 'low_contrast_ct':is_low}) print_progessbar(i, ct_vol.shape[2], Name=f'Volume {id:03} {n+1:03}/{len(info_df.id):03}', Size=20, erase=False) # Make dataframe of outputs output_info_df = pd.DataFrame(output_info) # save df output_info_df.to_csv(os.path.join(output_data_path, 'slice_info.csv')) print('>>> Slice informations saved at ' + os.path.join(output_data_path, 'slice_info.csv')) # save patient df info_df.to_csv(os.path.join(output_data_path, 'volume_info.csv')) print('>>> Volume informations saved at ' + os.path.join(output_data_path, 'volume_info.csv'))
def evaluate(self, dataset): """ Evaluate the passed network on the given dataset for the Context retoration task. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for evaluation. It should output the original image, | and the sample index. OUTPUT |---- None """ logger = logging.getLogger() # make loader loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed()) # put net on device self.net = self.net.to(self.device) # Evaluate logger.info('Start Evaluating the context restoration model.') start_time = time.time() idx_repr = [] # placeholder for bottleneck representation n_batch = len(loader) self.net.eval() with torch.no_grad(): for b, data in enumerate(loader): # get data : load in standard way (no patch swaped image) input, idx = data input = input.to(self.device).float() idx = idx.to(self.device) # get representation self.net.return_bottleneck = True _, repr = self.net(input) # down sample representation for reduced memory impact repr = nn.AdaptiveAvgPool2d((4,4))(repr) # add ravelled representations to placeholder idx_repr += list(zip(idx.cpu().data.tolist(), repr.view(repr.shape[0], -1).cpu().data.tolist())) # print_progress if self.print_progress: print_progessbar(b, n_batch, Name='\t\tEvaluation Batch', Size=40, erase=True) # reset the network attriubtes self.net.return_bottleneck = False # compute tSNE for representation idx, repr = zip(*idx_repr) repr = np.array(repr) logger.info('Computing the t-SNE representation.') repr_2D = TSNE(n_components=2).fit_transform(repr) self.outputs['eval']['repr'] = list(zip(idx, repr_2D.tolist())) logger.info('Succesfully computed the t-SNE representation.') # finish evluation self.outputs['eval']['time'] = time.time() - start_time logger.info(f"Finished evaluating on the context restoration task in {timedelta(seconds=int(self.outputs['eval']['time']))}")
def _pixelwise_error(self, input, grid_masks, verbose=False): """ Generate a the pixelwise error sample of input when masked with grid. ---------- INPUT |---- input (torch.tensor) the image on which the error sample is computed with the grids. It should have | dimension [C, H, W]. |---- grid_masks (np.array) the set of grid mask to use for inpainting. It should have dimension [N_grid, H, W]. |---- verbose (bool) whether to print a progress bar of the inpainting process. OUTPUT |---- err (np.array) the pixelwise sample of inpainting errors with dimension [N_grid, C, H, W]. """ assert input.ndim == 3, f"Input must be 3 dimensional (C x H x W). Got {input.shape}" grid_dataset = data.TensorDataset( torch.tensor(grid_masks).unsqueeze(1)) grid_loader = data.DataLoader(grid_dataset, batch_size=self.batch_size) input = input.unsqueeze(0) # add a batch dimension error_list = [] for b, grid_mask_batch in enumerate(grid_loader): grid_mask_batch = grid_mask_batch[0] # repeat input to along batch dimension input_rep = input.repeat(grid_mask_batch.shape[0], 1, 1, 1).to(self.device) # inpaint im inpaint_im = self._inpaint(input_rep, grid_mask_batch) # compute difference to input error_list.append(inpaint_im - input_rep) if verbose: print_progessbar(b, len(grid_loader), Name='Grid Inpainting', Size=50, erase=True) # keep only inpainting error where the grid mask were present measures = torch.cat(error_list, dim=0) # [N_measure x C x H x W] grid_sample = torch.tensor(grid_masks).unsqueeze(1).repeat( 1, input.shape[1], 1, 1) # use the whole grid array and repeat the channel dimension err = measures.permute(1, 2, 3, 0)[grid_sample.permute(1, 2, 3, 0) == 1] c, h, w = input.shape[1:] err = err.view(c, h, w, -1).permute(3, 0, 1, 2) # [N_err, C, H, W] return err.cpu().numpy()
def evaluate(self, dataset): """ Evaluate the network on the given dataset for the Contrastive task (get t-SNE representation of samples). Only if global task. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for evaluation. It should output the original image, | and the sample index. OUTPUT |---- None """ if self.is_global: logger = logging.getLogger() # initiliatize Dataloader loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed()) # Evaluate logger.info("Start Evaluating the network on global contrastive task.") start_time = time.time() idx_repr = [] # placeholder for bottleneck representation n_batch = len(loader) self.net.eval() self.net.return_bottleneck = True with torch.no_grad(): for b, data in enumerate(loader): im, idx = data im = im.to(self.device).float() idx = idx.to(self.device) # get representations _, z = self.net(im) # keep representations (bottleneck) idx_repr += list(zip(idx.cpu().data.tolist(), z.squeeze().cpu().data.tolist())) # print_progress if self.print_progress: print_progessbar(b, n_batch, Name='\t\tEvaluation Batch', Size=40, erase=True) # reset the network attriubtes self.net.return_bottleneck = False # compute tSNE for representation idx, repr = zip(*idx_repr) repr = np.array(repr) logger.info('Computing the t-SNE representation.') repr_2D = TSNE(n_components=2).fit_transform(repr) self.outputs['eval']['repr'] = list(zip(idx, repr_2D.tolist())) logger.info('Succesfully computed the t-SNE representation.') # finish evluation self.outputs['eval']['time'] = time.time() - start_time logger.info(f"Finished evaluating of encoder on the global contrastive task in {timedelta(seconds=int(self.outputs['eval']['time']))}") else: warnings.warn("Evaluation is only possible with a global contrastive task.")
def main(src_data_path, src_data_info, dst_data_path, dst_data_info, which): """ """ # load src csv src_df = pd.read_csv(src_data_info, index_col=0) src_df = src_df[['id', 'slice', f'ad_{which}_fn']] # load dst csv dst_df = pd.read_csv(dst_data_info, index_col=0) # merge df df = pd.merge(dst_df, src_df, on=['id', 'slice']) df["attention_fn"] = df.apply( lambda row: f"{row['id']:03}" + os.sep + 'anomaly' + os.sep + os.path. basename(row[f'ad_{which}_fn']) if row[f'ad_{which}_fn'] != 'None' else 'None', axis=1 ) #df['id'] + os.sep + 'anomaly' + os.sep + os.path.basename(df[f'ad_{which}_fn']) # remove old folder and make new empty ones for i, id_i in enumerate( df.id.unique()): #(_, row) in enumerate(df.iterrows()): dir_i = os.path.join(dst_data_path, f'{id_i:03}/anomaly/') if os.path.isdir(dir_i): for fn in glob.glob(os.path.join(dir_i, '*.png')): os.remove(fn) else: os.makedirs(dir_i) print_progessbar(i, len(df.id.unique()), Name='Folder cleaning', Size=50) # for each sample : transfer file for i, (_, row) in enumerate(df.iterrows()): if os.path.basename(row[f'ad_{which}_fn']) != 'None': _ = shutil.copy2( os.path.join(src_data_path, row[f'ad_{which}_fn']), os.path.join(dst_data_path, row['attention_fn'])) print_progessbar(i, len(df), Name='Sample', Size=50) # remove src fn and save df df = df.drop(columns=[f'ad_{which}_fn']) df.to_csv(os.path.join(dst_data_path, 'info.csv')) print( f">>> new info csv saved at {os.path.join(dst_data_path, 'info.csv')}")
def validate(self, dataset): """ validate with dataset and retrun loss and AUC computed on anomaly score (i.e. sum of output feature map). """ with torch.no_grad(): # make loader valid_loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) n_batch = len(valid_loader) loss_fn = HSCLoss(reduction='mean') self.net.eval() # validate data by batch valid_loss = 0.0 label_score = [] for b, data in enumerate(valid_loader): im, label, _ = data im = im.to(self.device).float() label = label.to(self.device) # reconstruct feat_map = self.net(im) # compute loss valid_loss += loss_fn(feat_map, label).item() # compute score ad_score = (torch.sqrt(feat_map**2 + 1) - 1).reshape( feat_map.shape[0], -1).sum(-1) # save label and scores label_score += list( zip(label.cpu().data.tolist(), ad_score.cpu().data.tolist())) if self.print_progress: print_progessbar(b, n_batch, Name='Valid Batch', Size=100, erase=True) # compute AUC label, score = zip(*label_score) auc = roc_auc_score(np.array(label), np.array(score)) return valid_loss / n_batch, auc
def main(input_path, out_folder): """ Convert the qureAI CQ500 dataset (at input_path) from series of dicom to 3D NIfTI volumes. """ print( f'>>> Start converting dicom series from {input_path} to NIfTI volumes saved at {out_folder}.' ) # adjust dicom conversion settings to not check for orthogonality dicom2nifti.settings.disable_validate_orthogonal() # Place holder for nifti file information out_info_list = [] # iterate over subfolder dir_list = glob.glob(input_path + '*/') for n, patient_dir in enumerate(dir_list): # get patient CT ID ID = patient_dir.split('/')[-2] # iterate over patient's series dcm_list = [] for dcm_fn in glob.glob(patient_dir + '*.dcm'): # read dicom and decompress it ds = pydicom.dcmread(dcm_fn) ds.decompress() dcm_list.append(ds) # convert the dicom serie into a nifti file out_path = out_folder + ID + '.nii' _ = dicom2nifti.convert_dicom.dicom_array_to_nifti(dcm_list, out_path) out_info_list.append({ 'id': int(ID), 'filename': ID + '.nii', 'n_slice': len(dcm_list) }) print_progessbar(n, len(dir_list), '\tCT Scan', Size=40, erase=False) print( f'>>> {len(out_info_list)} NIfTI volumes successfully saved at {out_folder}.' ) # read info csv and merge with filepath in_df = pd.read_csv(input_path + 'ICH_probabilities.csv', index_col=0) fn_df = pd.DataFrame(out_info_list) df = pd.merge(fn_df, in_df, left_on='id', right_index=True, how='outer') # save data info df.to_csv(out_folder + 'info.csv') print(f">>> NIfTI file informations saved at {out_folder + 'info.csv'}.")
def get_min_max(self, loader, reception=True, std=None, cpu=True, q_min=0.025, q_max=0.975): """ Compute the Min and Max values of the heat maps on the dataset. For each batch a possible new min or max values is selected as the q_min or q_max quantile of the heatmaps entries. """ self.net.eval() # get scaling parameters with one forward pass min_val, max_val = np.inf, -np.inf for b, data in enumerate(loader): #im, _, _ = data im = data[0] im = im.to(self.device).float() #label = label.to(self.device) heatmap = self.generate_heatmap(im, reception=reception, std=std, cpu=cpu) qmax = torch.kthvalue(heatmap.reshape(-1), int(q_max * heatmap.reshape(-1).size(0)) )[0] if q_max < 1.0 else heatmap.max() if qmax > max_val: max_val = qmax qmin = torch.kthvalue(heatmap.reshape(-1), int(q_min * heatmap.reshape(-1).size(0)) )[0] if q_min > 0 else heatmap.min() if qmin < min_val: min_val = qmin if self.print_progress: print_progessbar(b, len(loader), Name='Getting Scaling Factor', Size=100, erase=True) return min_val, max_val
def evaluate(self, net, dataset, return_score=False, print_to_logger=True, save_path=None): """ Evaluate the network with the given dataset. The evaluation score is given for the 3D prediction. ---------- INPUT |---- net (nn.Module) the network architecture to train. |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a | target binary mask, the patientID. |---- return_score (bool) whether to return the mean Dice and mean IoU scores of 3D segmentation (for | the 2D case the Dice is computed on the concatenation of prediction for a patient). |---- print_to_logger (bool) whether to print information to the logger. |---- save_path (str) the folder path where to save segemntation map (as bitmap for each slice) and the | preformance dataframe. If not provided (i.e. None) nothing is saved. OUTPUT |---- (Dice) (float) the average Dice coefficient for the 3D segemntation. |---- (IoU) (flaot) the average IoU coefficient for the 3D segementation. """ # manage to save predictions if save path is given (dataset give patient ID and slice number as well) # Need to report Dice for 3D input (even if 2D model) --> need to rearange 2D prediction if print_to_logger: logger = logging.getLogger() # make dataloader loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) # put net on device net = net.to(self.device) # Evaluation if print_to_logger: logger.info('Start evaluating the UNet.') start_time = time.time() id_pred = [] # Placeholder for each 2D prediction scores net.eval() with torch.no_grad(): for b, data in enumerate(loader): # get data on device input, target, pID = data input = input.to(self.device).float() target = target.to(self.device) # make prediction pred = net(input).argmax(dim=1) # get confusion matrix for each slice of each samples # decompose volume in slice and treat process them separately for id, target_samp, pred_samp in zip(pID, target, pred): # iterate over batch tn, fp, fn, tp = confusion_matrix(target_samp.cpu().data.numpy().ravel(), pred_samp.cpu().data.numpy().ravel()).ravel() # save slice prediction if required if save_path: img = nib.Nifti1Image(pred_samp.cpu().numpy().astype(bool), np.eye(4)) pred_path = f'{save_path}/{id}.nii' nib.save(img, pred_path) else: pred_path = 'None # add to list id_pred.append({'PatientID':id.cpu(), 'TP':tp, 'TN':tn, 'FP':fp, 'FN':fn, 'pred_fn':None}) if self.print_progress: print_progessbar(b, len(loader), Name='\t\tEvaluation Batch', Size=40, erase=True) # make DataFrame from ID_pred to compute Dice score per image and per volume result_df = pd.DataFrame(id_pred) # compute Dice & Jaccard (IoU) per Slice + save DF if required result_df['Dice'] = (2*result_df.TP + 1e-9) / (2*result_df.TP + result_df.FP + result_df.FN + 1e-9) result_df['IoU'] = (result_df.TP + 1e-9) / (result_df.TP + result_df.FP + result_df.FN + 1e-9) if save_path: result_df.to_csv(f'{save_path}/prediction_scores.csv') # aggregate by patient TP/TN/FP/FN (sum) + recompute 3D Dice & Jaccard then take mean and return values result_3D_df = result_df[['PatientID', 'TP', 'TN', 'FP', 'FN']].groupby('PatientID').sum() result_3D_df['Dice'] = (2*result_3D_df.TP + 1e-9) / (2*result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9) result_3D_df['IoU'] = (result_3D_df.TP + 1e-9) / (result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9) avg_results = result_3D_df[['Dice', 'IoU']].mean(axis=0) self.eval_time = time.time() - start_time() self.eval_dice = avg_results.Dice self.eval_IoU = avg_result.IoU if print_to_logger: logger.info(f'Evaluation time: {self.eval_time} [s].') logger.info(f'Evaluation Dice: {self.eval_dice:.3%}.') logger.info(f'Evaluation IoU: {self.eval_IoU:.3%}.') logger.info('Finished evaluating the UNet.') if return_score: return avg_results.Dice, avg_results.IoU
def train(self, dataset, valid_dataset=None, checkpoint_path=None): """ Train the network with the given dataset(s). ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a | target binary mask, the patientID, and the slice number. |---- valid_dataset (torch.utils.data.Dataset) the optional validation dataset. If provided, the model is | validated at each epoch. It must have the same struture as the train dataset. |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. OUTPUT |---- net (nn.Module) the trained network. """ logger = logging.getLogger() # make the dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed()) # put net to device self.unet = self.unet.to(self.device) # define optimizer optimizer = optim.Adam(self.unet.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define the lr scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # define the loss function loss_fn = self.loss_fn(**self.loss_fn_kwargs) # Load checkpoint if present try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] self.unet.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['lr_state']) epoch_loss_list = checkpoint['loss_evolution'] logger.info( f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from beginning.') n_epoch_finished = 0 epoch_loss_list = [] # Placeholder for epoch evolution # start training logger.info('Start training the U-Net 2.5D.') start_time = time.time() n_batch = len(train_loader) for epoch in range(n_epoch_finished, self.n_epoch): self.unet.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(train_loader): # get data input, target, _, _ = data # put data tensors on device input = input.to(self.device).float().requires_grad_(True) target = target.to(self.device).float().requires_grad_(True) # zero the networks' gradients optimizer.zero_grad() # optimize weights with backpropagation on the batch pred = self.unet(input) loss = loss_fn(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() # print process if self.print_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True) # Get validation performance if required valid_dice = '' if valid_dataset: self.evaluate(valid_dataset, print_to_logger=False, save_path=None) valid_dice = f"| Valid Dice: {self.outputs['eval']['dice']['all']:.5f} " + \ f"| Valid Dice (Positive Slices): {self.outputs['eval']['dice']['positive']:.5f} " # log the epoch statistics logger.info( f'\t| Epoch: {epoch + 1:03}/{self.n_epoch:03} ' f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} ' f'| Train Loss: {epoch_loss / n_batch:.6f} ' + valid_dice + f'| lr: {scheduler.get_last_lr()[0]:.7f} |') # Store epoch loss and epoch dice epoch_loss_list.append([ epoch + 1, epoch_loss / n_batch, self.outputs['eval']['dice']['all'], self.outputs['eval']['dice']['positive'] ]) # update scheduler scheduler.step() # Save Checkpoint every 10 epochs if (epoch + 1) % 10 == 0 and checkpoint_path: checkpoint = { 'n_epoch_finished': epoch + 1, 'net_state': self.unet.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_state': scheduler.state_dict(), 'loss_evolution': epoch_loss_list } torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') # End training self.outputs['train']['time'] = time.time() - start_time self.outputs['train']['evolution'] = epoch_loss_list logger.info( f"Finished training U-Net 2D in {timedelta(seconds=int(self.outputs['train']['time']))}" )
def segement_volume(self, vol, save_fn=None, window=None, input_size=(256, 256), return_pred=False): """ Segement each slice of the passed Nifti volume and save the results as a Nifti volume. ---------- INPUT |---- vol (nibabel.nifti1.Nifti1Pair) the nibabel volume with metadata to segement. |---- save_fn (str) where to save the segmentation. |---- window (tuple (center, width)) the winowing to apply to the ct-scan. |---- input_size (tuple (h, w)) the input size for the network. |---- return_pred (bool) whether to return the volume of prediction. OUTPUT |---- (mask_vol) (nibabel.nifti1.Nifti1Pair) the prediction volume. """ pred_list = [] vol_data = np.rot90(vol.get_fdata(), axes=(0, 1)) # 90° counterclockwise rotation if window: vol_data = window_ct(vol_data, win_center=window[0], win_width=window[1], out_range=(0, 1)) transform = tf.Compose(tf.Resize(H=input_size[0], W=input_size[1]), tf.ToTorchTensor()) self.unet.eval() self.unet.to(self.device) with torch.no_grad(): for s in range(0, vol_data.shape[2], self.batch_size): # get slice in good size and as tensor input = transform(vol_data[:, :, s:s + self.batch_size]).to( self.device).float().permute(3, 0, 1, 2) # predict pred = self.unet(input) pred = torch.where(pred >= 0.5, torch.ones_like(pred, device=self.device), torch.zeros_like(pred, device=self.device)) # store pred (B x H x W) pred_list.append( pred.squeeze(dim=1).permute(1, 2, 0).cpu().numpy().astype( np.uint8) * 255) if self.print_progress: print_progessbar(s + pred.shape[0] - 1, Max=vol_data.shape[2], Name='Slice', Size=20, erase=True) # make the prediction volume vol_pred = np.concatenate(pred_list, axis=2) # resize to input size and rotate 90° clockwise vol_pred = np.rot90(skimage.transform.resize( vol_pred, (vol.header['dim'][1], vol.header['dim'][2]), order=0), axes=(1, 0)) # make Nifty and save it vol_pred_nii = nib.Nifti1Pair(vol_pred.astype(np.uint8), vol.affine) if save_fn: nib.save(vol_pred_nii, save_fn) # return Nifti prediction if return_pred: return vol_pred_nii
def evaluate(self, dataset, print_to_logger=True, save_path=None): """ Evaluate the network with the given dataset. The evaluation score is given for the 3D prediction. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a | target binary mask, the volume ID and the slice number. |---- print_to_logger (bool) whether to print information to the logger. |---- save_path (str) the folder path where to save segemntation map (as bitmap for each slice) and the | preformance dataframe. If not provided (i.e. None) nothing is saved. OUTPUT |---- None """ if print_to_logger: logger = logging.getLogger() # make dataloader loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed()) # put net on device self.unet = self.unet.to(self.device) # Evaluation if print_to_logger: logger.info('Start evaluating the U-Net 2.5D.') start_time = time.time() id_pred = { 'volID': [], 'slice': [], 'label': [], 'TP': [], 'TN': [], 'FP': [], 'FN': [], 'pred_fn': [] } # Placeholder for each 2D prediction scores self.unet.eval() with torch.no_grad(): for b, data in enumerate(loader): # get data on device input, target, ID, slice_nbr = data input = input.to(self.device).float() target = target.to(self.device).float() # make prediction pred = self.unet(input) # threshold to binarize pred = torch.where(pred >= 0.5, torch.ones_like(pred, device=self.device), torch.zeros_like(pred, device=self.device)) # get confusion matrix for each slice of each volumes tn, fp, fn, tp = batch_binary_confusion_matrix(pred, target) # save prediction if required if save_path: pred_path = [] for id, s_nbr, pred_samp in zip(ID, slice_nbr, pred): # save slice prediction if required os.makedirs(os.path.join(save_path, f'{id}/'), exist_ok=True) io.imsave( os.path.join(save_path, f'{id}/{s_nbr}.bmp'), pred_samp[0, :, :].cpu().numpy().astype(np.uint8) * 255, check_contrast=False ) # image are binary --> put in uint8 and scale to 255 pred_path.append( f'{id}/{s_nbr}.bmp' ) # file name with volume and slice number else: pred_path = ['-'] * ID.shape[0] # add data to placeholder id_pred['volID'] += ID.cpu().tolist() id_pred['slice'] += slice_nbr.cpu().tolist() id_pred['label'] += target.view( target.shape[0], -1).max(dim=1)[0].cpu().tolist( ) # 1 if mask has some positive, else 0 id_pred['TP'] += tp.cpu().tolist() id_pred['TN'] += tn.cpu().tolist() id_pred['FP'] += fp.cpu().tolist() id_pred['FN'] += fn.cpu().tolist() id_pred['pred_fn'] += pred_path if self.print_progress: print_progessbar(b, len(loader), Name='\t\tEvaluation Batch', Size=40, erase=True) # make DataFrame from ID_pred to compute Dice score per image and per volume result_df = pd.DataFrame(id_pred) # compute Dice per Slice + save DF if required result_df['Dice'] = (2 * result_df.TP + 1) / ( 2 * result_df.TP + result_df.FP + result_df.FN + 1) if save_path: result_df.to_csv( os.path.join(save_path, 'slice_prediction_scores.csv')) # aggregate by patient TP/TN/FP/FN (sum) + recompute 3D Dice & Jaccard then take mean and return values result_3D_df = result_df[['volID', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('volID').agg({ 'label': 'max', 'TP': 'sum', 'TN': 'sum', 'FP': 'sum', 'FN': 'sum' }) result_3D_df['Dice'] = (2 * result_3D_df.TP + 1) / ( 2 * result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1) if save_path: result_3D_df.to_csv( os.path.join(save_path, 'volume_prediction_scores.csv')) # take average over positive volumes only and all together avg_results_ICH = result_3D_df.loc[result_3D_df.label == 1, 'Dice'].mean(axis=0) avg_results = result_3D_df.Dice.mean(axis=0) self.outputs['eval']['time'] = time.time() - start_time self.outputs['eval']['dice'] = { 'all': avg_results, 'positive': avg_results_ICH } if print_to_logger: logger.info( f"Evaluation time: {timedelta(seconds=int(self.outputs['eval']['time']))}" ) logger.info( f"Evaluation Dice: {self.outputs['eval']['dice']['all']:.5f}.") logger.info( f"Evaluation Dice (Positive only): {self.outputs['eval']['dice']['positive']:.5f}." ) logger.info("Finished evaluating the U-Net 2.5D.")
def train(self, dataset, checkpoint_path=None): """ Train the network on the given dataset for the contrastive task (global or local). ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It should output two augmented | version of the same image as well as the image index. |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. If None, the | network's weights are not saved regularily during training. OUTPUT |---- None. """ logger = logging.getLogger() # initialize dataloader loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, worker_init_fn=lambda _: np.random.seed()) # initialize loss function loss_fn = self.loss_fn(**self.loss_fn_kwargs) # initialize otpitimizer optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # initialize scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # load checkpoint if any try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] self.net.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['lr_state']) epoch_loss_list = checkpoint['loss_evolution'] logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from beginning.') n_epoch_finished = 0 epoch_loss_list = [] # Train Loop logger.info(f"Start trianing the network on the {'global' if self.is_global else 'local'} contrastive task.") start_time = time.time() n_batch = len(loader) for epoch in range(n_epoch_finished, self.n_epoch): self.net.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(loader): # get data im1, im2, _ = data im1 = im1.to(self.device).float().requires_grad_(True) im2 = im2.to(self.device).float().requires_grad_(True) # zeros gradient optimizer.zero_grad() # get image representations z1 = self.net(im1) z2 = self.net(im2) # normalize representations if self.is_global: z1 = nn.functional.normalize(z1, dim=1) z2 = nn.functional.normalize(z2, dim=1) # compute loss and backpropagate loss = loss_fn(z1, z2) loss.backward() optimizer.step() epoch_loss += loss.item() # print progress if self.print_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True) # print epoch statistics logger.info(f'\t| Epoch : {epoch + 1:03}/{self.n_epoch:03} ' f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} ' f'| Train Loss: {epoch_loss / n_batch:.6f} ' f'| lr: {scheduler.get_last_lr()[0]:.7f} |') # store epoch loss epoch_loss_list.append([epoch+1, epoch_loss/n_batch]) # update lr scheduler.step() # save checkpoint if needed if (epoch+1)%1 == 0 and checkpoint_path: checkpoint = {'n_epoch_finished': epoch+1, 'net_state': self.net.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_state': scheduler.state_dict(), 'loss_evolution': epoch_loss_list} torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') # End training self.outputs['train']['time'] = time.time() - start_time self.outputs['train']['evolution'] = epoch_loss_list logger.info(f"Finished training on the network on the {'global' if self.is_global else 'local'} contrastive task in {timedelta(seconds=int(self.outputs['train']['time']))}")
def main(config_path): """ """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") #------------------------------------------- # Make Dataset #------------------------------------------- data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) #------------------------------------------- # Load FCDD Model #------------------------------------------- cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path) fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias) loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device) fcdd_net.load_state_dict(loaded_state_dict) fcdd_net = fcdd_net.to(cfg.device).eval() logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}") # make FCDD object fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers, device=cfg.device, print_progress=cfg.print_progress) #------------------------------------------- # Load Classifier Model #------------------------------------------- # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}") #------------------------------------------- # Generate Heat-Map for each slice #------------------------------------------- with torch.no_grad(): # make loader loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) fcdd_net.eval() min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param) # computing and saving heatmaps out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[], TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[]) for b, data in enumerate(loader): im, mask, id, slice = data im = im.to(cfg.device).float() mask = mask.to(cfg.device).float() # get heatmap heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std, cpu=cfg.heatmap_param.cpu) # scaling heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1) # Threshold ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device), torch.zeros_like(heatmap, device=heatmap.device)) # Compute CM tn, fp, fn, tp = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device)) # Save heatmaps/mask map_fn, mask_fn = [], [] for i in range(im.shape[0]): # Save AD Map ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png" save_path = os.path.join(out_path, 'pred/', ad_map_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) ad_map = heatmap[i].squeeze().cpu().numpy() io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False) # save ad_mask ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False) map_fn.append(ad_map_fn) mask_fn.append(ad_mask_fn) # apply classifier ResNet-18 if cfg.classifier_model_path is not None: pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device), torch.zeros_like(pred_score, device=pred_score.device)) else: clss_pred = [None]*im.shape[0] # Save Values out['id'] += id.cpu().tolist() out['slice'] += slice.cpu().tolist() out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist() out['ad_map_fn'] += map_fn out['ad_mask_fn'] += mask_fn out['TN'] += tn.cpu().tolist() out['FP'] += fp.cpu().tolist() out['FN'] += fn.cpu().tolist() out['TP'] += tp.cpu().tolist() out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])] out['classifier_pred'] += clss_pred.cpu().tolist() if cfg.print_progress: print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True) # make df and save as csv slice_df = pd.DataFrame(out) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'}) slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}") # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}") volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}") cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(video_path, output_path): """ Read the equation showed on the video at `video_path` by tracking the robot. The video with the written equation and the robot trajectory is saved at `output_path`. """ # load video from video_path video = imageio.get_reader(video_path) N_frames = int(video.get_meta_data()['duration'] * video.get_meta_data()['fps']) # Read 1st frame and analyse the environment (classify operator and digits) frame1 = video.get_data(0) detector = Detector(frame1, '../models/Digit_model.pickle', '../models/Operators_model.pickle', '../models/KMeans_centers.json') eq_element_list = detector.analyse_frame(verbose=True) # initialize equation and output-video equation = '' output_frames = [] # list of frame # initialize tracker tracker = Tracker() # iterate over frames print('>>> Equation reading with arrow tracking.') is_free = True # state if the arrow has moved to another element (i.e. if there has been an absence of overlap before) unsolved = True # whether the equation has been solved or not for i, frame in enumerate(video): # get position <- track robot position tracker.track(frame) trajectory = tracker.position_list arrow_bbox = tracker.bbox # check if bbox overlap with any digit/operator overlap_elem_list = [ elem for elem in eq_element_list if elem.has_overlap(arrow_bbox, frac=1.0) ] # append character to equation string if len(overlap_elem_list) >= 1 and is_free: equation += overlap_elem_list[0].value is_free = False # reset the is_free if no element overlapped if len(overlap_elem_list) == 0: is_free = True # solve expression if '=' is detected if len(equation) > 0: if equation[-1] == '=' and unsolved: # evaluate equation results = numexpr.evaluate(equation[:-1]).item() # add results to equation equation += str(results) unsolved = False # draw track and equation on output frame output_frames.append( draw_output_frame(frame, trajectory, equation, eq_elem_list=eq_element_list)) # print progress bar print_progessbar(i, N_frames, Name='Frame', Size=40, erase=False) # check if equation has been solved. if unsolved: print('>>> WARNING : The equation could not be solved!') else: print('>>> Successfully read and solve the equation.') # save output-video at output_path save_video(output_path, output_frames, fps=video.get_meta_data()['fps'] * 2) print(f'>>> Output video saved at {output_path}.')
def train(self, dataset, checkpoint_path=None, valid_dataset=None): """ """ logger = logging.getLogger() # make dataloader loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, worker_init_fn=lambda _: np.random.seed()) # make optimizers optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # make scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # make loss functions loss_fn = HSCLoss(reduction='mean') # Load checkpoint if present try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] self.net.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['lr_state']) epoch_loss_list = checkpoint['loss_evolution'] logger.info( f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from beginning.') n_epoch_finished = 0 epoch_loss_list = [] # Placeholder for epoch evolution logger.info('Start training FCDD.') start_time = time.time() n_batch = len(loader) # train loop for epoch in range(n_epoch_finished, self.n_epoch): self.net.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(loader): im, label, _ = data im = im.to(self.device).float().requires_grad_(True) label = label.to(self.device) optimizer.zero_grad() feat_map = self.net(im) loss = loss_fn(feat_map, label) loss.backward() optimizer.step() epoch_loss += loss.item() if self.print_progress: print_progessbar(b, n_batch, Name='Train Batch', Size=100, erase=True) # validate valid_loss, valid_auc = 0.0, 0.0 if valid_dataset: valid_loss, valid_auc = self.validate(valid_dataset) # print epoch summary logger.info( f"| Epoch {epoch+1:03}/{self.n_epoch:03} " f"| Time {timedelta(seconds=int(time.time() - epoch_start_time))} " f"| Loss {epoch_loss/n_batch:.5f} " f"| Valid Loss {valid_loss:.5f} | Valid AUC {valid_auc:.2%} " f"| lr {scheduler.get_last_lr()[0]:.6f} |") epoch_loss_list.append( [epoch + 1, epoch_loss / n_batch, valid_loss, valid_auc]) # update lr scheduler.step() # save checkpoint if (epoch + 1 ) % self.checkpoint_freq == 0 and checkpoint_path is not None: checkpoint = { 'n_epoch_finished': epoch + 1, 'net_state': self.net.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_state': scheduler.state_dict(), 'loss_evolution': epoch_loss_list } torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') self.outputs['train']['time'] = time.time() - start_time self.outputs['train']['evolution'] = { 'col_name': ['Epoch', 'Train_Loss', 'Valid_Loss', 'Valid_AUC'], 'data': epoch_loss_list } logger.info( f"Finished training FCDD in {timedelta(seconds=int(self.outputs['train']['time']))}" )
def validate(self, dataset, save_path=None, epoch=0): """ Validate the generator inpainting capabilities on a samll sample of data. ---------- INPUT |---- dataset (torch.utils.data.Dataset) dataset returning a small sample of fixed pairs (image, mask, idx) | on which to validate the GAN over training. It must return an image tensor of dimension | [C, H, W], a mask tensor of dimension [1, H, W] and a tensor of index of dimension [1]. | If None, no validation is performed during training. |---- save_path (str) path to directory where to save the inpaint results of the valida_data as .png. Each | image is saved as save_path/valid_imY_epXXX.png where Y is the image index and XXX is the epoch. |---- epoch (int) the current epoch number. OUTPUT |---- l1_loss (float) the mean Discounted L1Loss over the validation images. """ with torch.no_grad(): # make loader valid_loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) n_batch = len(valid_loader) l1_loss_fn = DiscountedL1(gamma=self.gammaL1, reduction='mean', device=self.device) self.generator.eval() # validate data by batch l1_loss = 0.0 for b, data in enumerate(valid_loader): im_v, mask_v, idx = data im_v = im_v.to(self.device).float() mask_v = mask_v.to(self.device).float() idx = idx.cpu().numpy() # inpaint im_inpaint, coarse = self.generator(im_v, mask_v) # recover non-masked regions im_inpaint = im_v * (1 - mask_v) + im_inpaint * mask_v coarse = im_v * (1 - mask_v) + coarse * mask_v # compute L1 loss l1_loss += l1_loss_fn(im_inpaint, im_v, mask_v).item() # save results if save_path: for i in range(im_inpaint.shape[0]): arr = im_inpaint[i].permute(1, 2, 0).squeeze().cpu().numpy() io.imsave( os.path.join(save_path, f'valid_im{idx[i]}_ep{epoch}.png'), img_as_ubyte(arr)) arr = coarse[i].permute(1, 2, 0).squeeze().cpu().numpy() io.imsave( os.path.join( save_path, f'valid_im{idx[i]}_coarse_ep{epoch}.png'), img_as_ubyte(arr)) print_progessbar(b, n_batch, Name='Valid Batch', Size=40, erase=True) return l1_loss / n_batch
def main(input_data_path, output_data_path, window): """ Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap. """ # open data info dataframe info_df = pd.read_csv(input_data_path + 'hemorrhage_diagnosis_raw_ct.csv') # replace No-Hemorrhage to hemorrange info_df['Hemorrhage'] = 1 - info_df.No_Hemorrhage info_df.drop(columns='No_Hemorrhage', inplace=True) # open patient info dataframe patient_df = pd.read_csv(input_data_path + 'Patient_demographics.csv', header=1, skipfooter=2, engine='python') \ .rename(columns={'Unnamed: 0':'PatientNumber', 'Unnamed: 1':'Age', 'Unnamed: 2':'Gender', 'Unnamed: 8':'Fracture', 'Unnamed: 9':'Note'}) patient_df[patient_df.columns[3:9]] = patient_df[ patient_df.columns[3:9]].fillna(0).astype(int) # add columns Hemorrgae (any ICH) patient_df['Hemorrhage'] = patient_df[patient_df.columns[3:8]].max(axis=1) # make patient directory if not os.path.exists(output_data_path): os.mkdir(output_data_path) if not os.path.exists(output_data_path + 'Patient_CT/'): os.mkdir(output_data_path + 'Patient_CT/') # iterate over volume to extract data output_info = [] for n, id in enumerate(info_df.PatientNumber.unique()): # read nii volume ct_nii = nib.load(input_data_path + f'ct_scans/{id:03}.nii') mask_nii = nib.load(input_data_path + f'masks/{id:03}.nii') # get np.array ct_vol = ct_nii.get_fdata() mask_vol = skimage.img_as_bool(mask_nii.get_fdata()) # rotate 90° counter clockwise for head pointing upward ct_vol = np.rot90(ct_vol, axes=(0, 1)) mask_vol = np.rot90(mask_vol, axes=(0, 1)) # window the ct volume to get better contrast of soft tissues if window is not None: ct_vol = window_ct(ct_vol, win_center=window[0], win_width=window[1], out_range=(0, 1)) if mask_vol.shape != ct_vol.shape: print( f'>>> Warning! The ct volume of patient {id} does not have ' f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})' ) # make patient directory if not os.path.exists(output_data_path + f'Patient_CT/{id:03}/'): os.mkdir(output_data_path + f'Patient_CT/{id:03}/') # iterate over slices to save slices for i, slice in enumerate(range(ct_vol.shape[2])): ct_slice_fn = f'Patient_CT/{id:03}/{slice+1}.tif' # save CT slice skimage.io.imsave(output_data_path + ct_slice_fn, ct_vol[:, :, slice], check_contrast=False) is_low = True if skimage.exposure.is_low_contrast( ct_vol[:, :, slice]) else False # save mask if some positive ICH if np.any(mask_vol[:, :, slice]): mask_slice_fn = f'Patient_CT/{id:03}/{slice+1}_ICH_Seg.bmp' skimage.io.imsave(output_data_path + mask_slice_fn, skimage.img_as_ubyte(mask_vol[:, :, slice]), check_contrast=False) else: mask_slice_fn = 'None' # add info to output list output_info.append({ 'PatientNumber': id, 'SliceNumber': slice + 1, 'CT_fn': ct_slice_fn, 'mask_fn': mask_slice_fn, 'low_contrast_CT': is_low }) print_progessbar( i, ct_vol.shape[2], Name= f'Patient {id:03} {n+1:03}/{len(info_df.PatientNumber.unique()):03}', Size=20, erase=False) # Make dataframe of outputs output_info_df = pd.DataFrame(output_info) # Merge with input info info_df = pd.merge(info_df, output_info_df, how='inner', on=['PatientNumber', 'SliceNumber']) # save df info_df.to_csv(output_data_path + 'ct_info.csv') print('>>> Data informations saved at ' + output_data_path + 'ct_info.csv') # save patient df patient_df.to_csv(output_data_path + 'patient_info.csv') print('>>> Patient informations saved at ' + output_data_path + 'patient_info.csv')
def train(self, dataset, checkpoint_path=None, valid_dataset=None, valid_path=None, valid_freq=5): """ """ logger = logging.getLogger() # make dataloader loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, worker_init_fn=lambda _: np.random.seed()) # make optimizers optimizer = optim.Adam(self.ae.parameters(), lr=self.lr, weight_decay=self.weight_decay) # make scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # make loss functions gdl_fn = GDL(reduction='mean', device=self.device) mae_fn = nn.L1Loss(reduction='mean') mse_fn = nn.MSELoss(reduction='mean') # Load checkpoint if present try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] self.ae.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['lr_state']) epoch_loss_list = checkpoint['loss_evolution'] logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from beginning.') n_epoch_finished = 0 epoch_loss_list = [] # Placeholder for epoch evolution logger.info('Start training the inpainting AE.') start_time = time.time() n_batch = len(loader) # train loop for epoch in range(n_epoch_finished, self.n_epoch): self.ae.train() epoch_loss, epoch_loss_l1, epoch_loss_l2, epoch_loss_gdl = 0.0, 0.0, 0.0, 0.0 epoch_start_time = time.time() # update Lambda GDL if str(epoch) in self.ep_GDL.keys(): self.lambda_GDL = self.ep_GDL[str(epoch)] logger.info(f"Lambda GLD set to {self.lambda_GDL}.") for b, data in enumerate(loader): im, _ = data im = im.to(self.device).float().requires_grad_(True) optimizer.zero_grad() rec = self.ae(im) loss_l1 = mae_fn(rec, im) loss_l2 = mse_fn(rec, im) #loss_gdl = self.lambda_GDL*gdl_fn(im, rec) if epoch+1 >= self.ep_GDL else 0.0*gdl_fn(im, rec) # consider gdl loss only after some epoch loss_gdl = self.lambda_GDL*gdl_fn(im, rec) loss = loss_l1 + loss_l2 + loss_gdl loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss_l1 += loss_l1.item() epoch_loss_l2 += loss_l2.item() epoch_loss_gdl += loss_gdl.item() if self.print_progress: print_progessbar(b, n_batch, Name='Train Batch', Size=100, erase=True) # validate valid_loss = 0.0 if valid_dataset: save_path = valid_path if (epoch+1)%valid_freq == 0 else None valid_loss = self.validate(valid_dataset, save_path=save_path, prefix=f'_ep{epoch+1}') # print epoch summary logger.info(f"| Epoch {epoch+1:03}/{self.n_epoch:03} " f"| Time {timedelta(seconds=int(time.time() - epoch_start_time))} " f"| Loss (L1 + L2 + GDL) {epoch_loss/n_batch:.5f} = {epoch_loss_l1/n_batch:.5f} + {epoch_loss_l2/n_batch:.5f} + {epoch_loss_gdl/n_batch:.5f} " f"| Valid Loss {valid_loss:.5f} | lr {scheduler.get_last_lr()[0]:.6f} |") epoch_loss_list.append([epoch+1, epoch_loss/n_batch, epoch_loss_l1/n_batch, epoch_loss_l2/n_batch, epoch_loss_gdl/n_batch]) # update lr scheduler.step() # save checkpoint if (epoch+1)%self.checkpoint_freq == 0 and checkpoint_path is not None: checkpoint ={ 'n_epoch_finished': epoch+1, 'net_state': self.ae.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_state': scheduler.state_dict(), 'loss_evolution': epoch_loss_list } torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') self.outputs['train']['time'] = time.time() - start_time self.outputs['train']['evolution'] = {'col_name': ['Epoch', 'Loss_total', 'L1_loss', 'L2_loss', 'GDL_loss'], 'data': epoch_loss_list} logger.info(f"Finished training inpainter AE in {timedelta(seconds=int(self.outputs['train']['time']))}")
def train(self, net, dataset, valid_dataset=None, checkpoint_path=None): """ Train the passed network with the given dataset(s). ---------- INPUT |---- net (nn.Module) the network architecture to train. |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a | target binary mask, the patientID. |---- valid_dataset (torch.utils.data.Dataset) the optional validation dataset. If provided, the model is | validated at each epoch. It must have the same struture as the train dataset. |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. OUTPUT |---- net (nn.Module) the trained network. """ logger = logging.getLogger() # make the dataloader train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) # put net to device net = net.to(self.device) # define optimizer optimizer = torch.optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # Load checkpoint if present try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] net.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from begining.') n_epoch_finished = 0 # define the lr scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # define the loss function loss_fn = self.loss_fn(**self.loss_fn_kwargs) # start training logger.info('Start training the UNet.') start_time = time.time() epoch_loss_list = [] # Placeholder for epoch evolution n_batch = len(train_loader) for epoch in range(n_epoch_finished, self.n_epoch): net.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(train_loader): # get data input, target, _ = data # put data tensors on device input = input.to(self.device).float().requires_grad_(True) target = target.to(self.device) # zero the networks' gradients optimizer.zero_grad() # optimize weights with backpropagation on the batch pred = net(input).argmax(dim=1) loss = loss_fn(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() # print process if self.print_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True) # Get validation performance if required valid_dice, dice, IoU = '', None, None if valid_dataset: dice, IoU = self.evaluate(net, valid_dataset, return_score=True, print_to_logger=False, save_path=None) valid_dice = f'| Valid Dice: {dice:.3%} | Valid IoU: {IoU:.3%} ' # log the epoch statistics logger.info(f'\t| Epoch: {epoch + 1:03}/{self.n_epoch:03} ' f'| Train time: {time.time() - epoch_start_time:.3f} [s] ' f'| Train Loss: {epoch_loss / n_batch:.6f} ' + valid_dice f'| lr: {scheduler.get_lr()[0]:g} |') # Store epoch loss and epoch dice epoch_loss_list.append([epoch+1, epoch_loss/n_batch, dice]) # update scheduler scheduler.step() # Save Checkpoint every 10 epochs if (epoch+1) % 10 == 0 and checkpoint_path: checkpoint = {'n_epoch_finished': epoch+1, 'net_state': net.state_dict(), 'optimizer_state': optimizer.state_dict()} torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') # End training self.train_time = time.time() - start_time self.train_evolution = epoch_loss_list logger.info(f'Finished training UNet in {self.train_time:.3f} [s].') return net def evaluate(self, net, dataset, return_score=False, print_to_logger=True, save_path=None): """ Evaluate the network with the given dataset. The evaluation score is given for the 3D prediction. ---------- INPUT |---- net (nn.Module) the network architecture to train. |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It must return an input image, a | target binary mask, the patientID. |---- return_score (bool) whether to return the mean Dice and mean IoU scores of 3D segmentation (for | the 2D case the Dice is computed on the concatenation of prediction for a patient). |---- print_to_logger (bool) whether to print information to the logger. |---- save_path (str) the folder path where to save segemntation map (as bitmap for each slice) and the | preformance dataframe. If not provided (i.e. None) nothing is saved. OUTPUT |---- (Dice) (float) the average Dice coefficient for the 3D segemntation. |---- (IoU) (flaot) the average IoU coefficient for the 3D segementation. """ # manage to save predictions if save path is given (dataset give patient ID and slice number as well) # Need to report Dice for 3D input (even if 2D model) --> need to rearange 2D prediction if print_to_logger: logger = logging.getLogger() # make dataloader loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) # put net on device net = net.to(self.device) # Evaluation if print_to_logger: logger.info('Start evaluating the UNet.') start_time = time.time() id_pred = [] # Placeholder for each 2D prediction scores net.eval() with torch.no_grad(): for b, data in enumerate(loader): # get data on device input, target, pID = data input = input.to(self.device).float() target = target.to(self.device) # make prediction pred = net(input).argmax(dim=1) # get confusion matrix for each slice of each samples # decompose volume in slice and treat process them separately for id, target_samp, pred_samp in zip(pID, target, pred): # iterate over batch tn, fp, fn, tp = confusion_matrix(target_samp.cpu().data.numpy().ravel(), pred_samp.cpu().data.numpy().ravel()).ravel() # save slice prediction if required if save_path: img = nib.Nifti1Image(pred_samp.cpu().numpy().astype(bool), np.eye(4)) pred_path = f'{save_path}/{id}.nii' nib.save(img, pred_path) else: pred_path = 'None # add to list id_pred.append({'PatientID':id.cpu(), 'TP':tp, 'TN':tn, 'FP':fp, 'FN':fn, 'pred_fn':None}) if self.print_progress: print_progessbar(b, len(loader), Name='\t\tEvaluation Batch', Size=40, erase=True) # make DataFrame from ID_pred to compute Dice score per image and per volume result_df = pd.DataFrame(id_pred) # compute Dice & Jaccard (IoU) per Slice + save DF if required result_df['Dice'] = (2*result_df.TP + 1e-9) / (2*result_df.TP + result_df.FP + result_df.FN + 1e-9) result_df['IoU'] = (result_df.TP + 1e-9) / (result_df.TP + result_df.FP + result_df.FN + 1e-9) if save_path: result_df.to_csv(f'{save_path}/prediction_scores.csv') # aggregate by patient TP/TN/FP/FN (sum) + recompute 3D Dice & Jaccard then take mean and return values result_3D_df = result_df[['PatientID', 'TP', 'TN', 'FP', 'FN']].groupby('PatientID').sum() result_3D_df['Dice'] = (2*result_3D_df.TP + 1e-9) / (2*result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9) result_3D_df['IoU'] = (result_3D_df.TP + 1e-9) / (result_3D_df.TP + result_3D_df.FP + result_3D_df.FN + 1e-9) avg_results = result_3D_df[['Dice', 'IoU']].mean(axis=0) self.eval_time = time.time() - start_time() self.eval_dice = avg_results.Dice self.eval_IoU = avg_result.IoU if print_to_logger: logger.info(f'Evaluation time: {self.eval_time} [s].') logger.info(f'Evaluation Dice: {self.eval_dice:.3%}.') logger.info(f'Evaluation IoU: {self.eval_IoU:.3%}.') logger.info('Finished evaluating the UNet.') if return_score: return avg_results.Dice, avg_results.IoU
def localize_anomalies(self, dataset, save_path=None, reception=True, std=None, cpu=True, q_min=0.025, q_max=0.975): """ Generate heat map for image in dataset and save them. """ with torch.no_grad(): # make loader loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) self.net.eval() min_val, max_val = self.get_min_max(loader, reception=reception, std=std, cpu=cpu, q_min=q_min, q_max=q_max) # computing and saving heatmaps for b, data in enumerate(loader): im, label, idx = data im = im.to(self.device).float() label = label.to(self.device) heatmap = self.generate_heatmap(im, reception=reception, std=std, cpu=cpu) #, qu=qu) # scaling heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp( 0, 1) if save_path: for i in range(im.shape[0]): arr = np.concatenate([ im[i].squeeze().cpu().numpy(), heatmap.repeat(1, im.shape[1], 1, 1)[i].squeeze().cpu().numpy() ], axis=1) io.imsave(os.path.join( save_path, f'heatmap_{idx[i]}_{label[i].item()}.png'), img_as_ubyte(arr), check_contrast=False) if self.print_progress: print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True)
def train(self, dataset, valid_dataset=None, checkpoint_path=None): """ Train the passed network on the given dataset for the Context retoration task. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It should output the image, | the binary label, and the sample index. |---- valid_dataset (torch.utils.data.Dataset) the dataset to use for validation at each epoch. It should | output the image, the binary label, and the sample index. |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. If None, the | network's weights are not saved regularily during training. OUTPUT |---- None. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=False, worker_init_fn=lambda _: np.random.seed()) # put net on device self.net = self.net.to(self.device) # define optimizer optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define lr scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # define the loss function loss_fn = self.loss_fn(**self.loss_fn_kwargs) # load checkpoint if any try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] self.net.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['lr_state']) epoch_loss_list = checkpoint['loss_evolution'] logger.info( f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from beginning.') n_epoch_finished = 0 epoch_loss_list = [] # Start Training logger.info('Start training the Binary Classification task.') start_time = time.time() n_batch = len(train_loader) for epoch in range(n_epoch_finished, self.n_epoch): self.net.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(train_loader): # get data : target is original image, input is the corrupted image input, label, _ = data # put data on device input = input.to(self.device).float().requires_grad_(True) label = label.to(self.device).long() #.requires_grad_(True) # zero the networks' gradients optimizer.zero_grad() # optimize the weights with backpropagation on the batch pred = self.net(input) pred = nn.functional.softmax(pred, dim=1) loss = loss_fn(pred, label) loss.backward() optimizer.step() epoch_loss += loss.item() # print progress if self.print_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=50, erase=True) auc = None if valid_dataset: auc, acc, recall, precision, f1 = self.evaluate( valid_dataset, save_tsne=False, return_scores=True) valid_summary = f'| AUC {auc:.3%} | Accuracy {acc:.3%} | Recall {recall:.3%} | Precision {precision:.3%} | F1 {f1:.3%} ' else: valid_summary = '' # Print epoch statistics logger.info( f'\t| Epoch : {epoch + 1:03}/{self.n_epoch:03} ' f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} ' f'| Train Loss: {epoch_loss / n_batch:.6f} ' f'{valid_summary}' f'| lr: {scheduler.get_last_lr()[0]:.7f} |') # Store epoch loss epoch_loss_list.append([epoch + 1, epoch_loss / n_batch, auc]) # update lr scheduler.step() # Save Checkpoint every epochs if (epoch + 1) % 1 == 0 and checkpoint_path: checkpoint = { 'n_epoch_finished': epoch + 1, 'net_state': self.net.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_state': scheduler.state_dict(), 'loss_evolution': epoch_loss_list } torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') # End training self.outputs['train']['time'] = time.time() - start_time self.outputs['train']['evolution'] = { 'col_name': ['Epoch', 'loss', 'Valid_AUC'], 'data': epoch_loss_list } logger.info( f"Finished training inpainter Binary Classifier in {timedelta(seconds=int(self.outputs['train']['time']))}" )
def evaluate(self, dataset, save_tsne=False, return_scores=False): """ Evaluate the passed network on the given dataset for the Context retoration task. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for evaluation. It should output the original image, | and the sample index. |---- save_tsne (bool) whether to compute and store in self.outputs the tsne representation of the feature map | after the average pooling layer and before the MLP |---- return_scores (bool) whether to return the measured ROC AUC, accuracy, recall, precision and f1-score. OUTPUT |---- (auc) (float) the ROC AUC on the dataset. |---- (acc) (float) the accuracy on the dataset. |---- (recall) (float) the recall on the dataset. |---- (precision) (float) the precision on the dataset. |---- (f1) (float) the f1-score on the dataset. """ logger = logging.getLogger() # make loader loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed()) # put net on device self.net = self.net.to(self.device) # Evaluate start_time = time.time() idx_repr = [] # placeholder for bottleneck representation idx_label_pred = [] # placeholder for label & prediction n_batch = len(loader) self.net.eval() with torch.no_grad(): for b, data in enumerate(loader): # get data : load in standard way (no patch swaped image) input, label, idx = data input = input.to(self.device).float() label = label.to(self.device).float() idx = idx.to(self.device) if save_tsne: # get representation self.net.return_bottleneck = True pred_score, repr = self.net(input) pred_score = torch.sigmoid(pred_score) pred = torch.where( pred_score > 0.5, torch.ones_like(pred_score, device=self.device), torch.zeros_like(pred_score, device=self.device)) # down sample representation for reduced memory impact #repr = nn.AdaptiveAvgPool2d((4,4))(repr) # add ravelled representations to placeholder idx_repr += list( zip(idx.cpu().data.tolist(), repr.view(repr.shape[0], -1).cpu().data.tolist())) idx_label_pred += list( zip(idx.cpu().data.tolist(), label.cpu().data.tolist(), pred.cpu().data.tolist(), pred_score.cpu().data.tolist()) ) # pred score is softmax activation of class 1 else: pred_score = self.net(input) pred_score = torch.sigmoid(pred_score) # B x N_class pred = torch.where( pred_score > 0.5, torch.ones_like(pred_score, device=self.device), torch.zeros_like(pred_score, device=self.device)) idx_label_pred += list( zip(idx.cpu().data.tolist(), label.cpu().data.tolist(), pred.cpu().data.tolist(), pred_score.cpu().data.tolist())) # print_progress if self.print_progress: print_progessbar(b, n_batch, Name='\t\tEvaluation Batch', Size=50, erase=True) # reset the network attriubtes if save_tsne: self.net.return_bottleneck = False # compute tSNE for representation if save_tsne: idx, repr = zip(*idx_repr) repr = np.array(repr) logger.info('Computing the t-SNE representation.') repr_2D = TSNE(n_components=2).fit_transform(repr) self.outputs['eval']['repr'] = list(zip(idx, repr_2D.tolist())) logger.info('Succesfully computed the t-SNE representation.') # Compute Accuracy _, label, pred, pred_score = zip(*idx_label_pred) label, pred, pred_score = np.array(label), np.array(pred), np.array( pred_score) auc = roc_auc_score(label, pred_score, average=self.score_average) acc = accuracy_score(label.ravel(), pred.ravel()) sub_acc = accuracy_score(label, pred) recall = recall_score(label, pred, average=self.score_average) precision = precision_score(label, pred, average=self.score_average) f1 = f1_score(label, pred, average=self.score_average) self.outputs['eval']['auc'] = auc self.outputs['eval']['acc'] = acc self.outputs['eval']['subset_acc'] = sub_acc self.outputs['eval']['recall'] = recall self.outputs['eval']['precision'] = precision self.outputs['eval']['f1'] = f1 self.outputs['eval']['pred'] = idx_label_pred # finish evluation self.outputs['eval']['time'] = time.time() - start_time if return_scores: return auc, acc, sub_acc, recall, precision, f1
def train(self, dataset, checkpoint_path=None): """ Train the passed network on the given dataset for the Context retoration task. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset to use for training. It should output the original image, | the corrupted image (Patch swapped) and the sample index. |---- checkpoint_path (str) the filename for a possible checkpoint to start the training from. If None, the | network's weights are not saved regularily during training. OUTPUT |---- None. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, worker_init_fn=lambda _: np.random.seed()) # put net on device self.net = self.net.to(self.device) # define optimizer optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define lr scheduler scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) # define the loss function loss_fn = self.loss_fn(**self.loss_fn_kwargs) # load checkpoint if any try: checkpoint = torch.load(checkpoint_path, map_location=self.device) n_epoch_finished = checkpoint['n_epoch_finished'] self.net.load_state_dict(checkpoint['net_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['lr_state']) epoch_loss_list = checkpoint['loss_evolution'] logger.info(f'Checkpoint loaded with {n_epoch_finished} epoch finished.') except FileNotFoundError: logger.info('No Checkpoint found. Training from beginning.') n_epoch_finished = 0 epoch_loss_list = [] # Start Training logger.info('Start training the Context Restoration task.') start_time = time.time() n_batch = len(train_loader) for epoch in range(n_epoch_finished, self.n_epoch): self.net.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(train_loader): # get data : target is original image, input is the corrupted image target, input, _ = data # put data on device target = target.to(self.device).float().requires_grad_(True) input = input.to(self.device).float().requires_grad_(True) # zero the networks' gradients optimizer.zero_grad() # optimize the weights with backpropagation on the batch rec = self.net(input) loss = loss_fn(rec, target) loss.backward() optimizer.step() epoch_loss += loss.item() # print progress if self.print_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True) # Print epoch statistics logger.info(f'\t| Epoch : {epoch + 1:03}/{self.n_epoch:03} ' f'| Train time: {timedelta(seconds=int(time.time() - epoch_start_time))} ' f'| Train Loss: {epoch_loss / n_batch:.6f} ' f'| lr: {scheduler.get_last_lr()[0]:.7f} |') # Store epoch loss epoch_loss_list.append([epoch+1, epoch_loss/n_batch]) # update lr scheduler.step() # Save Checkpoint every epochs if (epoch+1)%1 == 0 and checkpoint_path: checkpoint = {'n_epoch_finished': epoch+1, 'net_state': self.net.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_state': scheduler.state_dict(), 'loss_evolution': epoch_loss_list} torch.save(checkpoint, checkpoint_path) logger.info('\tCheckpoint saved.') # End training self.outputs['train']['time'] = time.time() - start_time self.outputs['train']['evolution'] = epoch_loss_list logger.info(f"Finished training on the context restoration task in {timedelta(seconds=int(self.outputs['train']['time']))}")