示例#1
0
    def test_compute(self):
        auc_metric = ROCAUC()
        act = Activations(softmax=True)
        to_onehot = AsDiscrete(to_onehot=2)

        device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available(
        ) else "cpu"
        if dist.get_rank() == 0:
            y_pred = [
                torch.tensor([0.1, 0.9], device=device),
                torch.tensor([0.3, 1.4], device=device)
            ]
            y = [
                torch.tensor([0], device=device),
                torch.tensor([1], device=device)
            ]

        if dist.get_rank() == 1:
            y_pred = [
                torch.tensor([0.2, 0.1], device=device),
                torch.tensor([0.1, 0.5], device=device),
                torch.tensor([0.3, 0.4], device=device),
            ]
            y = [
                torch.tensor([0], device=device),
                torch.tensor([1], device=device),
                torch.tensor([1], device=device)
            ]

        y_pred = [act(p) for p in y_pred]
        y = [to_onehot(y_) for y_ in y]
        auc_metric.update([y_pred, y])

        result = auc_metric.compute()
        np.testing.assert_allclose(0.66667, result, rtol=1e-4)
    def test_compute(self):
        auc_metric = ROCAUC()
        act = Activations(softmax=True)
        to_onehot = AsDiscrete(to_onehot=True, n_classes=2)

        device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available(
        ) else "cpu"
        if dist.get_rank() == 0:
            y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device)
            y = torch.tensor([[0], [1]], device=device)

        if dist.get_rank() == 1:
            y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5], [0.3, 0.4]],
                                  device=device)
            y = torch.tensor([[0], [1], [1]], device=device)

        y_pred = act(y_pred)
        y = to_onehot(y)
        auc_metric.update([y_pred, y])

        result = auc_metric.compute()
        np.testing.assert_allclose(0.66667, result, rtol=1e-4)
示例#3
0
    def test_compute(self):
        auc_metric = ROCAUC(to_onehot_y=True, softmax=True)

        y_pred = torch.Tensor([[0.1, 0.9], [0.3, 1.4]])
        y = torch.Tensor([[0], [1]])
        auc_metric.update([y_pred, y])

        y_pred = torch.Tensor([[0.2, 0.1], [0.1, 0.5]])
        y = torch.Tensor([[0], [1]])
        auc_metric.update([y_pred, y])

        auc = auc_metric.compute()
        np.testing.assert_allclose(0.75, auc)
示例#4
0
    def __init__(
        self,
        device: torch.device,
        val_data_loader: Union[Iterable, DataLoader],
        network: torch.nn.Module,
        loss_function,
        n_classes,
        patience=20,
        summary_writer: SummaryWriter = None,
        non_blocking: bool = False,
        post_transform: Optional[Transform] = None,
        amp: bool = False,
        mode: Union[ForwardMode, str] = ForwardMode.EVAL,
    ) -> None:
        self.summary_writer = summary_writer
        self.early_stop_handler = EarlyStopHandler(
            patience=patience,
            score_function=lambda engine: engine.state.metrics[
                engine.state.key_metric_name])

        if n_classes > 1:
            to_onehot = AsDiscrete(to_onehot=True, n_classes=2)
        else:
            to_onehot = lambda x: x

        super().__init__(
            device,
            val_data_loader,
            network,
            non_blocking=non_blocking,
            iteration_update=self._iteration,
            post_transform=post_transform,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(average="micro",
                       output_transform=lambda x:
                       (x["pred"], to_onehot(x["label"])))
            },
            additional_metrics={
                "Valid_ACC":
                Accuracy(output_transform=lambda x: (AsDiscrete(
                    threshold_values=True)(x["pred"]), to_onehot(x["label"]))),
                "Valid_Loss":
                Loss(loss_fn=loss_function,
                     output_transform=lambda x: (x["pred"], x["label"])),
            },
            amp=amp,
            mode=mode)
示例#5
0
    def __init__(
        self,
        device: torch.device,
        test_data_loader: Union[Iterable, DataLoader],
        network: torch.nn.Module,
        load_dir: str,
        out_dir: str,
        n_classes,
        non_blocking: bool = False,
        post_transform: Optional[Transform] = None,
        amp: bool = False,
        mode: Union[ForwardMode, str] = ForwardMode.EVAL,
    ) -> None:
        self.load_dir = load_dir
        self.out_dir = out_dir

        if n_classes > 1:
            to_onehot = AsDiscrete(to_onehot=True, n_classes=2)
        else:
            to_onehot = lambda x: x
        super().__init__(
            device,
            test_data_loader,
            network,
            non_blocking=non_blocking,
            post_transform=post_transform,
            key_val_metric={
                "Test_AUC":
                ROCAUC(average="micro",
                       output_transform=lambda x:
                       (x["pred"], to_onehot(x["label"])))
            },
            additional_metrics={
                "Test_ACC":
                Accuracy(output_transform=lambda x: (AsDiscrete(
                    threshold_values=True)(x["pred"]), to_onehot(x["label"])))
            },
            amp=amp,
            mode=mode)

        load_path = glob(os.path.join(self.load_dir, 'network_key_metric*'))[0]
        handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointLoader(load_path=load_path,
                             load_dict={"network": self.network}),
        ]
        self._register_handlers(handlers)
示例#6
0
    def test_compute(self):
        auc_metric = ROCAUC(to_onehot_y=True, softmax=True)
        device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available(
        ) else "cpu"
        if dist.get_rank() == 0:
            y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device)
            y = torch.tensor([[0], [1]], device=device)
            auc_metric.update([y_pred, y])

        if dist.get_rank() == 1:
            y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=device)
            y = torch.tensor([[0], [1]], device=device)
            auc_metric.update([y_pred, y])

        result = auc_metric.compute()
        np.testing.assert_allclose(0.75, result)
示例#7
0
    def test_compute(self):
        auc_metric = ROCAUC()
        act = Activations(softmax=True)
        to_onehot = AsDiscrete(to_onehot=True, n_classes=2)

        y_pred = torch.Tensor([[0.1, 0.9], [0.3, 1.4]])
        y = torch.Tensor([[0], [1]])
        y_pred = act(y_pred)
        y = to_onehot(y)
        auc_metric.update([y_pred, y])

        y_pred = torch.Tensor([[0.2, 0.1], [0.1, 0.5]])
        y = torch.Tensor([[0], [1]])
        y_pred = act(y_pred)
        y = to_onehot(y)
        auc_metric.update([y_pred, y])

        auc = auc_metric.compute()
        np.testing.assert_allclose(0.75, auc)
示例#8
0
    def test_compute(self):
        auc_metric = ROCAUC()
        act = Activations(softmax=True)
        to_onehot = AsDiscrete(to_onehot=True, num_classes=2)

        y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]
        y = [torch.Tensor([0]), torch.Tensor([1])]
        y_pred = [act(p) for p in y_pred]
        y = [to_onehot(y_) for y_ in y]
        auc_metric.update([y_pred, y])

        y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]
        y = [torch.Tensor([0]), torch.Tensor([1])]
        y_pred = [act(p) for p in y_pred]
        y = [to_onehot(y_) for y_ in y]

        auc_metric.update([y_pred, y])

        auc = auc_metric.compute()
        np.testing.assert_allclose(0.75, auc)
示例#9
0
def main():
    dist.init_process_group(backend="nccl", init_method="env://")

    auc_metric = ROCAUC(to_onehot_y=True, softmax=True)

    if dist.get_rank() == 0:
        y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]],
                              device=torch.device("cuda:0"))
        y = torch.tensor([[0], [1]], device=torch.device("cuda:0"))
        auc_metric.update([y_pred, y])

    if dist.get_rank() == 1:
        y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]],
                              device=torch.device("cuda:1"))
        y = torch.tensor([[0], [1]], device=torch.device("cuda:1"))
        auc_metric.update([y_pred, y])

    result = auc_metric.compute()
    np.testing.assert_allclose(0.75, result)

    dist.destroy_process_group()
示例#10
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # define transforms for image
    train_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    )
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):

        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {
        metric_name: Accuracy(),
        "AUC": ROCAUC(to_onehot_y=True, add_softmax=True)
    }
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
示例#11
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
示例#12
0
# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
# and can use output_transform to convert engine.state.output if it's not loss value
train_stats_handler = StatsHandler(name='trainer')
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler()
train_tensorboard_stats_handler.attach(trainer)

# set parameters for validation
validation_every_n_epochs = 1

metric_name = 'Accuracy'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: Accuracy(), 'AUC': ROCAUC(to_onehot_y=True, add_softmax=True)}
# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch)

# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
val_stats_handler.attach(evaluator)

# add handler to record metrics to TensorBoard at every epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
    output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
示例#13
0
# and can use output_transform to convert engine.state.output if it's not loss value
train_stats_handler = StatsHandler(name='trainer')
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler()
train_tensorboard_stats_handler.attach(trainer)

# set parameters for validation
validation_every_n_epochs = 1

metric_name = 'Accuracy'
# add evaluation metric to the evaluator engine
val_metrics = {
    metric_name: Accuracy(),
    'AUC': ROCAUC(to_onehot_y=True, add_softmax=True)
}
# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net,
                                        val_metrics,
                                        device,
                                        True,
                                        prepare_batch=prepare_batch)

# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None,  # no need to print loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch