def train(args): """ :param exper_hdl: :return: """ torch.manual_seed(5431232439) torch.cuda.manual_seed(5431232439) torch.backends.cudnn.enabled = True np.random.seed(6572345) # get fold from segmentation experiment settings. We train detection model per cross-validation fold seg_exper_settings = loadExperimentSettings( os.path.join(args.src_path_data, 'settings.yaml')) args.fold = seg_exper_settings.fold print("WARNING - processing fold {}".format(args.fold)) config_detector.get_architecture(args.network) # set number of input channels for initialization of model if args.input_channels != "allchannels": if args.input_channels == "mronly": config_detector.architecture['n_channels_input'] = 1 else: config_detector.architecture['n_channels_input'] = 2 print("WARNING - Using {} channels as input".format( config_detector.architecture['n_channels_input'])) if args.fn_penalty_weight is not None: config_detector.architecture[ "fn_penalty_weight"] = args.fn_penalty_weight print("WARNING - Using args fn_penalty_weight {:.3f}".format( config_detector.architecture["fn_penalty_weight"])) if args.output_directory is None: # synthesize output_dir = synthesize_output_dir(args, config_detector.architecture) args.output_directory = os.path.join( os.path.join(args.src_path_data, "dt_logs"), output_dir) else: args.output_directory = os.path.expanduser(args.output_directory) os.makedirs(args.output_directory, exist_ok=False) saveExperimentSettings( args, os.path.join(args.output_directory, 'settings.yaml')) saveExperimentSettings( config_detector.architecture, os.path.join(args.output_directory, 'architecture.yaml')) print(args) # get dataset dataset = create_dataset(args.fold, args.src_path_data, mc_dropout=args.mc_dropout, num_of_input_chnls=3, limited_load=args.limited_load, dt_config_id=args.dt_config_id, cardiac_phases=tuple(('ES', 'ED'))) # and finally we initialize something for visualization in visdom seg_model = args.src_path_data.split("/")[-1] dt_log_dir = args.output_directory.split("/")[-1] env = 'Detection{}-{}-{}_{}'.format(args.dataset, seg_model.replace("_", '-'), args.input_channels, dt_log_dir) vis = Visualizer(env, args.port, 'Learning curves of fold {}'.format(args.fold), ['training', 'validation']) vis_metrics = Visualizer( env, args.port, 'Grid detection prec/rec metrics fold {}'.format(args.fold), ['precision', 'recall', 'pr_auc']) vis_detection_rate = Visualizer( env, args.port, 'Slice/voxel detection rate fold {}'.format(args.fold), ['detection_rate', 'slice_tp_rate', 'slice_tn_rate']) do_balance_batch = True trainer = get_trainer(args, config_detector.architecture, model_file=None) try: for _ in tqdm(range(args.max_iters), desc="Train {}".format(args.network)): # store model if not trainer._train_iter % args.store_model_every: trainer.save(args.output_directory) # store learning curves if not trainer._train_iter % args.store_curves_every: trainer.save_losses(args.output_directory) # visualize example from validation set if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 0: vis(trainer.current_training_loss, trainer.current_validation_loss) # plot learning curve train_batch = BatchHandler(data_set=dataset, is_train=True, verbose=False, keep_bounding_boxes=False, input_channels=args.input_channels, num_of_max_pool_layers=config_detector. architecture['num_of_max_pool'], app_config=config_detector) x_input, ref_labels = train_batch(batch_size=args.batch_size, do_balance=do_balance_batch) y_labels = ref_labels[config_detector.max_grid_spacing] trainer.train(x_input, y_labels, y_labels_seg=train_batch.batch_labels_per_voxel) if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 0: val_batch = BatchHandler( data_set=dataset, is_train=False, verbose=False, keep_bounding_boxes=False, input_channels=args.input_channels, num_of_max_pool_layers=config_detector. architecture['num_of_max_pool'], app_config=config_detector) val_set_size = dataset.get_size(is_train=False) val_batch.last_test_list_idx = np.random.randint(0, val_set_size - 101, size=1) trainer.evaluate(val_batch, keep_batch=True) vis_metrics(trainer.validation_metrics['prec'], trainer.validation_metrics['rec'], trainer.validation_metrics['pr_auc']) dt_rate = trainer.validation_metrics[ 'detected_voxel_count'] / trainer.validation_metrics[ 'total_voxel_count'] dt_slice_tp = trainer.validation_metrics['tp_slice'] / (trainer.validation_metrics['tp_slice'] + \ trainer.validation_metrics['fn_slice']) dt_slice_tn = trainer.validation_metrics['tn_slice'] / (trainer.validation_metrics['tn_slice'] + \ trainer.validation_metrics['fp_slice']) vis_detection_rate(dt_rate, dt_slice_tp, dt_slice_tn) idx = 12 patid = val_batch.batch_patient_slice_id[idx][0] val_img = val_batch.keep_batch_images[idx][0][0] w, h, = val_img.shape vis.image((val_img**.5), 'image {}'.format(patid), 11) vis.image((val_batch.keep_batch_images[idx][0][1] / 0.9), 'uncertainty {}'.format(patid), 12) vis.image(val_batch.keep_batch_label_slices[idx] / 1.001, 'reference', 13) vis.image((val_batch.keep_batch_images[idx][0][2] / 1.001), 'seg mask', 16) p = np.squeeze(val_batch.batch_pred_probs[idx])[1] heat_map, grid_map, target_lbl_grid = create_grid_heat_map( p, config_detector.max_grid_spacing, w, h, prob_threshold=0.5) vis.image((heat_map**.5), 'grid predictions', 14) if args.network == "rsnup": p_mask = np.argmax(np.squeeze(trainer.val_segs[idx]), axis=0) vis.image(p_mask / 1.001, 'predictions', 15) del val_batch except KeyboardInterrupt: print('interrupted') finally: trainer.save(args.output_directory) trainer.save_losses(args.output_directory)
def main(): # first we obtain the user arguments, set random seeds, make directories, and store the experiment settings. args = parse_args() # Set resample always to True for ACDC args = get_network_settings(args) # End - overwriting args args.patch_size = tuple(args.patch_size) torch.manual_seed(5431232439) torch.cuda.manual_seed(5431232439) rs = np.random.RandomState(78346) os.makedirs(args.output_directory, exist_ok=True) saveExperimentSettings(args, path.join(args.output_directory, 'settings.yaml')) print(args) dta_settings = get_config(args.dataset) # we create a trainer n_classes = len(dta_settings.tissue_structure_labels) n_channels_input = 1 trainer, pad = get_trainer(args, n_classes, n_channels_input) # we initialize datasets with augmentations. training_augmentations = get_train_augmentations(args, rs, pad) validation_augmentations = [ datasets.augmentations.PadInput(pad, args.patch_size), datasets.augmentations.RandomCrop(args.patch_size, input_padding=pad, rs=rs), datasets.augmentations.BlurImage(sigma=0.9), datasets.augmentations.ToTensor() ] training_set, validation_set = get_datasets( args, dta_settings, transforms.Compose(training_augmentations), transforms.Compose(validation_augmentations)) # now we create dataloaders tra_sampler = RandomSampler(training_set, replacement=True, num_samples=args.batch_size * args.max_iters) val_sampler = RandomSampler(validation_set, replacement=True, num_samples=args.batch_size * args.max_iters) data_loader_training = torch.utils.data.DataLoader( training_set, batch_size=args.batch_size, sampler=tra_sampler, num_workers=args.number_of_workers, collate_fn=None) # _utils.collate.default_collate data_loader_validation = torch.utils.data.DataLoader( validation_set, batch_size=args.batch_size, sampler=val_sampler, num_workers=args.number_of_workers, collate_fn=None) # and finally we initialize something for visualization in visdom env_suffix = "f" + str(args.fold) + args.output_directory.split("_")[-1] vis = Visualizer( 'Segmentation{}-{}_{}'.format(args.dataset, args.network, env_suffix), args.port, 'Learning curves of fold {}'.format(args.fold), ['training', 'validation', 'aleatoric']) # try: for it, (training_batch, validation_batch) in tqdm( enumerate(zip(data_loader_training, data_loader_validation)), desc='Training', total=args.max_iters): # store model if not trainer._train_iter % args.store_model_every: trainer.save(args.output_directory) # store learning curves if not trainer._train_iter % args.store_curves_every: trainer.save_losses(args.output_directory) # visualize example from validation set if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 20: image = validation_batch['image'][0][None] val_output = trainer.predict(image) prediction = val_output['predictions'] reference = validation_batch['reference'][0] val_patient_id = validation_batch['patient_id'][0] image = image.detach().numpy() prediction = prediction.detach().numpy().astype( float) # .transpose(1, 2, 0) reference = reference.detach().numpy().astype(float) if pad > 0: # Note: image has shape [batch, 1, x, y], we get rid off extra padding in last two dimensions vis.image((image[0, 0, pad:-pad, pad:-pad]**.5), 'padded image {}'.format(val_patient_id), 12) else: vis.image((image[0]**.5), 'image {}'.format(val_patient_id), 11) vis.image(reference / 3, 'reference', 13) vis.image(prediction / 3, 'prediction', 14) # used log_softmax values if 'aleatoric' in val_output.keys(): vis.image(val_output['aleatoric'] / 0.9, 'aleatoric', 15) # # vis.image((prediction >= 0.5).astype(float), 'binary prediction', 15) # visualize learning curve vis(trainer.current_training_loss, trainer.current_validation_loss, trainer.current_aleatoric_loss) # plot learning curve # train on training mini-batch trainer.train(training_batch['image'].to(device), training_batch['reference'].to(device), ignore_label=None if 'ignore_label' not in training_batch.keys() else training_batch['ignore_label']) # evaluate on validation mini-batch trainer.evaluate(validation_batch['image'].to(device), validation_batch['reference'].to(device), ignore_label=None if 'ignore_label' not in validation_batch.keys() else validation_batch['ignore_label']) except KeyboardInterrupt: print('interrupted') finally: trainer.save(args.output_directory) trainer.save_losses(args.output_directory)