Esempio n. 1
0
def run_infer(weights_folder_path, cfg):

    cfg.pretrained = False
    # for local test, please modify the following path into actual path.
    cfg.data_folder = cfg.data_dir + "test/"

    to_device_transform = ToDeviced(keys=("input", "target", "mask",
                                          "is_annotated"),
                                    device=cfg.device)

    all_path = []
    for path in glob.iglob(os.path.join(weights_folder_path, "*.pth")):
        all_path.append(path)

    nets = []
    for path in all_path:
        state_dict = torch.load(path)["model"]
        new_state_dict = {}
        for k, v in state_dict.items():
            new_state_dict[k.replace("module.", "")] = v

        net = RanzcrNet(cfg).eval().to(cfg.device)
        net.load_state_dict(new_state_dict)

        del net.decoder
        del net.segmentation_head

        nets.append(net)

    test_df = pd.read_csv(cfg.test_df)
    test_dataset = get_test_dataset(test_df, cfg)
    test_dataloader = get_test_dataloader(test_dataset, cfg)

    with torch.no_grad():
        fold_preds = [[] for i in range(len(nets))]
        for batch in tqdm(test_dataloader):
            batch = to_device_transform(batch)
            for i, net in enumerate(nets):
                if cfg.mixed_precision:
                    with autocast():
                        logits = net(batch)["logits"].cpu().numpy()
                else:
                    logits = net(batch)["logits"].cpu().numpy()
                fold_preds[i] += [logits]
        fold_preds = [np.concatenate(p) for p in fold_preds]

    preds = np.stack(fold_preds)
    preds = expit(preds)
    preds = np.mean(preds, axis=0)

    sub_df = test_df.copy()
    sub_df[cfg.label_cols] = preds

    submission = pd.read_csv(cfg.test_df)
    submission.loc[sub_df.index, cfg.label_cols] = sub_df[cfg.label_cols]
    submission.to_csv("submission.csv", index=False)
Esempio n. 2
0
 def test_value(self):
     device = "cuda:0"
     data = [{"img": torch.tensor(i)} for i in range(4)]
     dataset = CacheDataset(data=data,
                            transform=ToDeviced(keys="img",
                                                device=device,
                                                non_blocking=True),
                            cache_rate=1.0)
     dataloader = ThreadDataLoader(dataset=dataset,
                                   num_workers=0,
                                   batch_size=1)
     for i, d in enumerate(dataloader):
         torch.testing.assert_allclose(d["img"],
                                       torch.tensor([i], device=device))
    def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
                dtype=np.float64,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose(
            [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
Esempio n. 4
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    # use amp to accelerate training
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[224, 224, 144],
                         random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    train_loader = ThreadDataLoader(train_ds,
                                    num_workers=0,
                                    batch_size=args.batch_size,
                                    shuffle=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    # create network, loss function and optimizer
    if args.network == "SegResNet":
        model = SegResNet(
            blocks_down=[1, 2, 2, 4],
            blocks_up=[1, 1, 1],
            init_filters=16,
            in_channels=4,
            out_channels=3,
            dropout_prob=0.0,
        ).to(device)
    else:
        model = UNet(
            spatial_dims=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)

    loss_function = DiceFocalLoss(
        smooth_nr=1e-5,
        smooth_dr=1e-5,
        squared_pred=True,
        to_onehot_y=False,
        sigmoid=True,
        batch=True,
    )
    optimizer = Novograd(model.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.epochs)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric_batch = DiceMetric(include_background=True,
                                   reduction="mean_batch")

    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])

    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    print(f"time elapsed before training: {time.time() - total_start}")
    train_start = time.time()
    for epoch in range(args.epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{args.epochs}")
        epoch_loss = train(train_loader, model, loss_function, optimizer,
                           lr_scheduler, scaler)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, dice_metric, dice_metric_batch, post_trans)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                if dist.get_rank() == 0:
                    torch.save(model.state_dict(), "best_metric_model.pth")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

        print(
            f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
        )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch},"
        f" total train time: {(time.time() - train_start):.4f}")
    dist.destroy_process_group()
Esempio n. 5
0
                clip=True,
            )
        ),
        Range()(CropForegroundd(keys=["image", "label"], source_key="image")),
        # pre-compute foreground and background indexes
        # and cache them to accelerate training
        Range("Indexing")(
            FgBgToIndicesd(
                keys="label",
                fg_postfix="_fg",
                bg_postfix="_bg",
                image_key="image",
            )
        ),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device="cuda:0"),
        Range("RandCrop")(
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            )
        )
    ]
)
Esempio n. 6
0
def main(cfg):

    os.makedirs(str(cfg.output_dir + f"/fold{cfg.fold}/"), exist_ok=True)

    # set random seed, works when use all data to train
    if cfg.seed < 0:
        cfg.seed = np.random.randint(1_000_000)
    set_seed(cfg.seed)

    # set dataset, dataloader
    train = pd.read_csv(cfg.train_df)

    if cfg.fold == -1:
        val_df = train[train["fold"] == 0]
    else:
        val_df = train[train["fold"] == cfg.fold]
    train_df = train[train["fold"] != cfg.fold]

    train_dataset = get_train_dataset(train_df, cfg)
    val_dataset = get_val_dataset(val_df, cfg)

    train_dataloader = get_train_dataloader(train_dataset, cfg)
    val_dataloader = get_val_dataloader(val_dataset, cfg)

    if cfg.train_val is True:
        train_val_dataset = get_val_dataset(train_df, cfg)
        train_val_dataloader = get_val_dataloader(train_val_dataset, cfg)

    to_device_transform = ToDeviced(keys=("input", "target", "mask",
                                          "is_annotated"),
                                    device=cfg.device)
    cfg.to_device_transform = to_device_transform
    # set model

    model = RanzcrNet(cfg)
    model.to(cfg.device)

    # set optimizer, lr scheduler
    total_steps = len(train_dataset)

    optimizer = get_optimizer(model, cfg)
    scheduler = get_scheduler(cfg, optimizer, total_steps)

    # set other tools
    if cfg.mixed_precision:
        scaler = GradScaler()
    else:
        scaler = None

    writer = SummaryWriter(str(cfg.output_dir + f"/fold{cfg.fold}/"))

    # train and val loop
    step = 0
    i = 0
    best_val_loss = np.inf
    optimizer.zero_grad()
    for epoch in range(cfg.epochs):
        print("EPOCH:", epoch)
        gc.collect()
        if cfg.train is True:
            run_train(
                model=model,
                train_dataloader=train_dataloader,
                optimizer=optimizer,
                scheduler=scheduler,
                cfg=cfg,
                scaler=scaler,
                writer=writer,
                epoch=epoch,
                iteration=i,
                step=step,
            )

        if (epoch + 1) % cfg.eval_epochs == 0 or (epoch + 1) == cfg.epochs:
            val_loss = run_eval(
                model=model,
                val_dataloader=val_dataloader,
                cfg=cfg,
                writer=writer,
                epoch=epoch,
            )

        if cfg.train_val is True:
            if (epoch + 1) % cfg.eval_train_epochs == 0 or (epoch +
                                                            1) == cfg.epochs:
                train_val_loss = run_eval(model, train_val_dataloader, cfg,
                                          writer, epoch)
                print(f"train_val_loss {train_val_loss:.5}")

        if val_loss < best_val_loss:
            print(
                f"SAVING CHECKPOINT: val_loss {best_val_loss:.5} -> {val_loss:.5}"
            )
            best_val_loss = val_loss

            checkpoint = create_checkpoint(
                model,
                optimizer,
                epoch,
                scheduler=scheduler,
                scaler=scaler,
            )
            torch.save(
                checkpoint,
                f"{cfg.output_dir}/fold{cfg.fold}/checkpoint_best_seed{cfg.seed}.pth",
            )