def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    :param dataloader_train:
    :param dataloader_val:
    :param patch_size:
    :param params:
    :param border_val_seg:
    :return:
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"),
                                                  seeds=range(params.get('num_threads')), pin_memory=True)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1),
                                                params.get("num_cached_per_thread"),
                                                seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
예제 #3
0
 def get_validation_transforms(self):
     val_transforms = []
     if self.params.get("selected_data_channels"):
         val_transforms.append(
             DataChannelSelectionTransform(
                 self.params.get("selected_data_channels")))
     if self.params.get("selected_seg_channels"):
         val_transforms.append(
             SegChannelSelectionTransform(
                 self.params.get("selected_seg_channels")))
     val_transforms.append(CenterCropTransform(self.patch_size))
     val_transforms.append(RemoveLabelTransform(-1, 0))
     val_transforms.append(RenameTransform('seg', 'target', True))
     val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
     return Compose(val_transforms)
def get_default_augmentation_withEDT(dataloader_train,
                                     dataloader_val,
                                     patch_size,
                                     idx_of_edts,
                                     params=default_3D_augmentation_params,
                                     border_val_seg=-1,
                                     pin_memory=True,
                                     seeds_train=None,
                                     seeds_val=None):
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key so that it does not get intensity transformed
    ##############################################################
    ##############################################################
    """
    tr_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key
    ##############################################################
    ##############################################################
    """
    val_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
예제 #5
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))
        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
        if self.params.get("dummy_2D", False):
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
        else:
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D"):
            tr_transforms.append(Convert2DTo3DTransform())

        # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
        # channel gets in the way
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.5),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5), )
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3),
                                              p_per_sample=0.15))
        if self.params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(
                    self.params.get("additive_brightness_mu"),
                    self.params.get("additive_brightness_sigma"),
                    True,
                    p_per_sample=self.params.get(
                        "additive_brightness_p_per_sample"),
                    p_per_channel=self.params.get(
                        "additive_brightness_p_per_channel")))
        tr_transforms.append(
            ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                          p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes), )
        tr_transforms.append(
            GammaTransform(self.params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=self.params.get("gamma_retain_stats"),
                           p_per_sample=0.15))  # inverted gamma

        if self.params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))
        if self.params.get("do_mirror") or self.params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))
        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
예제 #6
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"
        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))

        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        if self.params.get("dummy_2D", False):
            # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
            tr_transforms.append(Convert3DTo2DTransform())

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D", False):
            tr_transforms.append(Convert2DTo3DTransform())

        if self.params.get("do_gamma", False):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))

        if self.params.get("do_mirror", False):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))

        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
예제 #7
0
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0]))))


if __name__ == '__main__':
    ### current implementation of betchgenerators stuff for this script does not use _use_shared_memory!

    from time import time
    batch_size = 50
    num_workers = 8
    pin_memory = False
    num_epochs = 3
    dataset_dir = '/media/fabian/data/data/cifar10'
    numpy_to_tensor = NumpyToTensor(['data', 'labels'], cast_to=None)
    fname = os.path.join(dataset_dir, 'cifar10_training_data.npz')
    dataset = np.load(fname)
    cifar_dataset_as_arrays = (dataset['data'], dataset['labels'],
                               dataset['filenames'])
    print('batch_size', batch_size)
    print('num_workers', num_workers)
    print('pin_memory', pin_memory)
    print('num_epochs', num_epochs)

    tr_transforms = [
        SpatialTransform((32, 32))
    ] * 1  # SpatialTransform is computationally expensive and we need some
    # load on CPU so we just stack 5 of them on top of each other
    tr_transforms.append(numpy_to_tensor)
    tr_transforms = Compose(tr_transforms)
예제 #8
0
def Transforms(patch_size,
               params=default_3D_augmentation_params,
               border_val_seg=-1):
    tr_transforms = []
    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(params.get("selected_data_channels"),
                                          data_key="data"))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())
    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
예제 #9
0
def train(optimizer_options, data_options, logger_options, model_options, scheduler_options):#, results_path=None):
    
    #torch.manual_seed(42)
    #np.random.seed(42)

    vis = Visdom(env=logger_options['vislogger_env'], port=logger_options['vislogger_port'])    
    device = torch.device(optimizer_options['device'])
    epochs = optimizer_options['epochs']
    
    

    ## ======================================= Scheduler ======================================= ##
    
    ## ======================================= Scheduler ======================================= ##

    ## ======================================= Save model ======================================= ##
    if (logger_options['save_model'] == ""):
        model_checkpoint = ModelCheckpoint()
    else:
        suffix = optimizer_options['optimizer']+"_"+str(optimizer_options['learning_rate'])+"_"+logger_options['suffix']
        model_checkpoint = ModelCheckpoint(save_model=True, save_path=logger_options['save_model'], 
                                            use_loss=True, suffix=suffix) ####### FILL PARAMTERS!!!!
    ## ======================================= Save model ======================================= ##

    ## ======================================= Data ======================================= ##
    # image_transform = Compose([Resize(data_options['image_size'])])
    # image_transform = Compose([Resize(data_options['image_size']), ToTensor()])
    
    # iaa_transform = iaa.Sequential([iaa.Scale(0.5)]) # Not worth scaling image, haven't found a fast scaler.
        
    image_transform = Compose([ #Resize(data_options['image_size']),
                                RangeNormalize(0., 1.0), # Faster than the one in BatchGenerators
                                ChannelFirst(),
                                # RangeTransform(), 
                                MeanStdNormalizationTransform(mean=[0.3610,0.2131,0.2324],
                                                               std=[0.0624,0.0463,0.0668]),
                                NumpyToTensor(keys=['data', 'target'])
                                ])
    
    
    # kfoldWorkflowSet = kFoldWorkflowSplitMT('/home/anant/data/endovis/COMPRESSED_0_05/TrainingSet/', 
    #                                         image_transform=image_transform,
    #                                         video_extn='.avi', shuffle=True,
    #                                         n_folds=3, num_phases=14,
    #                                         batch_size=32, 
    #                                         num_workers=12)

    kfoldWorkflowSet = kFoldWorkflowSplitMT(data_options['base_path'], 
                                            image_transform=image_transform,
                                            video_extn='.avi', shuffle=True,
                                            n_folds=data_options['n_folds'], num_phases=14,
                                            batch_size=data_options['batch_size'], 
                                            num_workers=data_options['n_threads'],
                                            video_folder='videos_480x272')
    ## ======================================= Data ======================================= ##
    
    nfolds_training_loss_avg = CumulativeMovingAvgStd()
    nfolds_validation_loss_avg = CumulativeMovingAvgStd()
    nfolds_validation_score_avg = CumulativeMovingAvgStd()

    
    folds_pbar = ProgressBar(kfoldWorkflowSet, desc="Folds", pb_len=optimizer_options['run_nfolds'])
    max_folds = folds_pbar.total


    for iFold, (train_loader, val_loader) in enumerate(folds_pbar): #= next(kfoldWorkflowSet)
        ## ======================================= Create Plot ======================================= ##
        create_plot_window(vis, "Epochs+Iterations", "CE Loss", "Training loss Fold "+str(iFold+1), 
                            tag='Training_Loss_Fold_'+str(iFold+1), name='Training Loss Fold '+str(iFold+1))
        create_plot_window(vis, "Epochs+Iterations", "CE Loss", "Validation loss Fold "+str(iFold+1), 
                            tag='Validation_Loss_Fold_'+str(iFold+1), name='Validation Loss Fold '+str(iFold+1))
        create_plot_window(vis, "Epochs+Iterations", "Score", "Validation Score Fold "+str(iFold+1), 
                            tag='Validation_Score_Fold_'+str(iFold+1), name='Validation Loss Fold '+str(iFold+1))
        ## ======================================= Create Plot ======================================= ##


        ## ======================================= Model ======================================= ##
        # TODO: Pass 'models.resnet50' as string

        model = ResFeatureExtractor(pretrained_model=models.resnet101, 
                                        device=device)
                                        
        if model_options['pretrained'] is not None:
            # print('Loading pretrained model...') 
            checkpoint = torch.load(model_options['pretrained'])
            model.load_state_dict(checkpoint['model'])

        
        ## ======================================= Model ======================================= ##


        ### ============================== Parts of Training step ============================== ###
        criterion_CE = nn.CrossEntropyLoss().to(device)        
        ### ============================== Parts of Training step ============================== ###
        
        epoch_pbar = ProgressBar(range(epochs), desc="Epochs") #tqdm(range(epochs))
        epoch_training_avg_loss = CumulativeMovingAvgStd()
        epoch_training_avg_score = CumulativeMovingAvgStd()
        epoch_validation_loss = BestScore()
        # epoch_validation_score = BestScore()
        epoch_msg_dict = {}

        evaluator = Engine(model, None, criterion_CE, None, val_loader, 0, device, False,
                            use_half_precision=optimizer_options["use_half_precision"],
                            score_type="f1")

        for epoch in epoch_pbar:
            if (optimizer_options['switch_optimizer'] > 0) and ((epoch+1) % optimizer_options['switch_optimizer'] == 0):
                temp_optimizer_options = optimizer_options
                temp_optimizer_options['optimizer'] = 'sgd'
                temp_optimizer_options['learning_rate'] = 1e-3
                optimizer, scheduler = get_optimizer(model.parameters(), temp_optimizer_options, scheduler_options, train_loader, vis)
            else:
                optimizer, scheduler = get_optimizer(model.parameters(), optimizer_options, scheduler_options, train_loader, vis)
            # else:
            #     optimizer, scheduler = get_optimizer(model.parameters(), optimizer_options, scheduler_options, train_loader, vis)

            runEpoch(train_loader, model, criterion_CE, optimizer, scheduler, device, 
                        vis, epoch, iFold, folds_pbar, epoch_training_avg_loss,
                        epoch_training_avg_score, logger_options, optimizer_options, 
                        epoch_msg_dict)

            ### ============================== Validation ============================== ###
            validation_loss, validation_score = None, None
            if (optimizer_options["validation_interval_epochs"] > 0):
                if ((epoch+1) % optimizer_options["validation_interval_epochs"] == 0):
                    validation_loss, validation_score = predict(evaluator, 
                                                                optimizer_options['max_valid_iterations'], 
                                                                device, vis)
                    epoch_validation_loss.step(validation_loss, [validation_score])
                    # epoch_validation_score.step(validation_score)
                    vis.line(X=np.array([epoch]), 
                                Y=np.array([validation_loss]),
                                update='append', win='Validation_Loss_Fold_'+str(iFold+1), 
                                name='Validation Loss Fold '+str(iFold+1))
                    vis.line(X=np.array([epoch]), 
                                Y=np.array([validation_score]),
                                update='append', win='Validation_Score_Fold_'+str(iFold+1), 
                                name='Validation Score Fold '+str(iFold+1))
                    epoch_msg_dict['CVL'] = validation_loss
                    epoch_msg_dict['CVS'] = validation_score
                    epoch_msg_dict['BVL'] = epoch_validation_loss.score()[0]
                    epoch_msg_dict['BVS'] = epoch_validation_loss.score()[1][0]
                    folds_pbar.update_message(msg_dict=epoch_msg_dict)
            ### ============================== Validation ============================== ###


            ### ============================== Save model ============================== ###
            model_checkpoint.step(curr_loss=validation_loss,
                                    model=model, suffix='_Fold_'+str(iFold))
            vis.save([logger_options['vislogger_env']])
            ### ============================== Save model ============================== ###
            
            # if early_stop:
            #     epoch_pbar.close()
            #     break

            # torch.cuda.empty_cache()

        if (iFold+1) == max_folds:
            folds_pbar.refresh()
            folds_pbar.close()
            break

    print("\n\n\n\n=================================== DONE ===================================\n\n")