Ejemplo n.º 1
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), dtype=np.uint8),
         FilterImaged(keys="image", min_size=5),
         AsChannelFirstd(keys="image"),
         AddChanneld(keys="label"),
         ToTensord(keys="image"),
         TorchVisiond(keys="image",
                      name="ColorJitter",
                      brightness=64.0 / 255.0,
                      contrast=0.75,
                      saturation=0.25,
                      hue=0.04),
         ToNumpyd(keys="image"),
         RandRotate90d(keys=("image", "label"),
                       prob=0.5,
                       spatial_axes=(0, 1)),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddInitialSeedPointExd(label="label", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys=("image", "label")),
     ]
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     return [
         LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8),
         FilterImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
     ]
Ejemplo n.º 3
0
 def pre_transforms(self, data=None):
     return [
         LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8),
         FilterImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddClickGuidanced(image="image", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ]
def main_worker(gpu, args):

    args.gpu = gpu

    if args.distributed:
        args.rank = args.rank * torch.cuda.device_count() + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    print(args.rank, " gpu", args.gpu)

    torch.cuda.set_device(
        args.gpu
    )  # use this default device (same as args.device if not distributed)
    torch.backends.cudnn.benchmark = True

    if args.rank == 0:
        print("Batch size is:", args.batch_size, "epochs", args.epochs)

    #############
    # Create MONAI dataset
    training_list = load_decathlon_datalist(
        data_list_file_path=args.dataset_json,
        data_list_key="training",
        base_dir=args.data_root,
    )
    validation_list = load_decathlon_datalist(
        data_list_file_path=args.dataset_json,
        data_list_key="validation",
        base_dir=args.data_root,
    )

    if args.quick:  # for debugging on a small subset
        training_list = training_list[:16]
        validation_list = validation_list[:16]

    train_transform = Compose([
        LoadImageD(keys=["image"],
                   reader=WSIReader,
                   backend="TiffFile",
                   dtype=np.uint8,
                   level=1,
                   image_only=True),
        LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
        TileOnGridd(
            keys=["image"],
            tile_count=args.tile_count,
            tile_size=args.tile_size,
            random_offset=True,
            background_val=255,
            return_list_of_dicts=True,
        ),
        RandFlipd(keys=["image"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image"], spatial_axis=1, prob=0.5),
        RandRotate90d(keys=["image"], prob=0.5),
        ScaleIntensityRangeD(keys=["image"],
                             a_min=np.float32(255),
                             a_max=np.float32(0)),
        ToTensord(keys=["image", "label"]),
    ])

    valid_transform = Compose([
        LoadImageD(keys=["image"],
                   reader=WSIReader,
                   backend="TiffFile",
                   dtype=np.uint8,
                   level=1,
                   image_only=True),
        LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
        TileOnGridd(
            keys=["image"],
            tile_count=None,
            tile_size=args.tile_size,
            random_offset=False,
            background_val=255,
            return_list_of_dicts=True,
        ),
        ScaleIntensityRangeD(keys=["image"],
                             a_min=np.float32(255),
                             a_max=np.float32(0)),
        ToTensord(keys=["image", "label"]),
    ])

    dataset_train = Dataset(data=training_list, transform=train_transform)
    dataset_valid = Dataset(data=validation_list, transform=valid_transform)

    train_sampler = DistributedSampler(
        dataset_train) if args.distributed else None
    val_sampler = DistributedSampler(
        dataset_valid, shuffle=False) if args.distributed else None

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        multiprocessing_context="spawn",
        sampler=train_sampler,
        collate_fn=list_data_collate,
    )
    valid_loader = torch.utils.data.DataLoader(
        dataset_valid,
        batch_size=1,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        multiprocessing_context="spawn",
        sampler=val_sampler,
        collate_fn=list_data_collate,
    )

    if args.rank == 0:
        print("Dataset training:", len(dataset_train), "validation:",
              len(dataset_valid))

    model = milmodel.MILModel(num_classes=args.num_classes,
                              pretrained=True,
                              mil_mode=args.mil_mode)

    best_acc = 0
    start_epoch = 0
    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
        if "epoch" in checkpoint:
            start_epoch = checkpoint["epoch"]
        if "best_acc" in checkpoint:
            best_acc = checkpoint["best_acc"]
        print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(
            args.checkpoint, start_epoch, best_acc))

    model.cuda(args.gpu)

    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], output_device=args.gpu)

    if args.validate:
        # if we only want to validate existing checkpoint
        epoch_time = time.time()
        val_loss, val_acc, qwk = val_epoch(model,
                                           valid_loader,
                                           epoch=0,
                                           args=args,
                                           max_tiles=args.tile_count)
        if args.rank == 0:
            print(
                "Final validation loss: {:.4f}".format(val_loss),
                "acc: {:.4f}".format(val_acc),
                "qwk: {:.4f}".format(qwk),
                "time {:.2f}s".format(time.time() - epoch_time),
            )

        exit(0)

    params = model.parameters()

    if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
        m = model if not args.distributed else model.module
        params = [
            {
                "params":
                list(m.attention.parameters()) + list(m.myfc.parameters()) +
                list(m.net.parameters())
            },
            {
                "params": list(m.transformer.parameters()),
                "lr": 6e-6,
                "weight_decay": 0.1
            },
        ]

    optimizer = torch.optim.AdamW(params,
                                  lr=args.optim_lr,
                                  weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epochs,
                                                           eta_min=0)

    if args.logdir is not None and args.rank == 0:
        writer = SummaryWriter(log_dir=args.logdir)
        if args.rank == 0:
            print("Writing Tensorboard logs to ", writer.log_dir)
    else:
        writer = None

    ###RUN TRAINING
    n_epochs = args.epochs
    val_acc_max = 0.0

    scaler = None
    if args.amp:  # new native amp
        scaler = GradScaler()

    for epoch in range(start_epoch, n_epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)
            torch.distributed.barrier()

        print(args.rank, time.ctime(), "Epoch:", epoch)

        epoch_time = time.time()
        train_loss, train_acc = train_epoch(model,
                                            train_loader,
                                            optimizer,
                                            scaler=scaler,
                                            epoch=epoch,
                                            args=args)

        if args.rank == 0:
            print(
                "Final training  {}/{}".format(epoch, n_epochs - 1),
                "loss: {:.4f}".format(train_loss),
                "acc: {:.4f}".format(train_acc),
                "time {:.2f}s".format(time.time() - epoch_time),
            )

        if args.rank == 0 and writer is not None:
            writer.add_scalar("train_loss", train_loss, epoch)
            writer.add_scalar("train_acc", train_acc, epoch)

        if args.distributed:
            torch.distributed.barrier()

        b_new_best = False
        val_acc = 0
        if (epoch + 1) % args.val_every == 0:

            epoch_time = time.time()
            val_loss, val_acc, qwk = val_epoch(model,
                                               valid_loader,
                                               epoch=epoch,
                                               args=args,
                                               max_tiles=args.tile_count)
            if args.rank == 0:
                print(
                    "Final validation  {}/{}".format(epoch, n_epochs - 1),
                    "loss: {:.4f}".format(val_loss),
                    "acc: {:.4f}".format(val_acc),
                    "qwk: {:.4f}".format(qwk),
                    "time {:.2f}s".format(time.time() - epoch_time),
                )
                if writer is not None:
                    writer.add_scalar("val_loss", val_loss, epoch)
                    writer.add_scalar("val_acc", val_acc, epoch)
                    writer.add_scalar("val_qwk", qwk, epoch)

                val_acc = qwk

                if val_acc > val_acc_max:
                    print("qwk ({:.6f} --> {:.6f})".format(
                        val_acc_max, val_acc))
                    val_acc_max = val_acc
                    b_new_best = True

        if args.rank == 0 and args.logdir is not None:
            save_checkpoint(model,
                            epoch,
                            args,
                            best_acc=val_acc,
                            filename="model_final.pt")
            if b_new_best:
                print("Copying to model.pt new best model!!!!")
                shutil.copyfile(os.path.join(args.logdir, "model_final.pt"),
                                os.path.join(args.logdir, "model.pt"))

        scheduler.step()

    print("ALL DONE")
Ejemplo n.º 5
0
def main(cfg):
    # -------------------------------------------------------------------------
    # Configs
    # -------------------------------------------------------------------------
    # Create log/model dir
    log_dir = create_log_dir(cfg)

    # Set the logger
    logging.basicConfig(
        format="%(asctime)s %(levelname)2s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    log_name = os.path.join(log_dir, "logs.txt")
    logger = logging.getLogger()
    fh = logging.FileHandler(log_name)
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Set TensorBoard summary writer
    writer = SummaryWriter(log_dir)

    # Save configs
    logging.info(json.dumps(cfg))
    with open(os.path.join(log_dir, "config.json"), "w") as fp:
        json.dump(cfg, fp, indent=4)

    # Set device cuda/cpu
    device = set_device(cfg)

    # Set cudnn benchmark/deterministic
    if cfg["benchmark"]:
        torch.backends.cudnn.benchmark = True
    else:
        set_determinism(seed=0)
    # -------------------------------------------------------------------------
    # Transforms and Datasets
    # -------------------------------------------------------------------------
    # Pre-processing
    preprocess_cpu_train = None
    preprocess_gpu_train = None
    preprocess_cpu_valid = None
    preprocess_gpu_valid = None
    if cfg["backend"] == "cucim":
        preprocess_cpu_train = Compose([ToTensorD(keys="label")])
        preprocess_gpu_train = Compose([
            Range()(ToCupy()),
            Range("ColorJitter")(RandCuCIM(name="color_jitter",
                                           brightness=64.0 / 255.0,
                                           contrast=0.75,
                                           saturation=0.25,
                                           hue=0.04)),
            Range("RandomFlip")(RandCuCIM(name="image_flip",
                                          apply_prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90",
                                              prob=cfg["prob"],
                                              max_k=3,
                                              spatial_axis=(-2, -1))),
            Range()(CastToType(dtype=np.float32)),
            Range("RandomZoom")(RandCuCIM(name="rand_zoom",
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(CuCIM(name="scale_intensity_range",
                                          a_min=0.0,
                                          a_max=255.0,
                                          b_min=-1.0,
                                          b_max=1.0)),
            Range()(ToTensor(device=device)),
        ])
        preprocess_cpu_valid = Compose([ToTensorD(keys="label")])
        preprocess_gpu_valid = Compose([
            Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)),
            Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range",
                                               a_min=0.0,
                                               a_max=255.0,
                                               b_min=-1.0,
                                               b_max=1.0)),
            Range("ValidToTensor")(ToTensor(device=device)),
        ])
    elif cfg["backend"] == "numpy":
        preprocess_cpu_train = Compose([
            Range()(ToTensorD(keys=("image", "label"))),
            Range("ColorJitter")(TorchVisionD(
                keys="image",
                name="ColorJitter",
                brightness=64.0 / 255.0,
                contrast=0.75,
                saturation=0.25,
                hue=0.04,
            )),
            Range()(ToNumpyD(keys="image")),
            Range("RandomFlip")(RandFlipD(keys="image",
                                          prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandRotate90D(keys="image",
                                                  prob=cfg["prob"])),
            Range()(CastToTypeD(keys="image", dtype=np.float32)),
            Range("RandomZoom")(RandZoomD(keys="image",
                                          prob=cfg["prob"],
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                         a_min=0.0,
                                                         a_max=255.0,
                                                         b_min=-1.0,
                                                         b_max=1.0)),
            Range()(ToTensorD(keys="image")),
        ])
        preprocess_cpu_valid = Compose([
            Range("ValidCastType")(CastToTypeD(keys="image",
                                               dtype=np.float32)),
            Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                              a_min=0.0,
                                                              a_max=255.0,
                                                              b_min=-1.0,
                                                              b_max=1.0)),
            Range("ValidToTensor")(ToTensorD(keys=("image", "label"))),
        ])
    else:
        raise ValueError(
            f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]"
        )

    # Post-processing
    postprocess = Compose([
        Activations(sigmoid=True),
        AsDiscrete(threshold=0.5),
    ])

    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )
    train_dataset = PatchWSIDataset(
        data=train_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_train,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        data=valid_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_valid,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])

    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("First sample is None!")
    for d in ["image", "label"]:
        logging.info(f"[{d}] \n"
                     f"  {d} shape: {first_sample[d].shape}\n"
                     f"  {d} type:  {type(first_sample[d])}\n"
                     f"  {d} dtype: {first_sample[d].dtype}")
    logging.info(f"Batch size: {cfg['batch_size']}")
    logging.info(f"[Training] number of batches: {len(train_dataloader)}")
    logging.info(f"[Validation] number of batches: {len(valid_dataloader)}")
    # -------------------------------------------------------------------------
    # Deep Learning Model and Configurations
    # -------------------------------------------------------------------------
    # Initialize model
    model = TorchVisionFCModel("resnet18",
                               n_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # Loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # Optimizer
    if cfg["novograd"] is True:
        optimizer = Novograd(model.parameters(), lr=cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6)
    if cfg["amp"] is True:
        scaler = GradScaler()
    else:
        scaler = None

    # Learning rate scheduler
    if cfg["cos"] is True:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=cfg["n_epochs"])
    else:
        scheduler = None

    # -------------------------------------------------------------------------
    # Training/Evaluating
    # -------------------------------------------------------------------------
    train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1}

    total_valid_time, total_train_time = 0.0, 0.0
    t_start = time.perf_counter()
    metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1}
    # Training/Validation Loop
    for _ in range(cfg["n_epochs"]):
        t_epoch = time.perf_counter()
        logging.info(
            f"[Training] learning rate: {optimizer.param_groups[0]['lr']}")

        # Training
        with Range("Training Epoch"):
            train_counter = training(
                train_counter,
                model,
                loss_func,
                optimizer,
                scaler,
                cfg["amp"],
                train_dataloader,
                preprocess_gpu_train,
                postprocess,
                device,
                writer,
                cfg["print_step"],
            )
        if scheduler is not None:
            scheduler.step()
        if cfg["save"]:
            torch.save(
                model.state_dict(),
                os.path.join(log_dir,
                             f"model_epoch_{train_counter['epoch']}.pt"))
        t_train = time.perf_counter()
        train_time = t_train - t_epoch
        total_train_time += train_time

        # Validation
        if cfg["validate"]:
            with Range("Validation"):
                valid_loss, valid_acc = validation(
                    model,
                    loss_func,
                    cfg["amp"],
                    valid_dataloader,
                    preprocess_gpu_valid,
                    postprocess,
                    device,
                    cfg["print_step"],
                )
            t_valid = time.perf_counter()
            valid_time = t_valid - t_train
            total_valid_time += valid_time
            if valid_loss < metric_summary["loss"]:
                metric_summary["loss"] = min(valid_loss,
                                             metric_summary["loss"])
                metric_summary["accuracy"] = max(valid_acc,
                                                 metric_summary["accuracy"])
                metric_summary["best_epoch"] = train_counter["epoch"]
            writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"])
            writer.add_scalar("valid/accuracy", valid_acc,
                              train_counter["epoch"])

            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, "
                f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)"
            )
        else:
            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s"
            )
        writer.flush()
    t_end = time.perf_counter()

    # Save final metrics
    metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"]
    metric_summary["total_time"] = t_end - t_start
    writer.add_hparams(hparam_dict=cfg,
                       metric_dict=metric_summary,
                       run_name=log_dir)
    writer.close()
    logging.info(f"Metric Summary: {metric_summary}")

    # Save the best and final model
    if cfg["validate"] is True:
        copyfile(
            os.path.join(log_dir,
                         f"model_epoch_{metric_summary['best_epoch']}.pth"),
            os.path.join(log_dir, "model_best.pth"),
        )
        copyfile(
            os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"),
            os.path.join(log_dir, "model_final.pth"),
        )

    # Final prints
    logging.info(
        f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s "
        f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)",
    )
    logging.info(f"Logs and model was saved at: {log_dir}")
Ejemplo n.º 6
0
def train(cfg):
    log_dir = create_log_dir(cfg)
    device = set_device(cfg)
    # --------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # Build MONAI preprocessing
    train_preprocess = Compose([
        ToTensorD(keys="image"),
        TorchVisionD(keys="image",
                     name="ColorJitter",
                     brightness=64.0 / 255.0,
                     contrast=0.75,
                     saturation=0.25,
                     hue=0.04),
        ToNumpyD(keys="image"),
        RandFlipD(keys="image", prob=0.5),
        RandRotate90D(keys="image", prob=0.5),
        CastToTypeD(keys="image", dtype=np.float32),
        RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    valid_preprocess = Compose([
        CastToTypeD(keys="image", dtype=np.float32),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    # __________________________________________________________________________
    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )

    train_dataset = PatchWSIDataset(
        train_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        train_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        valid_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        valid_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # __________________________________________________________________________
    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)

    # __________________________________________________________________________
    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("Fist sample is None!")

    print("image: ")
    print("    shape", first_sample["image"].shape)
    print("    type: ", type(first_sample["image"]))
    print("    dtype: ", first_sample["image"].dtype)
    print("labels: ")
    print("    shape", first_sample["label"].shape)
    print("    type: ", type(first_sample["label"]))
    print("    dtype: ", first_sample["label"].dtype)
    print(f"batch size: {cfg['batch_size']}")
    print(f"train number of batches: {len(train_dataloader)}")
    print(f"valid number of batches: {len(valid_dataloader)}")

    # --------------------------------------------------------------------------
    # Deep Learning Classification Model
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model
    model = TorchVisionFCModel("resnet18",
                               num_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # optimizer
    if cfg["novograd"]:
        optimizer = Novograd(model.parameters(), cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    if cfg["amp"]:
        cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= (
            1, 6) else False
    else:
        cfg["amp"] = False

    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=cfg["n_epochs"])

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(save_dir=log_dir,
                        save_dict={"net": model},
                        save_key_metric=True),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir,
                                output_transform=lambda x: None),
    ]
    val_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=valid_dataloader,
        network=model,
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        CheckpointSaver(save_dir=cfg["logdir"],
                        save_dict={
                            "net": model,
                            "opt": optimizer
                        },
                        save_interval=1,
                        epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        TensorBoardStatsHandler(log_dir=cfg["logdir"],
                                tag_name="train_loss",
                                output_transform=from_engine(["loss"],
                                                             first=True)),
    ]
    train_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_dataloader,
        network=model,
        optimizer=optimizer,
        loss_function=loss_func,
        postprocessing=train_postprocessing,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()