Exemple #1
0
    def setUp(self):
        self.bundle_dir = tempfile.TemporaryDirectory()
        self.dir_name = os.path.join(self.bundle_dir.name, "TestBundle")
        self.configs_name = os.path.join(self.dir_name, "configs")
        self.models_name = os.path.join(self.dir_name, "models")
        self.metadata_name = os.path.join(self.configs_name, "metadata.json")
        self.test_name = os.path.join(self.configs_name, "test.json")
        self.modelpt_name = os.path.join(self.models_name, "model.pt")

        self.zip_file = os.path.join(self.bundle_dir.name, "TestBundle.zip")
        self.ts_file = os.path.join(self.bundle_dir.name, "TestBundle.ts")

        # create the directories for the bundle
        os.mkdir(self.dir_name)
        os.mkdir(self.configs_name)
        os.mkdir(self.models_name)

        # fill bundle configs

        with open(self.metadata_name, "w") as o:
            o.write(metadata)

        with open(self.test_name, "w") as o:
            o.write(test_json)

        # save network
        net = UNet(2, 1, 1, [4, 8], [2])
        torch.save(net.state_dict(), self.modelpt_name)
Exemple #2
0
    def test_bundle(self):
        with tempfile.TemporaryDirectory() as tempdir:
            net = UNet(2, 1, 1, [4, 8], [2])
            torch.save(net.state_dict(), tempdir + "/test.pt")

            bundle_root = tempdir + "/test_bundle"

            cmd = ["coverage", "run", "-m", "monai.bundle", "init_bundle", bundle_root, tempdir + "/test.pt"]
            subprocess.check_call(cmd)

            self.assertTrue(os.path.exists(bundle_root + "/configs/metadata.json"))
            self.assertTrue(os.path.exists(bundle_root + "/configs/inference.json"))
            self.assertTrue(os.path.exists(bundle_root + "/models/model.pt"))
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            #metric = metric_sum / metric_count
            metric_values.append(metric)
            scheduler.step(metric)  ##
            writer.add_scalar("val_mean_dice", metric, epoch + 1)  ##
            writer.add_scalar("Learning rate", optimizer.param_groups[0]['lr'],
                              epoch + 1)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(),
                           os.path.join(out_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                  f"\nbest mean dice: {best_metric:.4f}"
                  f"at epoch: {best_metric_epoch}")

print(f"train completed, best_metric: {best_metric:.4f}"
      f"at epoch: {best_metric_epoch}")
"""## Plot the loss and metric"""

#fig2=plt.figure("train", (12, 6))
#plt.subplot(1, 2, 1)
#plt.title("Epoch Average Loss")
#x = [i + 1 for i in range(len(epoch_loss_values))]
#y = epoch_loss_values
Exemple #4
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"Missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[128, 128, 64],
                         random_size=False),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    val_loader = DataLoader(val_ds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if dist.get_rank() == 0:
        # Logging for TensorBoard
        writer = SummaryWriter(log_dir=args.log_dir)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    if args.network == "UNet":
        model = UNet(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
    else:
        model = SegResNet(in_channels=4,
                          out_channels=3,
                          init_filters=16,
                          dropout_prob=0.2).to(device)
    loss_function = DiceLoss(to_onehot_y=False,
                             sigmoid=True,
                             squared_pred=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5,
                                 amsgrad=True)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])

    # start a typical PyTorch training
    total_epoch = args.epochs
    best_metric = -1000000
    best_metric_epoch = -1
    epoch_time = AverageMeter("Time", ":6.3f")
    progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
    end = time.time()
    print(f"Time elapsed before training: {end-total_start}")
    for epoch in range(total_epoch):

        train_loss = train(train_loader, model, loss_function, optimizer,
                           epoch, args, device)
        epoch_time.update(time.time() - end)

        if epoch % args.print_freq == 0:
            progress.display(epoch)

        if dist.get_rank() == 0:
            writer.add_scalar("Loss/train", train_loss, epoch)

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, device)

            if dist.get_rank() == 0:
                writer.add_scalar("Mean Dice/val", metric, epoch)
                writer.add_scalar("Mean Dice TC/val", metric_tc, epoch)
                writer.add_scalar("Mean Dice WT/val", metric_wt, epoch)
                writer.add_scalar("Mean Dice ET/val", metric_et, epoch)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
        end = time.time()
        print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")

    if dist.get_rank() == 0:
        print(
            f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
        )
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
        writer.flush()
    dist.destroy_process_group()
Exemple #5
0
                    rng_valid_dataload = nvtx.start_range(message="compute metric", color="yellow")
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                    nvtx.end_range(rng_valid_dataload)

                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    best_metrics_epochs_and_time[2].append(time.time() - total_start)
                    torch.save(
                        model.state_dict(), os.path.join(out_dir, "best_metric_model.pth")
                    )
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} "
                    f"current mean dice: {metric:.4f} "
                    f"best mean dice: {best_metric:.4f} "
                    f" at epoch: {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
        nvtx.end_range(rng_epoch)
        print(
            f"time consuming of epoch {epoch + 1} is:"
            f" {(time.time() - epoch_start):.4f}"
        )
        epoch_times.append(time.time() - epoch_start)
Exemple #6
0
def train_process(fast=False):
    epoch_num = 10
    val_interval = 1
    train_trans, val_trans = transformations()
    train_ds = Dataset(data=train_files, transform=train_trans)
    val_ds = Dataset(data=val_files, transform=val_trans)

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n1 = 16
    model = UNet(dimensions=3,
                 in_channels=1,
                 out_channels=2,
                 channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16),
                 strides=(2, 2, 2, 2)).to(device)
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
    post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
    post_label = AsDiscrete(to_onehot=True, n_classes=2)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = list()
    metric_values = list()

    for epoch in range(epoch_num):
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['image'].to(
                device), batch_data['label'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = val_data['image'].to(
                        device), val_data['label'].to(device)
                    val_outputs = model(val_inputs)
                    val_outputs = post_pred(val_outputs)
                    val_labels = post_label(val_labels)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    epochs_no_improve = 0
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    torch.save(model.state_dict(), 'sLUMRTL644.pth')
                else:
                    epochs_no_improve += 1

            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = post_pred(val_outputs)
                val_labels = post_label(val_labels)
                value = compute_meandice(
                    y_pred=val_outputs,
                    y=val_labels,
                    include_background=False,
                )
                metric_count += len(value)
                metric_sum += value.sum().item()
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

print(f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}")

"""## Plot the loss and metric"""

plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
Exemple #8
0
                            # compute metric for current iteration
                            dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    metric_values.append(metric)
                    if metric > best_metric:
                        best_metric = metric
                        best_metric_epoch = epoch + 1
                        best_metrics_epochs_and_time[0].append(best_metric)
                        best_metrics_epochs_and_time[1].append(
                            best_metric_epoch)
                        best_metrics_epochs_and_time[2].append(time.time() -
                                                               total_start)
                        torch.save(
                            model.state_dict(),
                            os.path.join(out_dir, "best_metric_model.pth"))
                        print("saved new best metric model")
                    print(f"current epoch: {epoch + 1} "
                          f"current mean dice: {metric:.4f} "
                          f"best mean dice: {best_metric:.4f} "
                          f" at epoch: {best_metric_epoch}")
                    writer.add_scalar("val_mean_dice", metric, epoch + 1)
        print(f"time consuming of epoch {epoch + 1} is:"
              f" {(time.time() - epoch_start):.4f}")
        epoch_times.append(time.time() - epoch_start)

total_time = time.time() - total_start
print(f"train completed, best_metric: {best_metric:.4f}"
      f" at epoch: {best_metric_epoch}"
      f" total time: {total_time:.4f}")