def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred = Activations(softmax=softmax)(y_pred) y = AsDiscrete(to_onehot=to_onehot, n_classes=2)(y) metric = ROCAUCMetric(average=average) metric(y_pred=y_pred, y=y) result = metric.aggregate() np.testing.assert_allclose(expected_value, result, rtol=1e-5)
def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] metric = ROCAUCMetric(average=average) metric(y_pred=y_pred, y=y) result = metric.aggregate() metric.reset() np.testing.assert_allclose(expected_value, result, rtol=1e-5)
def __init__(self, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x) -> None: metric_fn = ROCAUCMetric(average=Average(average)) super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", num_workers=10): monai.config.print_config() # define transforms for image and classification train_transforms = Compose( [ LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True, dtype=np.float64), RandFlip(spatial_axis=0, prob=0.5), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensor(), ] ) train_transforms.set_random_state(1234) val_transforms = Compose( [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders train_ds = MedNISTDataset(train_x, train_y, train_transforms) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers) val_ds = MedNISTDataset(val_x, val_y, val_transforms) val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers) model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(train_y))).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) epoch_num = 4 val_interval = 1 # start training validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(epoch_num): print("-" * 10) print(f"Epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: with eval_mode(model): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) # compute accuracy acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) # decollate prediction and label and execute post processing y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] # compute AUC auc_metric(y_pred, y) auc_value = auc_metric.aggregate() auc_metric.reset() metric_values.append(auc_value) if auc_value > best_metric: best_metric = auc_value best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current AUC: {auc_value:0.4f} " f"current accuracy: {acc_metric:0.4f} best AUC: {best_metric:0.4f} at epoch {best_metric_epoch}" ) print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}") return epoch_loss_values, best_metric, best_metric_epoch
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ # the path of ixi IXI-T1 dataset data_path = os.sep.join([".", "workspace", "data", "medical", "ixi", "IXI-T1"]) images = [ "IXI314-IOP-0889-T1.nii.gz", "IXI249-Guys-1072-T1.nii.gz", "IXI609-HH-2600-T1.nii.gz", "IXI173-HH-1590-T1.nii.gz", "IXI020-Guys-0700-T1.nii.gz", "IXI342-Guys-0909-T1.nii.gz", "IXI134-Guys-0780-T1.nii.gz", "IXI577-HH-2661-T1.nii.gz", "IXI066-Guys-0731-T1.nii.gz", "IXI130-HH-1528-T1.nii.gz", "IXI607-Guys-1097-T1.nii.gz", "IXI175-HH-1570-T1.nii.gz", "IXI385-HH-2078-T1.nii.gz", "IXI344-Guys-0905-T1.nii.gz", "IXI409-Guys-0960-T1.nii.gz", "IXI584-Guys-1129-T1.nii.gz", "IXI253-HH-1694-T1.nii.gz", "IXI092-HH-1436-T1.nii.gz", "IXI574-IOP-1156-T1.nii.gz", "IXI585-Guys-1130-T1.nii.gz", ] images = [os.sep.join([data_path, f]) for f in images] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) train_files = [{"img": img, "label": label} for img, label in zip(images[:10], labels[:10])] val_files = [{"img": img, "label": label} for img, label in zip(images[-10:], labels[-10:])] # Define transforms for image train_transforms = Compose( [ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), EnsureTyped(keys=["img"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), EnsureTyped(keys=["img"]), ] ) post_pred = Compose([EnsureType(), Activations(softmax=True)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) # Create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) auc_metric = ROCAUCMetric() # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to(device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["label"].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) y_onehot = [post_label(i) for i in decollate_batch(y)] y_pred_act = [post_pred(i) for i in decollate_batch(y_pred)] auc_metric(y_pred_act, y_onehot) auc_result = auc_metric.aggregate() auc_metric.reset() del y_pred_act, y_onehot if acc_metric > best_metric: best_metric = acc_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_classification3d_dict.pth") print("saved new best metric model") print( "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}".format( epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch ) ) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()