Exemplo n.º 1
0
def pad_list_data_collate(
    batch: Sequence,
    method: Union[Method, str] = Method.SYMMETRIC,
    mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
    **np_kwargs,
):
    """
    Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`.

    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
    tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
    different sizes.

    This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added
    to the list of invertible transforms.

    The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`.

    Args:
        batch: batch of data to pad-collate
        method: padding method (see :py:class:`monai.transforms.SpatialPad`)
        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
        np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
            more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

    """
    from monai.transforms.croppad.batch import PadListDataCollate  # needs to be here to avoid circular import

    return PadListDataCollate(method=method, mode=mode, **np_kwargs)(batch)
Exemplo n.º 2
0
    def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]:
        data = dict(self.data[index])
        # If pad collation was used, then we need to undo this first
        if self.pad_collation_used:
            data = PadListDataCollate.inverse(data)

        return self.invertible_transform.inverse(data)
Exemplo n.º 3
0
def pad_list_data_collate(
    batch: Sequence,
    method: Union[Method, str] = Method.SYMMETRIC,
    mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
):
    """
    Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`.

    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
    tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
    different sizes.

    This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added
    to the list of invertible transforms.

    The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`.

    Args:
        batch: batch of data to pad-collate
        method: padding method (see :py:class:`monai.transforms.SpatialPad`)
        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
    """
    from monai.transforms.croppad.batch import PadListDataCollate  # needs to be here to avoid circular import

    return PadListDataCollate(method, mode)(batch)
Exemplo n.º 4
0
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float],
               NdarrayOrTensor]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are
                calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC)
                is `std/mean` across the whole output, including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
                concatenating across the first dimension containing `num_examples`. This allows the user to perform
                their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [deepcopy(d) for _ in range(num_examples)]
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        num_workers=self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        outs: List = []

        for b in tqdm(dl) if has_tqdm and self.progress else dl:
            # do model forward pass
            b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(
                self.device))
            outs.extend([
                self.inverter(PadListDataCollate.inverse(i))[self._pred_key]
                for i in decollate_batch(b)
            ])

        output: NdarrayOrTensor = stack(outs, 0)

        if self.return_full_data:
            return output

        # calculate metrics
        _mode = mode(output, dim=0)
        mean = output.mean(0)
        std = output.std(0)
        vvc = (output.std() / output.mean()).item()

        return _mode, mean, std, vvc
Exemplo n.º 5
0
    def __getitem__(self, index: int):
        data = dict(self.data[index])
        # If pad collation was used, then we need to undo this first
        if self.pad_collation_used:
            data = PadListDataCollate.inverse(data)

        if not isinstance(self.invertible_transform, InvertibleTransform):
            warnings.warn("transform is not invertible, can't invert transform for the input data.")
            return data
        return self.invertible_transform.inverse(data)
Exemplo n.º 6
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
Exemplo n.º 7
0
Arquivo: test.py Projeto: ckbr0/RIS
def main(train_output):
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # Setup directories
    dirs = setup_directories()

    # Setup torch device
    device, using_gpu = create_device("cuda")

    # Load and randomize images

    # HACKATON image and segmentation data
    hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
    map_fn = lambda x: (x[0], int(x[1]))
    with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
        train_info_hackathon = [
            map_fn(entry.strip().split(',')) for entry in fp.readlines()
        ]
    image_dir = os.path.join(hackathon_dir, 'images', 'train')
    seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train')
    _train_data_hackathon = get_data_from_info(image_dir,
                                               seg_dir,
                                               train_info_hackathon,
                                               dual_output=False)
    large_image_splitter(_train_data_hackathon, dirs["cache"])

    balance_training_data(_train_data_hackathon, seed=72)

    # PSUF data
    """psuf_dir = os.path.join(dirs["data"], 'psuf')
    with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp:
        train_info = [entry.strip().split(',') for entry in fp.readlines()]
    image_dir = os.path.join(psuf_dir, 'images')
    train_data_psuf = get_data_from_info(image_dir, None, train_info)"""
    # Split data into train, validate and test
    train_split, test_data_hackathon = train_test_split(_train_data_hackathon,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)
    #train_data_hackathon, valid_data_hackathon = train_test_split(train_split, test_size=0.2, shuffle=True, random_state=43)
    # Setup transforms

    # Crop foreground
    crop_foreground = CropForegroundd(
        keys=["image"],
        source_key="image",
        margin=(5, 5, 0),
        #select_fn = lambda x: x != 0
    )
    # Crop Z
    crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12))
    # Window width and level (window center)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
    resize = Resized(keys=["image"],
                     spatial_size=(int(512 * 0.50), int(512 * 0.50), -1),
                     mode="trilinear")

    # Create transforms
    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        resize,
        crop_foreground,
        crop_z,
        spatial_pad,
    ])
    hackathon_train_transfrom = Compose([
        common_transform,
        ToTensord(keys=["image"]),
    ]).flatten()
    psuf_transforms = Compose([
        LoadImaged(keys=["image"]),
        AddChanneld(keys=["image"]),
        ToTensord(keys=["image"]),
    ])

    # Setup data
    #set_determinism(seed=100)
    test_dataset = PersistentDataset(data=test_data_hackathon[:],
                                     transform=hackathon_train_transfrom,
                                     cache_dir=dirs["persistent"])
    test_loader = DataLoader(test_dataset,
                             batch_size=2,
                             shuffle=True,
                             pin_memory=using_gpu,
                             num_workers=1,
                             collate_fn=PadListDataCollate(
                                 Method.SYMMETRIC, NumpyPadMode.CONSTANT))

    # Setup network, loss function, optimizer and scheduler
    network = nets.DenseNet121(spatial_dims=3, in_channels=1,
                               out_channels=1).to(device)

    # Setup validator and trainer
    valid_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
    ])

    # Setup tester
    tester = Tester(device=device,
                    test_data_loader=test_loader,
                    load_dir=train_output,
                    out_dir=dirs["out"],
                    network=network,
                    post_transform=valid_post_transforms,
                    non_blocking=using_gpu,
                    amp=using_gpu)

    # Run tester
    tester.run()
Exemplo n.º 8
0
def main():
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # Setup directories
    dirs = setup_directories()

    # Setup torch device
    device, using_gpu = create_device("cuda")

    # Load and randomize images

    # HACKATON image and segmentation data
    hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
    map_fn = lambda x: (x[0], int(x[1]))
    with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
        train_info_hackathon = [
            map_fn(entry.strip().split(',')) for entry in fp.readlines()
        ]
    image_dir = os.path.join(hackathon_dir, 'images', 'train')
    seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train')
    _train_data_hackathon = get_data_from_info(image_dir,
                                               seg_dir,
                                               train_info_hackathon,
                                               dual_output=False)
    _train_data_hackathon = large_image_splitter(_train_data_hackathon,
                                                 dirs["cache"])
    copy_list = transform_and_copy(_train_data_hackathon, dirs['cache'])
    balance_training_data2(_train_data_hackathon, copy_list, seed=72)

    # PSUF data
    """psuf_dir = os.path.join(dirs["data"], 'psuf')
    with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp:
        train_info = [entry.strip().split(',') for entry in fp.readlines()]
    image_dir = os.path.join(psuf_dir, 'images')
    train_data_psuf = get_data_from_info(image_dir, None, train_info)"""
    # Split data into train, validate and test
    train_split, test_data_hackathon = train_test_split(_train_data_hackathon,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)
    train_data_hackathon, valid_data_hackathon = train_test_split(
        train_split, test_size=0.2, shuffle=True, random_state=43)

    #balance_training_data(train_data_hackathon, seed=72)
    #balance_training_data(valid_data_hackathon, seed=73)
    #balance_training_data(test_data_hackathon, seed=74)
    # Setup transforms

    # Crop foreground
    crop_foreground = CropForegroundd(keys=["image"],
                                      source_key="image",
                                      margin=(5, 5, 0),
                                      select_fn=lambda x: x != 0)
    # Crop Z
    crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12))
    # Window width and level (window center)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    # Random axis flip
    rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50)
    rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50)
    rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50)
    # Rand affine transform
    rand_affine = RandAffined(keys=["image"],
                              prob=0.5,
                              rotate_range=(0, 0, np.pi / 12),
                              shear_range=(0.07, 0.07, 0.0),
                              translate_range=(0, 0, 0),
                              scale_range=(0.07, 0.07, 0.0),
                              padding_mode="zeros")
    # Pad image to have hight at least 30
    spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
    resize = Resized(keys=["image"],
                     spatial_size=(int(512 * 0.50), int(512 * 0.50), -1),
                     mode="trilinear")
    # Apply Gaussian noise
    rand_gaussian_noise = RandGaussianNoised(keys=["image"],
                                             prob=0.25,
                                             mean=0.0,
                                             std=0.1)

    # Create transforms
    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        resize,
        crop_foreground,
        crop_z,
        spatial_pad,
    ])
    hackathon_train_transform = Compose([
        common_transform,
        rand_x_flip,
        rand_y_flip,
        rand_z_flip,
        rand_affine,
        rand_gaussian_noise,
        ToTensord(keys=["image"]),
    ]).flatten()
    hackathon_valid_transfrom = Compose([
        common_transform,
        #rand_x_flip,
        #rand_y_flip,
        #rand_z_flip,
        #rand_affine,
        ToTensord(keys=["image"]),
    ]).flatten()
    hackathon_test_transfrom = Compose([
        common_transform,
        ToTensord(keys=["image"]),
    ]).flatten()
    psuf_transforms = Compose([
        LoadImaged(keys=["image"]),
        AddChanneld(keys=["image"]),
        ToTensord(keys=["image"]),
    ])

    # Setup data
    #set_determinism(seed=100)
    train_dataset = PersistentDataset(data=train_data_hackathon[:],
                                      transform=hackathon_train_transform,
                                      cache_dir=dirs["persistent"])
    valid_dataset = PersistentDataset(data=valid_data_hackathon[:],
                                      transform=hackathon_valid_transfrom,
                                      cache_dir=dirs["persistent"])
    test_dataset = PersistentDataset(data=test_data_hackathon[:],
                                     transform=hackathon_test_transfrom,
                                     cache_dir=dirs["persistent"])
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,
        #shuffle=True,
        pin_memory=using_gpu,
        num_workers=2,
        sampler=ImbalancedDatasetSampler(
            train_data_hackathon,
            callback_get_label=lambda x, i: x[i]['_label']),
        collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT))
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=4,
        shuffle=False,
        pin_memory=using_gpu,
        num_workers=2,
        sampler=ImbalancedDatasetSampler(
            valid_data_hackathon,
            callback_get_label=lambda x, i: x[i]['_label']),
        collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT))
    test_loader = DataLoader(test_dataset,
                             batch_size=4,
                             shuffle=False,
                             pin_memory=using_gpu,
                             num_workers=2,
                             collate_fn=PadListDataCollate(
                                 Method.SYMMETRIC, NumpyPadMode.CONSTANT))

    # Setup network, loss function, optimizer and scheduler
    network = nets.DenseNet121(spatial_dims=3, in_channels=1,
                               out_channels=1).to(device)
    # pos_weight for class imbalance
    _, n, p = calculate_class_imbalance(train_data_hackathon)
    pos_weight = torch.Tensor([n, p]).to(device)
    loss_function = torch.nn.BCEWithLogitsLoss(pos_weight)
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       gamma=0.95,
                                                       last_epoch=-1)

    # Setup validator and trainer
    valid_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        #Activationsd(keys="pred", softmax=True),
    ])
    validator = Validator(device=device,
                          val_data_loader=valid_loader,
                          network=network,
                          post_transform=valid_post_transforms,
                          amp=using_gpu,
                          non_blocking=using_gpu)

    trainer = Trainer(device=device,
                      out_dir=dirs["out"],
                      out_name="DenseNet121",
                      max_epochs=120,
                      validation_epoch=1,
                      validation_interval=1,
                      train_data_loader=train_loader,
                      network=network,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      lr_scheduler=None,
                      validator=validator,
                      amp=using_gpu,
                      non_blocking=using_gpu)
    """x_max, y_max, z_max, size_max = 0, 0, 0, 0
    for data in valid_loader:
        image = data["image"]
        label = data["label"]
        print()
        print(len(data['image_transforms']))
        #print(data['image_transforms'])
        print(label)
        shape = image.shape
        x_max = max(x_max, shape[-3])
        y_max = max(y_max, shape[-2])
        z_max = max(z_max, shape[-1])
        size = int(image.nelement()*image.element_size()/1024/1024)
        size_max = max(size_max, size)
        print("shape:", shape, "size:", str(size)+"MB")
        #multi_slice_viewer(image[0, 0, :, :, :], str(label))
    print(x_max, y_max, z_max, str(size_max)+"MB")
    exit()"""

    # Run trainer
    train_output = trainer.run()

    # Setup tester
    tester = Tester(device=device,
                    test_data_loader=test_loader,
                    load_dir=train_output,
                    out_dir=dirs["out"],
                    network=network,
                    post_transform=valid_post_transforms,
                    non_blocking=using_gpu,
                    amp=using_gpu)

    # Run tester
    tester.run()