Exemplo n.º 1
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)),
            CenterSpatialCropd(KEYS, (110, 99))
        ])
        num_invertible_transforms = sum(1 for i in transforms.transforms
                                        if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2, ),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value
        segs_dict = {
            "label": segs,
            label_transform_key: data[label_transform_key]
        }

        segs_dict_decollated = decollate_batch(segs_dict)

        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Exemplo n.º 2
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)), extra_transform
        ])

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2,
                     in_channels=1,
                     out_channels=1,
                     channels=(2, 4),
                     strides=(1, )).to(device)

        data = first(loader)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        self.assertIsInstance(labels, MetaTensor)
        segs = model(labels).detach().cpu()
        segs_decollated = decollate_batch(segs)
        self.assertIsInstance(segs_decollated[0], MetaTensor)
        # inverse of individual segmentation
        seg_metatensor = first(segs_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_applied_interp_mode(seg_metatensor.applied_operations,
                                    mode="nearest",
                                    align_corners=None)

        # manually invert the last crop samples
        xform = seg_metatensor.applied_operations.pop(-1)
        shape_before_extra_xform = xform["orig_size"]
        resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform)
        with resizer.trace_transform(False):
            seg_metatensor = resizer(seg_metatensor)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse({"label": seg_metatensor})["label"]
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Exemplo n.º 3
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({"image": image, "label": label.astype(np.float32)})

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform])
        num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device)

        data = first(loader)
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX
        segs_dict = {"label": segs, label_transform_key: data[label_transform_key]}

        segs_dict_decollated = decollate_batch(segs_dict)
        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

        # Inverse of batch
        batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True)
        with allow_missing_keys_mode(transforms):
            inv_batch = batch_inverter(segs_dict)
        self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
Exemplo n.º 4
0
    def first_key(self, data: Dict[Hashable, Any]):
        """
        Get the first available key of `self.keys` in the input `data` dictionary.
        If no available key, return an empty list `[]`.

        Args:
            data: data that the transform will be applied to.

        """
        return first(self.key_iterator(data), [])
Exemplo n.º 5
0
 def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape):
     train_transform = Compose(
         [
             LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
             ToTensord(keys=["image"]),
         ]
     )
     dataset = Dataset([{"image": file_path}], transform=train_transform)
     data_loader = DataLoader(dataset)
     data: dict = first(data_loader)
     for s in data[PostFix.meta("image")]["spatial_shape"]:
         torch.testing.assert_allclose(s, expected_spatial_shape)
     self.assertTupleEqual(data["image"].shape, expected_shape)
Exemplo n.º 6
0
 def __call__(self, data: Dict[str, Any]) -> Any:
     decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)
     inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used)
     inv_loader = DataLoader(
         inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn
     )
     try:
         return first(inv_loader)
     except RuntimeError as re:
         re_str = str(re)
         if "equal size" in re_str:
             re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`."
         raise RuntimeError(re_str) from re
Exemplo n.º 7
0
 def test_with_dataloader_batch(self, file_path, level, expected_spatial_shape, expected_shape):
     train_transform = Compose(
         [
             LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
             FromMetaTensord(keys=["image"]),
             ToTensord(keys=["image"]),
         ]
     )
     dataset = Dataset([{"image": file_path}, {"image": file_path}], transform=train_transform)
     batch_size = 2
     data_loader = DataLoader(dataset, batch_size=batch_size)
     data: dict = first(data_loader)
     for s in data[PostFix.meta("image")]["spatial_shape"]:
         assert_allclose(s, expected_spatial_shape, type_test=False)
     self.assertTupleEqual(data["image"].shape, (batch_size, *expected_shape[1:]))
Exemplo n.º 8
0
    def __call__(self, data: Mapping[Hashable, np.ndarray]):
        d = dict(data)
        original_spatial_shape = d[first(self.keys)].shape[1:]

        for patch in zip(*[self.patch_iter(d[key]) for key in self.keys]):
            coords = patch[0][1]  # use the coordinate of the first item
            ret = {k: v[0] for k, v in zip(self.keys, patch)}
            # fill in the extra keys with unmodified data
            for k in set(d.keys()).difference(set(self.keys)):
                ret[k] = deepcopy(d[k])
            # also store the `coordinate`, `spatial shape of original image`, `start position` in the dictionary
            ret[self.coords_key] = coords
            ret[self.original_spatial_shape_key] = original_spatial_shape
            ret[self.start_pos_key] = self.patch_iter.start_pos
            yield ret, coords
Exemplo n.º 9
0
    def is_import_statement(cls, config: Union[Dict, List, str]) -> bool:
        """
        Check whether the config is an import statement (a special case of expression).

        Args:
            config: input config content to check.
        """
        if not cls.is_expression(config):
            return False
        if "import" not in config:
            return False
        return isinstance(
            first(
                ast.iter_child_nodes(
                    ast.parse(f"{config[len(cls.prefix) :]}"))),
            (ast.Import, ast.ImportFrom))
Exemplo n.º 10
0
Arquivo: utils.py Projeto: Irme/MONAI
def dense_patch_slices(
    image_size: Sequence[int],
    patch_size: Sequence[int],
    scan_interval: Sequence[int],
) -> List[Tuple[slice, ...]]:
    """
    Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.

    Args:
        image_size: dimensions of image to iterate over
        patch_size: size of patches to generate slices
        scan_interval: dense patch sampling interval

    Returns:
        a list of slice objects defining each patch

    """
    num_spatial_dims = len(image_size)
    patch_size = get_valid_patch_size(image_size, patch_size)
    scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)

    scan_num = []
    for i in range(num_spatial_dims):
        if scan_interval[i] == 0:
            scan_num.append(1)
        else:
            num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
            scan_dim = first(
                d for d in range(num)
                if d * scan_interval[i] + patch_size[i] >= image_size[i])
            scan_num.append(scan_dim + 1 if scan_dim is not None else 1)

    starts = []
    for dim in range(num_spatial_dims):
        dim_starts = []
        for idx in range(scan_num[dim]):
            start_idx = idx * scan_interval[dim]
            start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
            dim_starts.append(start_idx)
        starts.append(dim_starts)
    out = np.asarray(
        [x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
    slices = [
        tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x))
        for x in out
    ]
    return slices
Exemplo n.º 11
0
 def _parse_import_string(self, import_string: str):
     """parse single import statement such as "from monai.transforms import Resize"""
     node = first(ast.iter_child_nodes(ast.parse(import_string)))
     if not isinstance(node, (ast.Import, ast.ImportFrom)):
         return None
     if len(node.names) < 1:
         return None
     if len(node.names) > 1:
         warnings.warn(f"ignoring multiple import alias '{import_string}'.")
     name, asname = f"{node.names[0].name}", node.names[0].asname
     asname = name if asname is None else f"{asname}"
     if isinstance(node, ast.ImportFrom):
         self.globals[asname], _ = optional_import(f"{node.module}",
                                                   name=f"{name}")
         return self.globals[asname]
     if isinstance(node, ast.Import):
         self.globals[asname], _ = optional_import(f"{name}")
         return self.globals[asname]
     return None
Exemplo n.º 12
0
def main():

    #TODO Defining file paths & output directory path
    json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json')
    data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19')
    logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467')

    if os.path.exists(logdir_path)==False:
        os.mkdir(logdir_path)

    # Load Json & Append Root Path
    with open(json_Path, 'r') as json_f:
        json_Data = json.load(json_f)

    train_Data = json_Data['training']
    val_Data = json_Data['validation']

    for idx, each_d in enumerate(train_Data):
        train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image'])

    for idx, each_d in enumerate(val_Data):
        val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image'])

    print('Total Number of Training Data Samples: {}'.format(len(train_Data)))
    print(train_Data)
    print('#' * 10)
    print('Total Number of Validation Data Samples: {}'.format(len(val_Data)))
    print(val_Data)
    print('#' * 10)

    # Set Determinism
    set_determinism(seed=123)

    # Define Training Transforms
    train_Transforms = Compose(
        [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(
            2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
        ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8)
        ]
    )

    check_ds = Dataset(data=train_Data, transform=train_Transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = first(check_loader)
    image = (check_data["image"][0][0])
    print(f"image shape: {image.shape}")

    # Define Network ViT backbone & Loss & Optimizer
    device = torch.device("cuda:0")
    model = ViTAutoEnc(
                in_channels=1,
                img_size=(96, 96, 96),
                patch_size=(16, 16, 16),
                pos_embed='conv',
                hidden_size=768,
                mlp_dim=3072,
    )

    model = model.to(device)

    # Define Hyper-paramters for training loop
    max_epochs = 500
    val_interval = 2
    batch_size = 4
    lr = 1e-4
    epoch_loss_values = []
    step_loss_values = []
    epoch_cl_loss_values = []
    epoch_recon_loss_values = []
    val_loss_values = []
    best_val_loss = 1000.0

    recon_loss = L1Loss()
    contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define DataLoader using MONAI, CacheDataset needs to be used
    train_ds = Dataset(data=train_Data, transform=train_Transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    val_ds = Dataset(data=val_Data, transform=train_Transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        epoch_cl_loss = 0
        epoch_recon_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            start_time = time.time()

            inputs, inputs_2, gt_input = (
                batch_data["image"].to(device),
                batch_data["image_2"].to(device),
                batch_data["gt_image"].to(device),
            )
            optimizer.zero_grad()
            outputs_v1, hidden_v1 = model(inputs)
            outputs_v2, hidden_v2 = model(inputs_2)

            flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
            flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

            r_loss = recon_loss(outputs_v1, gt_input)
            cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

            # Adjust the CL loss by Recon Loss
            total_loss = r_loss + cl_loss * r_loss

            total_loss.backward()
            optimizer.step()
            epoch_loss += total_loss.item()
            step_loss_values.append(total_loss.item())

            # CL & Recon Loss Storage of Value
            epoch_cl_loss += cl_loss.item()
            epoch_recon_loss += r_loss.item()

            end_time = time.time()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {total_loss.item():.4f}, "
                f"time taken: {end_time-start_time}s")

        epoch_loss /= step
        epoch_cl_loss /= step
        epoch_recon_loss /= step

        epoch_loss_values.append(epoch_loss)
        epoch_cl_loss_values.append(epoch_cl_loss)
        epoch_recon_loss_values.append(epoch_recon_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if epoch % val_interval == 0:
            print('Entering Validation for epoch: {}'.format(epoch+1))
            total_val_loss = 0
            val_step = 0
            model.eval()
            for val_batch in val_loader:
                val_step += 1
                start_time = time.time()
                inputs, gt_input = (
                    val_batch["image"].to(device),
                    val_batch["gt_image"].to(device),
                )
                print('Input shape: {}'.format(inputs.shape))
                outputs, outputs_v2 = model(inputs)
                val_loss = recon_loss(outputs, gt_input)
                total_val_loss += val_loss.item()
                end_time = time.time()

            total_val_loss /= val_step
            val_loss_values.append(total_val_loss)
            print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

            if total_val_loss < best_val_loss:
                print(f"Saving new model based on validation loss {total_val_loss:.4f}")
                best_val_loss = total_val_loss
                checkpoint = {'epoch': max_epochs,
                              'state_dict': model.state_dict(),
                              'optimizer': optimizer.state_dict()
                              }
                torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt'))

            plt.figure(1, figsize=(8, 8))
            plt.subplot(2, 2, 1)
            plt.plot(epoch_loss_values)
            plt.grid()
            plt.title('Training Loss')

            plt.subplot(2, 2, 2)
            plt.plot(val_loss_values)
            plt.grid()
            plt.title('Validation Loss')

            plt.subplot(2, 2, 3)
            plt.plot(epoch_cl_loss_values)
            plt.grid()
            plt.title('Training Contrastive Loss')

            plt.subplot(2, 2, 4)
            plt.plot(epoch_recon_loss_values)
            plt.grid()
            plt.title('Training Recon Loss')

            plt.savefig(os.path.join(logdir_path, 'loss_plots.png'))
            plt.close(1)

    print('Done')
    return None
Exemplo n.º 13
0
def dense_patch_slices(
    image_size: Sequence[int],
    patch_size: Sequence[int],
    scan_interval: Sequence[int],
) -> List[Tuple[slice, ...]]:
    """
    Enumerate all slices defining 2D/3D patches of size `patch_size` from an `image_size` input image.

    Args:
        image_size: dimensions of image to iterate over
        patch_size: size of patches to generate slices
        scan_interval: dense patch sampling interval

    Raises:
        ValueError: When ``image_size`` length is not one of [2, 3].

    Returns:
        a list of slice objects defining each patch

    """
    num_spatial_dims = len(image_size)
    if num_spatial_dims not in (2, 3):
        raise ValueError(
            f"Unsupported image_size length: {len(image_size)}, available options are [2, 3]"
        )
    patch_size = get_valid_patch_size(image_size, patch_size)
    scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)

    scan_num = list()
    for i in range(num_spatial_dims):
        if scan_interval[i] == 0:
            scan_num.append(1)
        else:
            num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
            scan_dim = first(
                d for d in range(num)
                if d * scan_interval[i] + patch_size[i] >= image_size[i])
            scan_num.append(scan_dim + 1)

    slices: List[Tuple[slice, ...]] = []
    if num_spatial_dims == 3:
        for i in range(scan_num[0]):
            start_i = i * scan_interval[0]
            start_i -= max(start_i + patch_size[0] - image_size[0], 0)
            slice_i = slice(start_i, start_i + patch_size[0])

            for j in range(scan_num[1]):
                start_j = j * scan_interval[1]
                start_j -= max(start_j + patch_size[1] - image_size[1], 0)
                slice_j = slice(start_j, start_j + patch_size[1])

                for k in range(0, scan_num[2]):
                    start_k = k * scan_interval[2]
                    start_k -= max(start_k + patch_size[2] - image_size[2], 0)
                    slice_k = slice(start_k, start_k + patch_size[2])
                    slices.append((slice_i, slice_j, slice_k))
    else:
        for i in range(scan_num[0]):
            start_i = i * scan_interval[0]
            start_i -= max(start_i + patch_size[0] - image_size[0], 0)
            slice_i = slice(start_i, start_i + patch_size[0])

            for j in range(scan_num[1]):
                start_j = j * scan_interval[1]
                start_j -= max(start_j + patch_size[1] - image_size[1], 0)
                slice_j = slice(start_j, start_j + patch_size[1])
                slices.append((slice_i, slice_j))
    return slices
    ScaleIntensityRanged(
        keys=["image"],
        a_min=-300,
        a_max=300,
        b_min=0.0,
        b_max=1.0,
        clip=True,
    ),
    #CropForegroundd(keys=["image", "label"], source_key="image"),
    ToTensord(keys=["image", "label"]),
])
"""## Check transforms in DataLoader"""

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
#fig = plt.figure("check", (12, 6)) #figure size
#plt.subplot(1, 2, 1)
#plt.title("image")
#plt.imshow(image[:, :, 80], cmap="gray")
#plt.subplot(1, 2, 2)
#plt.title("label")
#plt.imshow(label[:, :, 80])
#plt.show()
#fig.savefig('my_figure.png')
"""## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  
Exemplo n.º 15
0
def train(n_feat,
          crop_size,
          bs,
          ep,
          optimizer="rmsprop",
          lr=5e-4,
          pretrain=None):
    model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
    print(f"save the best model as '{model_name}' during training.")

    crop_size = [int(cz) for cz in crop_size.split(",")]
    print(f"input image crop_size: {crop_size}")

    # starting training set loader
    train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
    if np.any([cz == -1 for cz in crop_size]):  # using full image
        train_transform = Compose([
            AddChannelDict(keys="image"),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform)
        # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
        train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                       num_workers=4,
                                                       batch_size=bs,
                                                       shuffle=True)
    else:
        # draw balanced foreground/background window samples according to the ground truth label
        train_transform = Compose([
            AddChannelDict(keys="image"),
            SpatialPadd(
                keys=("image", "label"),
                spatial_size=crop_size),  # ensure image size >= crop_size
            RandCropByPosNegLabeld(keys=("image", "label"),
                                   label_key="label",
                                   spatial_size=crop_size,
                                   num_samples=bs),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform
                                )  # each dataset item is a list of windows
        train_dataloader = torch.utils.data.DataLoader(  # stack each dataset item into a single tensor
            train_dataset,
            num_workers=4,
            batch_size=1,
            shuffle=True,
            collate_fn=list_data_collate)
    first_sample = first(train_dataloader)
    print(first_sample["image"].shape)

    # starting validation set loader
    val_transform = Compose([AddChannelDict(keys="image")])
    val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES),
                          transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=1,
                                                 batch_size=1)
    print(val_dataset[0]["image"].shape)
    print(
        f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}"
    )

    model = UNetPipe(spatial_dims=3,
                     in_channels=1,
                     out_channels=N_CLASSES,
                     n_feat=n_feat)
    model = flatten_sequential(model)
    lossweight = torch.from_numpy(
        np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98],
                 np.float32))

    if optimizer.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)  # lr = 5e-4
    elif optimizer.lower() == "momentum":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9)  # lr = 1e-4 for finetuning
    else:
        raise ValueError(
            f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')."
        )

    # config GPipe
    x = first_sample["image"].float()
    x = torch.autograd.Variable(x.cuda())
    partitions = torch.cuda.device_count()
    print(f"partition: {partitions}, input: {x.size()}")
    balance = balance_by_size(partitions, model, x)
    model = GPipe(model, balance, chunks=4, checkpoint="always")

    # config loss functions
    dice_loss_func = DiceLoss(softmax=True, reduction="none")
    # use the same pipeline and loss in
    # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
    # Medical Physics, 2018.
    focal_loss_func = FocalLoss(reduction="none")

    if pretrain:
        print(f"loading from {pretrain}.")
        pretrained_dict = torch.load(pretrain)["weight"]
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    b_time = time.time()
    best_val_loss = [0] * (N_CLASSES - 1)  # foreground
    for epoch in range(ep):
        model.train()
        trainloss = 0
        for b_idx, data_dict in enumerate(train_dataloader):
            x_train = data_dict["image"]
            y_train = data_dict["label"]
            flagvec = data_dict["with_complete_groundtruth"]

            x_train = torch.autograd.Variable(x_train.cuda())
            y_train = torch.autograd.Variable(y_train.cuda().float())
            optimizer.zero_grad()
            o = model(x_train).to(0, non_blocking=True).float()

            loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                    lossweight.to(o)).mean()
            loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                           lossweight.to(o)).mean()
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if b_idx % 20 == 0:
                print(
                    f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}"
                )
        print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")

        if epoch % 10 == 0:
            model.eval()
            # check validation dice
            val_loss = [0] * (N_CLASSES - 1)
            n_val = [0] * (N_CLASSES - 1)
            for data_dict in val_dataloader:
                x_val = data_dict["image"]
                y_val = data_dict["label"]
                with torch.no_grad():
                    x_val = torch.autograd.Variable(x_val.cuda())
                o = model(x_val).to(0, non_blocking=True)
                loss = compute_meandice(o,
                                        y_val.to(o),
                                        mutually_exclusive=True,
                                        include_background=False)
                val_loss = [
                    l.item() + tl if l == l else tl
                    for l, tl in zip(loss[0], val_loss)
                ]
                n_val = [
                    n + 1 if l == l else n for l, n in zip(loss[0], n_val)
                ]
            val_loss = [l / n for l, n in zip(val_loss, n_val)]
            print(
                "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(val_loss))
            for c in range(1, 10):
                if best_val_loss[c - 1] < val_loss[c - 1]:
                    best_val_loss[c - 1] = val_loss[c - 1]
                    state = {
                        "epoch": epoch,
                        "weight": model.state_dict(),
                        "score_" + str(c): best_val_loss[c - 1]
                    }
                    torch.save(state, f"{model_name}" + str(c))
            print(
                "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(best_val_loss))

    print("total time", time.time() - b_time)
Exemplo n.º 16
0
def train(cfg):
    log_dir = create_log_dir(cfg)
    device = set_device(cfg)
    # --------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # Build MONAI preprocessing
    train_preprocess = Compose([
        ToTensorD(keys="image"),
        TorchVisionD(keys="image",
                     name="ColorJitter",
                     brightness=64.0 / 255.0,
                     contrast=0.75,
                     saturation=0.25,
                     hue=0.04),
        ToNumpyD(keys="image"),
        RandFlipD(keys="image", prob=0.5),
        RandRotate90D(keys="image", prob=0.5),
        CastToTypeD(keys="image", dtype=np.float32),
        RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    valid_preprocess = Compose([
        CastToTypeD(keys="image", dtype=np.float32),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    # __________________________________________________________________________
    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )

    train_dataset = PatchWSIDataset(
        train_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        train_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        valid_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        valid_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # __________________________________________________________________________
    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)

    # __________________________________________________________________________
    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("Fist sample is None!")

    print("image: ")
    print("    shape", first_sample["image"].shape)
    print("    type: ", type(first_sample["image"]))
    print("    dtype: ", first_sample["image"].dtype)
    print("labels: ")
    print("    shape", first_sample["label"].shape)
    print("    type: ", type(first_sample["label"]))
    print("    dtype: ", first_sample["label"].dtype)
    print(f"batch size: {cfg['batch_size']}")
    print(f"train number of batches: {len(train_dataloader)}")
    print(f"valid number of batches: {len(valid_dataloader)}")

    # --------------------------------------------------------------------------
    # Deep Learning Classification Model
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model
    model = TorchVisionFCModel("resnet18",
                               num_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # optimizer
    if cfg["novograd"]:
        optimizer = Novograd(model.parameters(), cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    if cfg["amp"]:
        cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= (
            1, 6) else False
    else:
        cfg["amp"] = False

    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=cfg["n_epochs"])

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(save_dir=log_dir,
                        save_dict={"net": model},
                        save_key_metric=True),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir,
                                output_transform=lambda x: None),
    ]
    val_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=valid_dataloader,
        network=model,
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        CheckpointSaver(save_dir=cfg["logdir"],
                        save_dict={
                            "net": model,
                            "opt": optimizer
                        },
                        save_interval=1,
                        epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        TensorBoardStatsHandler(log_dir=cfg["logdir"],
                                tag_name="train_loss",
                                output_transform=from_engine(["loss"],
                                                             first=True)),
    ]
    train_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_dataloader,
        network=model,
        optimizer=optimizer,
        loss_function=loss_func,
        postprocessing=train_postprocessing,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()
Exemplo n.º 17
0
def main(cfg):
    # -------------------------------------------------------------------------
    # Configs
    # -------------------------------------------------------------------------
    # Create log/model dir
    log_dir = create_log_dir(cfg)

    # Set the logger
    logging.basicConfig(
        format="%(asctime)s %(levelname)2s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    log_name = os.path.join(log_dir, "logs.txt")
    logger = logging.getLogger()
    fh = logging.FileHandler(log_name)
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Set TensorBoard summary writer
    writer = SummaryWriter(log_dir)

    # Save configs
    logging.info(json.dumps(cfg))
    with open(os.path.join(log_dir, "config.json"), "w") as fp:
        json.dump(cfg, fp, indent=4)

    # Set device cuda/cpu
    device = set_device(cfg)

    # Set cudnn benchmark/deterministic
    if cfg["benchmark"]:
        torch.backends.cudnn.benchmark = True
    else:
        set_determinism(seed=0)
    # -------------------------------------------------------------------------
    # Transforms and Datasets
    # -------------------------------------------------------------------------
    # Pre-processing
    preprocess_cpu_train = None
    preprocess_gpu_train = None
    preprocess_cpu_valid = None
    preprocess_gpu_valid = None
    if cfg["backend"] == "cucim":
        preprocess_cpu_train = Compose([ToTensorD(keys="label")])
        preprocess_gpu_train = Compose([
            Range()(ToCupy()),
            Range("ColorJitter")(RandCuCIM(name="color_jitter",
                                           brightness=64.0 / 255.0,
                                           contrast=0.75,
                                           saturation=0.25,
                                           hue=0.04)),
            Range("RandomFlip")(RandCuCIM(name="image_flip",
                                          apply_prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90",
                                              prob=cfg["prob"],
                                              max_k=3,
                                              spatial_axis=(-2, -1))),
            Range()(CastToType(dtype=np.float32)),
            Range("RandomZoom")(RandCuCIM(name="rand_zoom",
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(CuCIM(name="scale_intensity_range",
                                          a_min=0.0,
                                          a_max=255.0,
                                          b_min=-1.0,
                                          b_max=1.0)),
            Range()(ToTensor(device=device)),
        ])
        preprocess_cpu_valid = Compose([ToTensorD(keys="label")])
        preprocess_gpu_valid = Compose([
            Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)),
            Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range",
                                               a_min=0.0,
                                               a_max=255.0,
                                               b_min=-1.0,
                                               b_max=1.0)),
            Range("ValidToTensor")(ToTensor(device=device)),
        ])
    elif cfg["backend"] == "numpy":
        preprocess_cpu_train = Compose([
            Range()(ToTensorD(keys=("image", "label"))),
            Range("ColorJitter")(TorchVisionD(
                keys="image",
                name="ColorJitter",
                brightness=64.0 / 255.0,
                contrast=0.75,
                saturation=0.25,
                hue=0.04,
            )),
            Range()(ToNumpyD(keys="image")),
            Range("RandomFlip")(RandFlipD(keys="image",
                                          prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandRotate90D(keys="image",
                                                  prob=cfg["prob"])),
            Range()(CastToTypeD(keys="image", dtype=np.float32)),
            Range("RandomZoom")(RandZoomD(keys="image",
                                          prob=cfg["prob"],
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                         a_min=0.0,
                                                         a_max=255.0,
                                                         b_min=-1.0,
                                                         b_max=1.0)),
            Range()(ToTensorD(keys="image")),
        ])
        preprocess_cpu_valid = Compose([
            Range("ValidCastType")(CastToTypeD(keys="image",
                                               dtype=np.float32)),
            Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                              a_min=0.0,
                                                              a_max=255.0,
                                                              b_min=-1.0,
                                                              b_max=1.0)),
            Range("ValidToTensor")(ToTensorD(keys=("image", "label"))),
        ])
    else:
        raise ValueError(
            f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]"
        )

    # Post-processing
    postprocess = Compose([
        Activations(sigmoid=True),
        AsDiscrete(threshold=0.5),
    ])

    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )
    train_dataset = PatchWSIDataset(
        data=train_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_train,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        data=valid_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_valid,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])

    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("First sample is None!")
    for d in ["image", "label"]:
        logging.info(f"[{d}] \n"
                     f"  {d} shape: {first_sample[d].shape}\n"
                     f"  {d} type:  {type(first_sample[d])}\n"
                     f"  {d} dtype: {first_sample[d].dtype}")
    logging.info(f"Batch size: {cfg['batch_size']}")
    logging.info(f"[Training] number of batches: {len(train_dataloader)}")
    logging.info(f"[Validation] number of batches: {len(valid_dataloader)}")
    # -------------------------------------------------------------------------
    # Deep Learning Model and Configurations
    # -------------------------------------------------------------------------
    # Initialize model
    model = TorchVisionFCModel("resnet18",
                               n_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # Loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # Optimizer
    if cfg["novograd"] is True:
        optimizer = Novograd(model.parameters(), lr=cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6)
    if cfg["amp"] is True:
        scaler = GradScaler()
    else:
        scaler = None

    # Learning rate scheduler
    if cfg["cos"] is True:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=cfg["n_epochs"])
    else:
        scheduler = None

    # -------------------------------------------------------------------------
    # Training/Evaluating
    # -------------------------------------------------------------------------
    train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1}

    total_valid_time, total_train_time = 0.0, 0.0
    t_start = time.perf_counter()
    metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1}
    # Training/Validation Loop
    for _ in range(cfg["n_epochs"]):
        t_epoch = time.perf_counter()
        logging.info(
            f"[Training] learning rate: {optimizer.param_groups[0]['lr']}")

        # Training
        with Range("Training Epoch"):
            train_counter = training(
                train_counter,
                model,
                loss_func,
                optimizer,
                scaler,
                cfg["amp"],
                train_dataloader,
                preprocess_gpu_train,
                postprocess,
                device,
                writer,
                cfg["print_step"],
            )
        if scheduler is not None:
            scheduler.step()
        if cfg["save"]:
            torch.save(
                model.state_dict(),
                os.path.join(log_dir,
                             f"model_epoch_{train_counter['epoch']}.pt"))
        t_train = time.perf_counter()
        train_time = t_train - t_epoch
        total_train_time += train_time

        # Validation
        if cfg["validate"]:
            with Range("Validation"):
                valid_loss, valid_acc = validation(
                    model,
                    loss_func,
                    cfg["amp"],
                    valid_dataloader,
                    preprocess_gpu_valid,
                    postprocess,
                    device,
                    cfg["print_step"],
                )
            t_valid = time.perf_counter()
            valid_time = t_valid - t_train
            total_valid_time += valid_time
            if valid_loss < metric_summary["loss"]:
                metric_summary["loss"] = min(valid_loss,
                                             metric_summary["loss"])
                metric_summary["accuracy"] = max(valid_acc,
                                                 metric_summary["accuracy"])
                metric_summary["best_epoch"] = train_counter["epoch"]
            writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"])
            writer.add_scalar("valid/accuracy", valid_acc,
                              train_counter["epoch"])

            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, "
                f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)"
            )
        else:
            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s"
            )
        writer.flush()
    t_end = time.perf_counter()

    # Save final metrics
    metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"]
    metric_summary["total_time"] = t_end - t_start
    writer.add_hparams(hparam_dict=cfg,
                       metric_dict=metric_summary,
                       run_name=log_dir)
    writer.close()
    logging.info(f"Metric Summary: {metric_summary}")

    # Save the best and final model
    if cfg["validate"] is True:
        copyfile(
            os.path.join(log_dir,
                         f"model_epoch_{metric_summary['best_epoch']}.pth"),
            os.path.join(log_dir, "model_best.pth"),
        )
        copyfile(
            os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"),
            os.path.join(log_dir, "model_final.pth"),
        )

    # Final prints
    logging.info(
        f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s "
        f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)",
    )
    logging.info(f"Logs and model was saved at: {log_dir}")
Exemplo n.º 18
0
def decollate_batch(batch, detach: bool = True):
    """De-collate a batch of data (for example, as produced by a `DataLoader`).

    Returns a list of structures with the original tensor's 0-th dimension sliced into elements using `torch.unbind`.

    Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information,
    such as metadata, may have been stored in a list (or a list inside nested dictionaries). In
    this case we return the element of the list corresponding to the batch idx.

    Return types aren't guaranteed to be the same as the original, since numpy arrays will have been
    converted to torch.Tensor, sequences may be converted to lists of tensors,
    mappings may be converted into dictionaries.

    For example:

    .. code-block:: python

        batch_data = {
            "image": torch.rand((2,1,10,10)),
            "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])}
        }
        out = decollate_batch(batch_data)
        print(len(out))
        >>> 2

        print(out[0])
        >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}}

        batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))]
        out = decollate_batch(batch_data)
        print(out[0])
        >>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])]

        batch_data = torch.rand((2,1,10,10))
        out = decollate_batch(batch_data)
        print(out[0])
        >>> tensor([[[4.3549e-01...43e-01]]])

    Args:
        batch: data to be de-collated.
        detach: whether to detach the tensors. Scalars tensors will be detached into number types
            instead of torch tensors.
    """
    if batch is None:
        return batch
    if isinstance(batch, (float, int, str, bytes)):
        return batch
    if isinstance(batch, torch.Tensor):
        if detach:
            batch = batch.detach()
        if batch.ndim == 0:
            return batch.item() if detach else batch
        out_list = torch.unbind(batch, dim=0)
        if out_list[0].ndim == 0 and detach:
            return [t.item() for t in out_list]
        return list(out_list)
    if isinstance(batch, Mapping):
        _dict_list = {
            key: decollate_batch(batch[key], detach)
            for key in batch
        }
        return [
            dict(zip(_dict_list, item)) for item in zip(*_dict_list.values())
        ]
    if isinstance(batch, Iterable):
        item_0 = first(batch)
        if (not isinstance(item_0, Iterable)
                or isinstance(item_0, (str, bytes))
                or (isinstance(item_0, torch.Tensor) and item_0.ndim == 0)):
            # Not running the usual list decollate here:
            # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']]
            # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor
            return [decollate_batch(b, detach) for b in batch]
        return [
            list(item)
            for item in zip(*(decollate_batch(b, detach) for b in batch))
        ]
    raise NotImplementedError(
        f"Unable to de-collate: {batch}, type: {type(batch)}.")
Exemplo n.º 19
0
val_image_trans = Compose([ScaleIntensity(), AddChannel(), ToTensor(),])

val_seg_trans = Compose([AddChannel(), ToTensor()])


val_ds = ArrayDataset(test_images, val_image_trans, test_segs, val_seg_trans)
val_loader = DataLoader(
    dataset=val_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
)

# %timeit first(loader)

batch = first(loader)
im = batch["img"]
seg = batch["seg"]
print(im.shape, im.min(), im.max(), seg.shape)
plt.imshow(im[0, 0].numpy() + seg[0, 0].numpy(), cmap="gray")



#%%

#Using UNet for the segmentation network
#Training scheme: For each epoch the training is made on each batch of iamges from the training set, assim the training is made with each image once, and then is evaluated with the validation set.

net = monai.networks.nets.UNet(
    dimensions=2,
    in_channels=1,