def test_nvtx_transfroms_dict(self, input):
     # with prob == 0.0
     transforms = Compose([
         RandMarkD("Mark: Transforms (p=0) Start!"),
         RandRangePushD("Range: RandFlipD"),
         RandFlipD(keys="image", prob=0.0),
         RandRangePopD(),
         RangePushD("Range: ToTensorD"),
         ToTensorD(keys=("image")),
         RangePopD(),
         MarkD("Mark: Transforms (p=0) End!"),
     ])
     output = transforms(input)
     self.assertIsInstance(output["image"], torch.Tensor)
     np.testing.assert_array_equal(input["image"], output["image"])
     # with prob == 1.0
     transforms = Compose([
         RandMarkD("Mark: Transforms (p=1) Start!"),
         RandRangePushD("Range: RandFlipD"),
         RandFlipD(keys="image", prob=1.0),
         RandRangePopD(),
         RangePushD("Range: ToTensorD"),
         ToTensorD(keys=("image")),
         RangePopD(),
         MarkD("Mark: Transforms (p=1) End!"),
     ])
     output = transforms(input)
     self.assertIsInstance(output["image"], torch.Tensor)
     np.testing.assert_array_equal(input["image"],
                                   Flip()(output["image"].numpy()))
    def test_tranform_dict(self, input):
        transforms = Compose([
            Range("random flip dict")(FlipD(keys="image")),
            Range()(ToTensorD("image"))
        ])
        # Apply transforms
        output = transforms(input)["image"]

        # Decorate with NVTX Range
        transforms1 = Range()(transforms)
        transforms2 = Range("Transforms2")(transforms)
        transforms3 = Range(name="Transforms3", methods="__call__")(transforms)

        # Apply transforms with Range
        output1 = transforms1(input)["image"]
        output2 = transforms2(input)["image"]
        output3 = transforms3(input)["image"]

        # Check the outputs
        self.assertIsInstance(output, torch.Tensor)
        self.assertIsInstance(output1, torch.Tensor)
        self.assertIsInstance(output2, torch.Tensor)
        self.assertIsInstance(output3, torch.Tensor)
        np.testing.assert_equal(output.numpy(), output1.numpy())
        np.testing.assert_equal(output.numpy(), output2.numpy())
        np.testing.assert_equal(output.numpy(), output3.numpy())
Ejemplo n.º 3
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    set_determinism(12345)
    device = torch.device("cuda:0")

    # load real data
    mednist_url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
    md5_value = "0bc7306e7427e00ad1c5526a6677552d"
    extract_dir = "data"
    tar_save_path = os.path.join(extract_dir, "MedNIST.tar.gz")
    download_and_extract(mednist_url, tar_save_path, extract_dir, md5_value)
    hand_dir = os.path.join(extract_dir, "MedNIST", "Hand")
    real_data = [{
        "hand": os.path.join(hand_dir, filename)
    } for filename in os.listdir(hand_dir)]

    # define real data transforms
    train_transforms = Compose([
        LoadPNGD(keys=["hand"]),
        AddChannelD(keys=["hand"]),
        ScaleIntensityD(keys=["hand"]),
        RandRotateD(keys=["hand"], range_x=15, prob=0.5, keep_size=True),
        RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5),
        RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensorD(keys=["hand"]),
    ])

    # create dataset and dataloader
    real_dataset = CacheDataset(real_data, train_transforms)
    batch_size = 300
    real_dataloader = DataLoader(real_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=10)

    # define function to process batchdata for input into discriminator
    def prepare_batch(batchdata):
        """
        Process Dataloader batchdata dict object and return image tensors for D Inferer
        """
        return batchdata["hand"]

    # define networks
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)

    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])

    # initialize both networks
    disc_net.apply(normal_init)
    gen_net.apply(normal_init)

    # input images are scaled to [0,1] so enforce the same of generated outputs
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)

    # create optimizers and loss functions
    learning_rate = 2e-4
    betas = (0.5, 0.999)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)

    disc_loss_criterion = torch.nn.BCELoss()
    gen_loss_criterion = torch.nn.BCELoss()
    real_label = 1
    fake_label = 0

    def discriminator_loss(gen_images, real_images):
        """
        The discriminator loss is calculated by comparing D
        prediction for real and generated images.

        """
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)

        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)

        return (genloss + realloss) / 2

    def generator_loss(gen_images):
        """
        The generator loss is calculated by determining how realistic
        the discriminator classifies the generated images.

        """
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    # initialize current run dir
    run_dir = "model_out"
    print("Saving model output to: %s " % run_dir)

    # create workflow handlers
    handlers = [
        StatsHandler(
            name="batch_training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(
            save_dir=run_dir,
            save_dict={
                "g_net": gen_net,
                "d_net": disc_net
            },
            save_interval=10,
            save_final=True,
            epoch_level=True,
        ),
    ]

    # define key metric
    key_train_metric = None

    # create adversarial trainer
    disc_train_steps = 5
    num_epochs = 50

    trainer = GanTrainer(
        device,
        num_epochs,
        real_dataloader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_prepare_batch=prepare_batch,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=handlers,
    )

    # run GAN training
    trainer.run()

    # Training completed, save a few random generated images.
    print("Saving trained generator sample output.")
    test_img_count = 10
    test_latents = make_latent(test_img_count, latent_size).to(device)
    fakes = gen_net(test_latents)
    for i, image in enumerate(fakes):
        filename = "gen-fake-final-%d.png" % (i)
        save_path = os.path.join(run_dir, filename)
        img_array = image[0].cpu().data.numpy()
        png_writer.write_png(img_array, save_path, scale=255)
train_datadict = [{"im": fname} for fname in all_filenames[:num_train]]
test_datadict = [{"im": fname} for fname in all_filenames[-num_test:]]
# test_datadict = [{"im": fname} for fname in all_testNames[:len( all_testNames ) ] ]

print(f"total number of images: {num_ims}")
print(f"number of images for training: {len(train_datadict)}")
print(f"number of images for testing: {len(test_datadict)}")

batch_size = 16
num_workers = 0

transforms = Compose([
    LoadImageD(keys=["im"]),
    AddChannelD(keys=["im"]),
    ScaleIntensityD(keys=["im"]),
    ToTensorD(keys=["im"]),
])

train_ds = CacheDataset(train_datadict, transforms, num_workers=num_workers)
train_loader = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers)
test_ds = CacheDataset(test_datadict, transforms, num_workers=num_workers)
test_loader = DataLoader(test_ds,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=num_workers)

# For our loss we'll want to use a combination of a reconstruction loss (here, BCE) and KLD. By increasing the importance of the KLD loss with beta, we encourage the network to disentangle the latent generative factors.
Ejemplo n.º 5
0
def main(cfg):
    # -------------------------------------------------------------------------
    # Configs
    # -------------------------------------------------------------------------
    # Create log/model dir
    log_dir = create_log_dir(cfg)

    # Set the logger
    logging.basicConfig(
        format="%(asctime)s %(levelname)2s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    log_name = os.path.join(log_dir, "logs.txt")
    logger = logging.getLogger()
    fh = logging.FileHandler(log_name)
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Set TensorBoard summary writer
    writer = SummaryWriter(log_dir)

    # Save configs
    logging.info(json.dumps(cfg))
    with open(os.path.join(log_dir, "config.json"), "w") as fp:
        json.dump(cfg, fp, indent=4)

    # Set device cuda/cpu
    device = set_device(cfg)

    # Set cudnn benchmark/deterministic
    if cfg["benchmark"]:
        torch.backends.cudnn.benchmark = True
    else:
        set_determinism(seed=0)
    # -------------------------------------------------------------------------
    # Transforms and Datasets
    # -------------------------------------------------------------------------
    # Pre-processing
    preprocess_cpu_train = None
    preprocess_gpu_train = None
    preprocess_cpu_valid = None
    preprocess_gpu_valid = None
    if cfg["backend"] == "cucim":
        preprocess_cpu_train = Compose([ToTensorD(keys="label")])
        preprocess_gpu_train = Compose([
            Range()(ToCupy()),
            Range("ColorJitter")(RandCuCIM(name="color_jitter",
                                           brightness=64.0 / 255.0,
                                           contrast=0.75,
                                           saturation=0.25,
                                           hue=0.04)),
            Range("RandomFlip")(RandCuCIM(name="image_flip",
                                          apply_prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90",
                                              prob=cfg["prob"],
                                              max_k=3,
                                              spatial_axis=(-2, -1))),
            Range()(CastToType(dtype=np.float32)),
            Range("RandomZoom")(RandCuCIM(name="rand_zoom",
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(CuCIM(name="scale_intensity_range",
                                          a_min=0.0,
                                          a_max=255.0,
                                          b_min=-1.0,
                                          b_max=1.0)),
            Range()(ToTensor(device=device)),
        ])
        preprocess_cpu_valid = Compose([ToTensorD(keys="label")])
        preprocess_gpu_valid = Compose([
            Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)),
            Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range",
                                               a_min=0.0,
                                               a_max=255.0,
                                               b_min=-1.0,
                                               b_max=1.0)),
            Range("ValidToTensor")(ToTensor(device=device)),
        ])
    elif cfg["backend"] == "numpy":
        preprocess_cpu_train = Compose([
            Range()(ToTensorD(keys=("image", "label"))),
            Range("ColorJitter")(TorchVisionD(
                keys="image",
                name="ColorJitter",
                brightness=64.0 / 255.0,
                contrast=0.75,
                saturation=0.25,
                hue=0.04,
            )),
            Range()(ToNumpyD(keys="image")),
            Range("RandomFlip")(RandFlipD(keys="image",
                                          prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandRotate90D(keys="image",
                                                  prob=cfg["prob"])),
            Range()(CastToTypeD(keys="image", dtype=np.float32)),
            Range("RandomZoom")(RandZoomD(keys="image",
                                          prob=cfg["prob"],
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                         a_min=0.0,
                                                         a_max=255.0,
                                                         b_min=-1.0,
                                                         b_max=1.0)),
            Range()(ToTensorD(keys="image")),
        ])
        preprocess_cpu_valid = Compose([
            Range("ValidCastType")(CastToTypeD(keys="image",
                                               dtype=np.float32)),
            Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                              a_min=0.0,
                                                              a_max=255.0,
                                                              b_min=-1.0,
                                                              b_max=1.0)),
            Range("ValidToTensor")(ToTensorD(keys=("image", "label"))),
        ])
    else:
        raise ValueError(
            f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]"
        )

    # Post-processing
    postprocess = Compose([
        Activations(sigmoid=True),
        AsDiscrete(threshold=0.5),
    ])

    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )
    train_dataset = PatchWSIDataset(
        data=train_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_train,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        data=valid_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_valid,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])

    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("First sample is None!")
    for d in ["image", "label"]:
        logging.info(f"[{d}] \n"
                     f"  {d} shape: {first_sample[d].shape}\n"
                     f"  {d} type:  {type(first_sample[d])}\n"
                     f"  {d} dtype: {first_sample[d].dtype}")
    logging.info(f"Batch size: {cfg['batch_size']}")
    logging.info(f"[Training] number of batches: {len(train_dataloader)}")
    logging.info(f"[Validation] number of batches: {len(valid_dataloader)}")
    # -------------------------------------------------------------------------
    # Deep Learning Model and Configurations
    # -------------------------------------------------------------------------
    # Initialize model
    model = TorchVisionFCModel("resnet18",
                               n_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # Loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # Optimizer
    if cfg["novograd"] is True:
        optimizer = Novograd(model.parameters(), lr=cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6)
    if cfg["amp"] is True:
        scaler = GradScaler()
    else:
        scaler = None

    # Learning rate scheduler
    if cfg["cos"] is True:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=cfg["n_epochs"])
    else:
        scheduler = None

    # -------------------------------------------------------------------------
    # Training/Evaluating
    # -------------------------------------------------------------------------
    train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1}

    total_valid_time, total_train_time = 0.0, 0.0
    t_start = time.perf_counter()
    metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1}
    # Training/Validation Loop
    for _ in range(cfg["n_epochs"]):
        t_epoch = time.perf_counter()
        logging.info(
            f"[Training] learning rate: {optimizer.param_groups[0]['lr']}")

        # Training
        with Range("Training Epoch"):
            train_counter = training(
                train_counter,
                model,
                loss_func,
                optimizer,
                scaler,
                cfg["amp"],
                train_dataloader,
                preprocess_gpu_train,
                postprocess,
                device,
                writer,
                cfg["print_step"],
            )
        if scheduler is not None:
            scheduler.step()
        if cfg["save"]:
            torch.save(
                model.state_dict(),
                os.path.join(log_dir,
                             f"model_epoch_{train_counter['epoch']}.pt"))
        t_train = time.perf_counter()
        train_time = t_train - t_epoch
        total_train_time += train_time

        # Validation
        if cfg["validate"]:
            with Range("Validation"):
                valid_loss, valid_acc = validation(
                    model,
                    loss_func,
                    cfg["amp"],
                    valid_dataloader,
                    preprocess_gpu_valid,
                    postprocess,
                    device,
                    cfg["print_step"],
                )
            t_valid = time.perf_counter()
            valid_time = t_valid - t_train
            total_valid_time += valid_time
            if valid_loss < metric_summary["loss"]:
                metric_summary["loss"] = min(valid_loss,
                                             metric_summary["loss"])
                metric_summary["accuracy"] = max(valid_acc,
                                                 metric_summary["accuracy"])
                metric_summary["best_epoch"] = train_counter["epoch"]
            writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"])
            writer.add_scalar("valid/accuracy", valid_acc,
                              train_counter["epoch"])

            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, "
                f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)"
            )
        else:
            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s"
            )
        writer.flush()
    t_end = time.perf_counter()

    # Save final metrics
    metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"]
    metric_summary["total_time"] = t_end - t_start
    writer.add_hparams(hparam_dict=cfg,
                       metric_dict=metric_summary,
                       run_name=log_dir)
    writer.close()
    logging.info(f"Metric Summary: {metric_summary}")

    # Save the best and final model
    if cfg["validate"] is True:
        copyfile(
            os.path.join(log_dir,
                         f"model_epoch_{metric_summary['best_epoch']}.pth"),
            os.path.join(log_dir, "model_best.pth"),
        )
        copyfile(
            os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"),
            os.path.join(log_dir, "model_final.pth"),
        )

    # Final prints
    logging.info(
        f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s "
        f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)",
    )
    logging.info(f"Logs and model was saved at: {log_dir}")
Ejemplo n.º 6
0
def train(cfg):
    log_dir = create_log_dir(cfg)
    device = set_device(cfg)
    # --------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # Build MONAI preprocessing
    train_preprocess = Compose([
        ToTensorD(keys="image"),
        TorchVisionD(keys="image",
                     name="ColorJitter",
                     brightness=64.0 / 255.0,
                     contrast=0.75,
                     saturation=0.25,
                     hue=0.04),
        ToNumpyD(keys="image"),
        RandFlipD(keys="image", prob=0.5),
        RandRotate90D(keys="image", prob=0.5),
        CastToTypeD(keys="image", dtype=np.float32),
        RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    valid_preprocess = Compose([
        CastToTypeD(keys="image", dtype=np.float32),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    # __________________________________________________________________________
    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )

    train_dataset = PatchWSIDataset(
        train_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        train_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        valid_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        valid_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # __________________________________________________________________________
    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)

    # __________________________________________________________________________
    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("Fist sample is None!")

    print("image: ")
    print("    shape", first_sample["image"].shape)
    print("    type: ", type(first_sample["image"]))
    print("    dtype: ", first_sample["image"].dtype)
    print("labels: ")
    print("    shape", first_sample["label"].shape)
    print("    type: ", type(first_sample["label"]))
    print("    dtype: ", first_sample["label"].dtype)
    print(f"batch size: {cfg['batch_size']}")
    print(f"train number of batches: {len(train_dataloader)}")
    print(f"valid number of batches: {len(valid_dataloader)}")

    # --------------------------------------------------------------------------
    # Deep Learning Classification Model
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model
    model = TorchVisionFCModel("resnet18",
                               num_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # optimizer
    if cfg["novograd"]:
        optimizer = Novograd(model.parameters(), cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    if cfg["amp"]:
        cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= (
            1, 6) else False
    else:
        cfg["amp"] = False

    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=cfg["n_epochs"])

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(save_dir=log_dir,
                        save_dict={"net": model},
                        save_key_metric=True),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir,
                                output_transform=lambda x: None),
    ]
    val_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=valid_dataloader,
        network=model,
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        CheckpointSaver(save_dir=cfg["logdir"],
                        save_dict={
                            "net": model,
                            "opt": optimizer
                        },
                        save_interval=1,
                        epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        TensorBoardStatsHandler(log_dir=cfg["logdir"],
                                tag_name="train_loss",
                                output_transform=from_engine(["loss"],
                                                             first=True)),
    ]
    train_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_dataloader,
        network=model,
        optimizer=optimizer,
        loss_function=loss_func,
        postprocessing=train_postprocessing,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()