def fit(self, tr_loader, val_loader, epochs, val_interval, loss, val_metrics, opt): """Trains the NN. Args: `tr_loader`: DataLoader with the training set. `val_loader`: DataLoader with the validaiton set. `epochs`: Number of epochs to train the model. If 0, no train. `val_interval`: After how many epochs to perform validation. `loss`: Name of the loss function. `val_metrics`: Which metrics to measure at validation time. `opt`: Optimizer. """ t0 = time.time() e = 1 # Expected classes of our dataset measure_classes = {0: "background", 1: "contra", 2: "R_hemisphere"} # Which classes will be reported during validation measure_classes_mean = np.array([1, 2]) while e <= epochs: self.train() tr_loss = 0 for (tr_i), (X, Y, info, W) in enumerate(tr_loader): X = [x.to(self.device) for x in X] Y = [y.to(self.device) for y in Y] W = [w.to(self.device) for w in W] output = self(X) pred = output[0] tr_loss_tmp = loss(output, Y, W) tr_loss += tr_loss_tmp # Optimization opt.zero_grad() tr_loss_tmp.backward() opt.step() tr_loss /= len(tr_loader) if len(val_loader) != 0 and e % val_interval == 0: log("Validation", self.out_path) self.eval() val_loss = 0 # val_scores stores all needed metrics for assessing validation val_scores = np.zeros( (len(val_metrics), len(val_loader), len(measure_classes))) Measure = Metric(val_metrics, onehot=softmax2onehot, classes=measure_classes, multiprocess=False) with torch.no_grad(): for (val_i), (X, Y, info, W) in enumerate(val_loader): X = [x.to(self.device) for x in X] Y = [y.to(self.device) for y in Y] W = [w.to(self.device) for w in W] output = self(X) val_loss_tmp = loss(output, Y, W) val_loss += val_loss_tmp y_true_cpu = Y[0].cpu().numpy() y_pred_cpu = output[0].cpu().numpy() # Record all needed metrics # If batch_size > 1, Measure.all() returns an avg. tmp_res = Measure.all(y_pred_cpu, y_true_cpu) for i, m in enumerate(val_metrics): val_scores[i, val_i] = tmp_res[m] # Validation loss val_loss /= len(val_loader) val_str = " Val Loss: {}".format(val_loss) # val_metrics shape: num_metrics x num_batches x num_classes for i, m in enumerate(val_metrics): # tmp shape: num_classes (averaged over num_batches when val != -1) tmp = np.array(Measure._getMean(val_scores[i])) # Mean validation value in metric m (all interesting classes) tmp_val = tmp[measure_classes_mean] # Note: if tmp_val is NaN, it means that the classes I am # interested in (check lib/data/whatever, measure_classes_mean) # were not found in the validation set. tmp_val = np.mean(tmp_val[tmp_val != -1]) val_str += ". Val " + m + ": " + str(tmp_val) else: val_str = "" eta = " ETA: " + datetime.fromtimestamp( time.time() + (epochs - e) * (time.time() - t0) / e).strftime("%Y-%m-%d %H:%M:%S") log("Epoch: {}. Loss: {}.".format(e, tr_loss) + val_str + eta, self.out_path) # Save model after every epoch torch.save( self.state_dict(), self.out_path + "model/MedicDeepLabv3Plus-model-" + str(e)) if e > 1 and os.path.exists(self.out_path + "model/MedicDeepLabv3Plus-model-" + str(e - 1)): os.remove(self.out_path + "model/MedicDeepLabv3Plus-model-" + str(e - 1)) e += 1
def evaluate(self, test_loader, metrics, remove_islands, save_output=True): """Tests/Evaluates the NN. Args: `test_loader`: DataLoader containing the test set. Batch_size = 1. `metrics`: Metrics to measure. `save_output`: (bool) whether to save the output segmentations. `remove_islands`: (bool) whether to apply post-processing. """ # Expected classes of our dataset measure_classes = {0: "background", 1: "contra", 2: "R_hemisphere"} results = {} self.eval() Measure = Metric(metrics, onehot=sigmoid2onehot, classes=measure_classes, multiprocess=True) # Pool to store pieces of output that will be put together # before evaluating the whole image. # This is useful when the entire image doesn't fit into mem. with torch.no_grad(): for (test_i), (X, Y, info, W) in enumerate(test_loader): print("{}/{}".format(test_i + 1, len(test_loader))) X = [x.to(self.device) for x in X] Y = [y.to(self.device) for y in Y] W = [w.to(self.device) for w in W] id_ = info["id"][0] output = self(X) y_pred_cpu = output[0].cpu().numpy() y_true_cpu = Y[0].cpu().numpy() if remove_islands: y_pred_cpu = removeSmallIslands(y_pred_cpu, thr=20) # Predictions (and GT) separate the two hemispheres # combineLabels will combine these such that it creates # brainmask and contra-hemisphere ROIs instead of # two different hemisphere ROIs. y_pred_cpu = combineLabels(y_pred_cpu) # If GT was provided it measures the performance if len(y_true_cpu.shape) > 1: y_true_cpu = combineLabels(y_true_cpu) results[id_] = Measure.all(y_pred_cpu, y_true_cpu) test_loader.dataset.save(y_pred_cpu[0], info, self.out_path + id_) # Gather results (multiprocessing) for k in results: results[k] = results[k].get() if len(results) > 0: with open(self.out_path + "stats.json", "w") as f: f.write(json.dumps(results)) # If we are using multiprocessing we need to close the pool Measure.close()