def test_image_pipeline_and_pin_memory(self): ''' This just should not crash :return: ''' try: import torch except ImportError: '''dont test if torch is not installed''' return from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) tr_transforms.append(NumpyToTensor(keys='data', cast_to='float')) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True) for _ in range(50): res = mt.next() assert isinstance(res['data'], torch.Tensor) assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation :param dataloader_train: :param dataloader_val: :param patch_size: :param params: :param border_val_seg: :return: """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=True) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def get_validation_transforms(self): val_transforms = [] if self.params.get("selected_data_channels"): val_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): val_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) val_transforms.append(CenterCropTransform(self.patch_size)) val_transforms.append(RemoveLabelTransform(-1, 0)) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(val_transforms)
def get_default_augmentation_withEDT(dataloader_train, dataloader_val, patch_size, idx_of_edts, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None): tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key so that it does not get intensity transformed ############################################################## ############################################################## """ tr_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "advanced_pyramid_augmentations") and not None and params.get( "advanced_pyramid_augmentations"): tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.0, dont_do_if_covers_more_than_X_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key ############################################################## ############################################################## """ val_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def get_training_transforms(self): assert self.params.get( 'mirror' ) is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if self.params.get("selected_data_channels"): tr_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): tr_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if self.params.get("dummy_2D", False): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform( self._spatial_transform_patch_size, patch_center_dist_from_border=None, do_elastic_deform=self.params.get("do_elastic"), alpha=self.params.get("elastic_deform_alpha"), sigma=self.params.get("elastic_deform_sigma"), do_rotation=self.params.get("do_rotation"), angle_x=self.params.get("rotation_x"), angle_y=self.params.get("rotation_y"), angle_z=self.params.get("rotation_z"), do_scale=self.params.get("do_scaling"), scale=self.params.get("scale_range"), order_data=self.params.get("order_data"), border_mode_data=self.params.get("border_mode_data"), border_cval_data=self.params.get("border_cval_data"), order_seg=self.params.get("order_seg"), border_mode_seg=self.params.get("border_mode_seg"), border_cval_seg=self.params.get("border_cval_seg"), random_crop=self.params.get("random_crop"), p_el_per_sample=self.params.get("p_eldef"), p_scale_per_sample=self.params.get("p_scale"), p_rot_per_sample=self.params.get("p_rot"), independent_scale_for_each_axis=self.params.get( "independent_scale_factor_for_each_axis"), )) if self.params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5), ) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3), p_per_sample=0.15)) if self.params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( self.params.get("additive_brightness_mu"), self.params.get("additive_brightness_sigma"), True, p_per_sample=self.params.get( "additive_brightness_p_per_sample"), p_per_channel=self.params.get( "additive_brightness_p_per_channel"))) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes), ) tr_transforms.append( GammaTransform(self.params.get("gamma_range"), True, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if self.params.get("do_gamma"): tr_transforms.append( GammaTransform( self.params.get("gamma_range"), False, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=self.params["p_gamma"])) if self.params.get("do_mirror") or self.params.get("mirror"): tr_transforms.append( MirrorTransform(self.params.get("mirror_axes"))) if self.params.get("use_mask_for_norm"): use_mask_for_norm = self.params.get("use_mask_for_norm") tr_transforms.append( MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(tr_transforms)
def get_training_transforms(self): assert self.params.get( 'mirror' ) is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if self.params.get("selected_data_channels"): tr_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): tr_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) if self.params.get("dummy_2D", False): # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform( self._spatial_transform_patch_size, patch_center_dist_from_border=None, do_elastic_deform=self.params.get("do_elastic"), alpha=self.params.get("elastic_deform_alpha"), sigma=self.params.get("elastic_deform_sigma"), do_rotation=self.params.get("do_rotation"), angle_x=self.params.get("rotation_x"), angle_y=self.params.get("rotation_y"), angle_z=self.params.get("rotation_z"), do_scale=self.params.get("do_scaling"), scale=self.params.get("scale_range"), order_data=self.params.get("order_data"), border_mode_data=self.params.get("border_mode_data"), border_cval_data=self.params.get("border_cval_data"), order_seg=self.params.get("order_seg"), border_mode_seg=self.params.get("border_mode_seg"), border_cval_seg=self.params.get("border_cval_seg"), random_crop=self.params.get("random_crop"), p_el_per_sample=self.params.get("p_eldef"), p_scale_per_sample=self.params.get("p_scale"), p_rot_per_sample=self.params.get("p_rot"), independent_scale_for_each_axis=self.params.get( "independent_scale_factor_for_each_axis"), )) if self.params.get("dummy_2D", False): tr_transforms.append(Convert2DTo3DTransform()) if self.params.get("do_gamma", False): tr_transforms.append( GammaTransform( self.params.get("gamma_range"), False, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=self.params["p_gamma"])) if self.params.get("do_mirror", False): tr_transforms.append( MirrorTransform(self.params.get("mirror_axes"))) if self.params.get("use_mask_for_norm"): use_mask_for_norm = self.params.get("use_mask_for_norm") tr_transforms.append( MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(tr_transforms)
transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError((error_msg.format(type(batch[0])))) if __name__ == '__main__': ### current implementation of betchgenerators stuff for this script does not use _use_shared_memory! from time import time batch_size = 50 num_workers = 8 pin_memory = False num_epochs = 3 dataset_dir = '/media/fabian/data/data/cifar10' numpy_to_tensor = NumpyToTensor(['data', 'labels'], cast_to=None) fname = os.path.join(dataset_dir, 'cifar10_training_data.npz') dataset = np.load(fname) cifar_dataset_as_arrays = (dataset['data'], dataset['labels'], dataset['filenames']) print('batch_size', batch_size) print('num_workers', num_workers) print('pin_memory', pin_memory) print('num_epochs', num_epochs) tr_transforms = [ SpatialTransform((32, 32)) ] * 1 # SpatialTransform is computationally expensive and we need some # load on CPU so we just stack 5 of them on top of each other tr_transforms.append(numpy_to_tensor) tr_transforms = Compose(tr_transforms)
def Transforms(patch_size, params=default_3D_augmentation_params, border_val_seg=-1): tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform(params.get("selected_data_channels"), data_key="data")) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "advanced_pyramid_augmentations") and not None and params.get( "advanced_pyramid_augmentations"): tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.0, dont_do_if_covers_more_than_X_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) return tr_transforms
def train(optimizer_options, data_options, logger_options, model_options, scheduler_options):#, results_path=None): #torch.manual_seed(42) #np.random.seed(42) vis = Visdom(env=logger_options['vislogger_env'], port=logger_options['vislogger_port']) device = torch.device(optimizer_options['device']) epochs = optimizer_options['epochs'] ## ======================================= Scheduler ======================================= ## ## ======================================= Scheduler ======================================= ## ## ======================================= Save model ======================================= ## if (logger_options['save_model'] == ""): model_checkpoint = ModelCheckpoint() else: suffix = optimizer_options['optimizer']+"_"+str(optimizer_options['learning_rate'])+"_"+logger_options['suffix'] model_checkpoint = ModelCheckpoint(save_model=True, save_path=logger_options['save_model'], use_loss=True, suffix=suffix) ####### FILL PARAMTERS!!!! ## ======================================= Save model ======================================= ## ## ======================================= Data ======================================= ## # image_transform = Compose([Resize(data_options['image_size'])]) # image_transform = Compose([Resize(data_options['image_size']), ToTensor()]) # iaa_transform = iaa.Sequential([iaa.Scale(0.5)]) # Not worth scaling image, haven't found a fast scaler. image_transform = Compose([ #Resize(data_options['image_size']), RangeNormalize(0., 1.0), # Faster than the one in BatchGenerators ChannelFirst(), # RangeTransform(), MeanStdNormalizationTransform(mean=[0.3610,0.2131,0.2324], std=[0.0624,0.0463,0.0668]), NumpyToTensor(keys=['data', 'target']) ]) # kfoldWorkflowSet = kFoldWorkflowSplitMT('/home/anant/data/endovis/COMPRESSED_0_05/TrainingSet/', # image_transform=image_transform, # video_extn='.avi', shuffle=True, # n_folds=3, num_phases=14, # batch_size=32, # num_workers=12) kfoldWorkflowSet = kFoldWorkflowSplitMT(data_options['base_path'], image_transform=image_transform, video_extn='.avi', shuffle=True, n_folds=data_options['n_folds'], num_phases=14, batch_size=data_options['batch_size'], num_workers=data_options['n_threads'], video_folder='videos_480x272') ## ======================================= Data ======================================= ## nfolds_training_loss_avg = CumulativeMovingAvgStd() nfolds_validation_loss_avg = CumulativeMovingAvgStd() nfolds_validation_score_avg = CumulativeMovingAvgStd() folds_pbar = ProgressBar(kfoldWorkflowSet, desc="Folds", pb_len=optimizer_options['run_nfolds']) max_folds = folds_pbar.total for iFold, (train_loader, val_loader) in enumerate(folds_pbar): #= next(kfoldWorkflowSet) ## ======================================= Create Plot ======================================= ## create_plot_window(vis, "Epochs+Iterations", "CE Loss", "Training loss Fold "+str(iFold+1), tag='Training_Loss_Fold_'+str(iFold+1), name='Training Loss Fold '+str(iFold+1)) create_plot_window(vis, "Epochs+Iterations", "CE Loss", "Validation loss Fold "+str(iFold+1), tag='Validation_Loss_Fold_'+str(iFold+1), name='Validation Loss Fold '+str(iFold+1)) create_plot_window(vis, "Epochs+Iterations", "Score", "Validation Score Fold "+str(iFold+1), tag='Validation_Score_Fold_'+str(iFold+1), name='Validation Loss Fold '+str(iFold+1)) ## ======================================= Create Plot ======================================= ## ## ======================================= Model ======================================= ## # TODO: Pass 'models.resnet50' as string model = ResFeatureExtractor(pretrained_model=models.resnet101, device=device) if model_options['pretrained'] is not None: # print('Loading pretrained model...') checkpoint = torch.load(model_options['pretrained']) model.load_state_dict(checkpoint['model']) ## ======================================= Model ======================================= ## ### ============================== Parts of Training step ============================== ### criterion_CE = nn.CrossEntropyLoss().to(device) ### ============================== Parts of Training step ============================== ### epoch_pbar = ProgressBar(range(epochs), desc="Epochs") #tqdm(range(epochs)) epoch_training_avg_loss = CumulativeMovingAvgStd() epoch_training_avg_score = CumulativeMovingAvgStd() epoch_validation_loss = BestScore() # epoch_validation_score = BestScore() epoch_msg_dict = {} evaluator = Engine(model, None, criterion_CE, None, val_loader, 0, device, False, use_half_precision=optimizer_options["use_half_precision"], score_type="f1") for epoch in epoch_pbar: if (optimizer_options['switch_optimizer'] > 0) and ((epoch+1) % optimizer_options['switch_optimizer'] == 0): temp_optimizer_options = optimizer_options temp_optimizer_options['optimizer'] = 'sgd' temp_optimizer_options['learning_rate'] = 1e-3 optimizer, scheduler = get_optimizer(model.parameters(), temp_optimizer_options, scheduler_options, train_loader, vis) else: optimizer, scheduler = get_optimizer(model.parameters(), optimizer_options, scheduler_options, train_loader, vis) # else: # optimizer, scheduler = get_optimizer(model.parameters(), optimizer_options, scheduler_options, train_loader, vis) runEpoch(train_loader, model, criterion_CE, optimizer, scheduler, device, vis, epoch, iFold, folds_pbar, epoch_training_avg_loss, epoch_training_avg_score, logger_options, optimizer_options, epoch_msg_dict) ### ============================== Validation ============================== ### validation_loss, validation_score = None, None if (optimizer_options["validation_interval_epochs"] > 0): if ((epoch+1) % optimizer_options["validation_interval_epochs"] == 0): validation_loss, validation_score = predict(evaluator, optimizer_options['max_valid_iterations'], device, vis) epoch_validation_loss.step(validation_loss, [validation_score]) # epoch_validation_score.step(validation_score) vis.line(X=np.array([epoch]), Y=np.array([validation_loss]), update='append', win='Validation_Loss_Fold_'+str(iFold+1), name='Validation Loss Fold '+str(iFold+1)) vis.line(X=np.array([epoch]), Y=np.array([validation_score]), update='append', win='Validation_Score_Fold_'+str(iFold+1), name='Validation Score Fold '+str(iFold+1)) epoch_msg_dict['CVL'] = validation_loss epoch_msg_dict['CVS'] = validation_score epoch_msg_dict['BVL'] = epoch_validation_loss.score()[0] epoch_msg_dict['BVS'] = epoch_validation_loss.score()[1][0] folds_pbar.update_message(msg_dict=epoch_msg_dict) ### ============================== Validation ============================== ### ### ============================== Save model ============================== ### model_checkpoint.step(curr_loss=validation_loss, model=model, suffix='_Fold_'+str(iFold)) vis.save([logger_options['vislogger_env']]) ### ============================== Save model ============================== ### # if early_stop: # epoch_pbar.close() # break # torch.cuda.empty_cache() if (iFold+1) == max_folds: folds_pbar.refresh() folds_pbar.close() break print("\n\n\n\n=================================== DONE ===================================\n\n")