def gen_heatmap(self, image_file): """ Args: image_file: str to a jpg file path Returns: prefix_name: str of a prefix_name of a jpg with/without prob figure_data: numpy array of a color image """ image_tensor, image_color = self.image_reader(image_file) image_tensor = image_tensor.to(self.device) # model inference logits, logit_maps = self.model(image_tensor) logits = torch.stack(logits) logit_maps = torch.stack(logit_maps) # tensor to numpy image_np = tensor2numpy(image_tensor) prob_maps_np = tensor2numpy(torch.sigmoid(logit_maps)) logit_maps_np = tensor2numpy(logit_maps) num_tasks = len(disease_classes) row_ = num_tasks // 3 + 1 plt_fig = plt.figure(figsize=(10, row_ * 4), dpi=300) prefix = -1 if self.prefix is None else \ disease_classes.index(self.prefix) prefix_name = '' # vgg and resnet do not use pixel_std, densenet and inception use. ori_image = image_np[0, 0, :, :] * self.cfg.pixel_std + \ self.cfg.pixel_mean for i in range(num_tasks): prob = torch.sigmoid(logits[i]) prob = tensor2numpy(prob) if prefix == i: prefix_name = '{:.4f}_'.format(prob[0][0]) subtitle = '{}:{:.4f}'.format(disease_classes[i], prob[0][0]) ax_overlay = self.set_overlay(plt_fig, row_, i, subtitle) # overlay_image is assigned multiple times, # but we only use the last one to get colorbar overlay_image = self.get_overlayed_img(ori_image, logit_maps_np[i, :, :], prob_maps_np[i, :, :], ax_overlay) ax_rawimage = self.set_rawimage(plt_fig, row_) _ = self.get_raw_image(image_color, ax_rawimage) divider = make_axes_locatable(ax_overlay) ax_colorbar = divider.append_axes("right", size="5%", pad=0.05) plt_fig.colorbar(overlay_image, cax=ax_colorbar) plt_fig.tight_layout() figure_data = fig2data(plt_fig) plt.close() return prefix_name, figure_data
def test(self, loader, ensemble=False): """Run testing Args: loader (torch.utils.data.Dataloader): dataloader use for testing ensemble (bool, optional): use FAEL for prediction. Defaults to False. Returns: dict: metrics use to evaluate model performance. """ torch.set_grad_enabled(False) self.model.eval() running_metrics = dict.fromkeys(self.metrics.keys(), 0.0) ova_len = loader.dataset._num_image preds_stack = None labels_stack = None running_loss = None for i, data in enumerate(tqdm.tqdm(loader)): imgs, labels = data[0].to(self.device), data[1].to(self.device) if self.cfg.mix_precision: with torch.cuda.amp.autocast(): preds = self.model(imgs) if self.cfg.split_output: preds = torch.cat([aa for aa in preds], dim=-1) if self.cfg.conditional_training: preds = preds[:, self.id_leaf] labels = labels[:, self.id_leaf] loss = self.metrics['loss'](preds, labels) else: preds = self.model(imgs) if self.cfg.split_output: preds = torch.cat([aa for aa in preds], dim=-1) if self.cfg.conditional_training: preds = preds[:, self.id_leaf] labels = labels[:, self.id_leaf] loss = self.metrics['loss'](preds, labels) preds = nn.Sigmoid()(preds) if ensemble: preds = torch.mm(preds, self.ensemble_weights) labels = labels[:, self.id_obs] loss = nn.MSELoss()(preds, labels) iter_len = imgs.size()[0] if i == 0: preds_stack = preds labels_stack = labels running_loss = loss.item() * iter_len / ova_len else: preds_stack = torch.cat((preds_stack, preds), 0) labels_stack = torch.cat((labels_stack, labels), 0) running_loss += loss.item() * iter_len / ova_len for key in list(self.metrics.keys()): if key == 'loss': running_metrics[key] = running_loss else: running_metrics[key] = tensor2numpy(self.metrics[key]( preds_stack, labels_stack)) torch.set_grad_enabled(True) self.model.train() return running_metrics
def save_FAEL_weight(self, path): """save FAEL weight Args: path (str): path to saved weight. """ with open(path, 'wb') as f: pickle.dump(tensor2numpy(self.ensemble_weights.float()), f)
def predict_from_file(self, image_file): """Run prediction from image path Args: image_file (str): image path Returns: numpy array: model prediction in numpy array type """ image_gray = cv2.imread(image_file, 0) image = transform(image_gray, self.cfg) image = torch.from_numpy(image) image = image.unsqueeze(0) return tensor2numpy(nn.Sigmoid()(self.predict(image)))
test_loader, ensemble=False, cal_loss=False) preds_stack_test.append(preds) labels_stack_test.append(labels) preds_val = torch.mean(torch.stack(preds_stack_val, dim=0), dim=0) labels_val = torch.mean(torch.stack(labels_stack_val, dim=0), dim=0) preds_test = torch.mean(torch.stack(preds_stack_test, dim=0), dim=0) labels_test = torch.mean(torch.stack(labels_stack_test, dim=0), dim=0) auc_opt = AUC_ROC() thresh_val = auc_opt(preds_val, labels_val, thresholding=True) thresh_val = torch.Tensor(thresh_val).float().cuda() running_metrics_val = dict.fromkeys(metrics_dict.keys(), 0.0) running_metrics_test = dict.fromkeys(metrics_dict.keys(), 0.0) for key in list(metrics_dict.keys()): if key in ['f1', 'precision', 'recall', 'specificity', 'acc']: running_metrics_val[key] = tensor2numpy(metrics_dict[key]( preds_val, labels_val, thresh_val)) running_metrics_test[key] = tensor2numpy(metrics_dict[key]( preds, labels_test, thresh_val)) else: running_metrics_val[key] = tensor2numpy(metrics_dict[key]( preds_val, labels_val)) running_metrics_test[key] = tensor2numpy(metrics_dict[key]( preds_test, labels_test)) for key in metrics_dict.keys(): if key != 'loss': if key == 'auc' and cfg.type == 'chexmic': print(key, running_metrics_val[key], running_metrics_val[key][id_obs].mean()) running_metrics_val[key] = np.append(running_metrics_val[key],running_metrics_val[key][id_obs].mean()) else: print(key, running_metrics_val[key], running_metrics_val[key].mean())
def FAEL(self, loader, val_loader, type='basic', init_lr=0.01, log_iter=100, steps=20, lambda1=0.1, lambda2=2): """Run fully associative ensemble learning (FAEL) Args: loader (torch.utils.data.Dataloader): dataloader use for training FAEL model val_loader (torch.utils.data.Dataloader): dataloader use for validating FAEL model type (str, optional): regularization type (basic/binary constraint). Defaults to 'basic'. init_lr (float, optional): initial learning rate. Defaults to 0.01. log_step (int, optional): logging step. Defaults to 100. steps (int, optional): total steps. Defaults to 20. lambda1 (float, optional): l2 regularization parameter. Defaults to 0.1. lambda2 (int, optional): binary constraint parameter. Defaults to 2. Returns: dict: metrics use to evaluate model performance. """ n_class = len(self.cfg.num_classes) w = np.random.rand(n_class, len(self.id_obs)) iden_matrix = np.diag(np.ones(n_class)) lr = init_lr start = time.time() for i, data in enumerate(tqdm.tqdm(loader, total=log_iter * steps)): imgs, labels = data[0].to(self.device), data[1] preds = self.predict(imgs) preds = tensor2numpy(preds) labels = tensor2numpy(labels) labels = labels[:, self.id_obs] if type == 'basic': grad = (preds.T.dot(preds) + lambda1*iden_matrix).dot(w) - \ preds.T.dot(labels) elif type == 'b_constraint': grad = (preds.T.dot(preds) + lambda1 * iden_matrix + lambda2 * self.M.T.dot(self.M)).dot(w) + -preds.T.dot(labels) else: raise Exception("Not support this type!!!") w -= lr * grad if (i + 1) % log_iter == 0: # print(w) end = time.time() print('iter {:d} time takes: {:.3f}s'.format( i + 1, end - start)) start = time.time() if (i + 1) // log_iter == steps: break if self.cfg.mix_precision: self.ensemble_weights = torch.from_numpy(w).type( torch.HalfTensor).to(self.device) else: self.ensemble_weights = torch.from_numpy(w).float().to(self.device) print('Done Essemble!!!') metrics = self.test(val_loader, ensemble=True) return metrics
def train(self, train_loader, val_loader, epochs=10, iter_log=100, use_lr_sch=False, resume=False, ckp_dir='./experiment/checkpoint', writer=None, eval_metric='loss'): """Run training Args: train_loader (torch.utils.data.Dataloader): dataloader use for training val_loader (torch.utils.data.Dataloader): dataloader use for validation epochs (int, optional): number of training epochs. Defaults to 120. iter_log (int, optional): logging iteration. Defaults to 100. use_lr_sch (bool, optional): use learning rate scheduler. Defaults to False. resume (bool, optional): resume training process. Defaults to False. ckp_dir (str, optional): path to checkpoint directory. Defaults to './experiment/checkpoint'. writer (torch.utils.tensorboard.SummaryWriter, optional): tensorboard summery writer. Defaults to None. eval_metric (str, optional): name of metric for validation. Defaults to 'loss'. """ if use_lr_sch: lr_sch = torch.optim.lr_scheduler.StepLR(self.optimizer, int(epochs * 2 / 3), self.cfg.lr / 3) else: lr_sch = None best_metric = 0.0 if os.path.exists(ckp_dir) != True: os.mkdir(ckp_dir) if resume: epoch_resume, iter_resume = self.load_ckp( os.path.join(ckp_dir, 'latest.ckpt')) else: epoch_resume = 1 iter_resume = 0 scaler = None if self.cfg.mix_precision: print('Train with mix precision!') scaler = torch.cuda.amp.GradScaler() for epoch in range(epoch_resume - 1, epochs): start = time.time() running_metrics = dict.fromkeys(self.metrics.keys(), 0.0) running_metrics.pop(eval_metric, None) n_iter = len(train_loader) torch.set_grad_enabled(True) self.model.train() batch_weights = (1 / iter_log) * np.ones(n_iter) step_per_epoch = n_iter // iter_log if n_iter % iter_log: step_per_epoch += 1 batch_weights[-(n_iter % iter_log):] = 1 / (n_iter % iter_log) iter_per_step = iter_log * np.ones(step_per_epoch, dtype=np.int16) iter_per_step[-1] = n_iter % iter_log else: iter_per_step = iter_log * np.ones(step_per_epoch, dtype=np.int16) i = 0 for step in range(step_per_epoch): for iteration in tqdm.tqdm(range(iter_per_step[step])): data = next(iter(train_loader)) imgs, labels = data[0].to(self.device), data[1].to( self.device) if self.cfg.mix_precision: with torch.cuda.amp.autocast(): preds = self.model(imgs) if self.cfg.split_output: preds = torch.cat([aa for aa in preds], dim=-1) if self.cfg.conditional_training: preds = preds[:, self.id_leaf] labels = labels[:, self.id_leaf] loss = self.metrics['loss'](preds, labels) else: preds = self.model(imgs) if self.cfg.split_output: preds = torch.cat([aa for aa in preds], dim=-1) if self.cfg.conditional_training: preds = preds[:, self.id_leaf] labels = labels[:, self.id_leaf] loss = self.metrics['loss'](preds, labels) preds = nn.Sigmoid()(preds) running_loss = loss.item() * batch_weights[i] for key in list(running_metrics.keys()): if key == 'loss': running_metrics[key] += running_loss else: running_metrics[key] += tensor2numpy( self.metrics[key](preds, labels)) * batch_weights[i] self.optimizer.zero_grad() if self.cfg.mix_precision: scaler.scale(loss).backward() scaler.step(self.optimizer) scaler.update() else: loss.backward() self.optimizer.step() i += 1 s = "Epoch [{}/{}] Iter [{}/{}]:\n".format( epoch + 1, epochs, i + 1, n_iter) s = get_str(running_metrics, 'train', s) running_metrics_test = self.test(val_loader, False) s = get_str(running_metrics_test, 'val', s) if self.cfg.conditional_training or ( not self.cfg.full_classes): metric_eval = running_metrics_test[eval_metric] else: metric_eval = running_metrics_test[eval_metric][ self.id_obs] s = s[:-1] + "- mean_"+eval_metric + \ " {:.3f}".format(metric_eval.mean()) self.save_ckp(os.path.join(ckp_dir, 'latest.ckpt'), epoch, i) if writer is not None: for key in list(running_metrics.keys()): writer.add_scalars( key, {'train': running_metrics[key].mean()}, (epoch * n_iter) + (i + 1)) running_metrics = dict.fromkeys(self.metrics.keys(), 0.0) running_metrics.pop(eval_metric, None) end = time.time() s += " ({:.1f}s)".format(end - start) print(s) if metric_eval.mean() > best_metric: best_metric = metric_eval.mean() shutil.copyfile( os.path.join(ckp_dir, 'latest.ckpt'), os.path.join( ckp_dir, 'epoch' + str(epoch + 1) + '_iter' + str(i + 1) + '.ckpt')) print('new checkpoint saved!') start = time.time() s = "Epoch [{}/{}] Iter [{}/{}]:\n".format(epoch + 1, epochs, n_iter, n_iter) running_metrics = self.test(val_loader) s = get_str(running_metrics, 'val', s) if self.cfg.conditional_training or (not self.cfg.full_classes): metric_eval = running_metrics[eval_metric] else: metric_eval = running_metrics[eval_metric][self.id_obs] s = s[:-1] + "- mean_"+eval_metric + \ " {:.3f}".format(metric_eval.mean()) end = time.time() s += " ({:.1f}s)".format(end - start) print(s) if lr_sch is not None: lr_sch.step() print('current lr: {:.4f}'.format(lr_sch.get_lr()[0]))