def get_model(model, gpus=None, num_classes=1000, train_aug="default"): checkpoint = _checkpoints[model][train_aug] print(model, train_aug, checkpoint) if 'efficientnet-l2t' in model: norm_type = os.environ.get('norm', 'batch') model = EfficientNet.from_name(model, num_classes=num_classes, multiple_feat=True, norm_type=norm_type) # instance norm # IMAGENET ----------------------------------------------------- elif model == 'resnet50': model = resnet50(pretrained=True) else: raise ValueError(model) if checkpoint and checkpoint != 'modelzoo': state = torch.load(checkpoint) key = 'model' if 'model' in state else 'state_dict' if key in state and not isinstance(state[key], dict): key = 'state_dict' if 'omem' in train_aug: key = 'ema' print('model epoch=', state.get('epoch', -1)) if key in state: model.load_state_dict( {k.replace('module.', ''): v for k, v in state[key].items()}) else: model.load_state_dict( {k.replace('module.', ''): v for k, v in state.items()}) # without key if gpus not in ['cpu', None]: if len(gpus) > 1: model = DataParallel(model, device_ids=gpus) model = model.cuda() return model
class nnUNetTrainerV2CascadeFullRes_DP(nnUNetTrainerV2CascadeFullRes): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, num_gpus=1, distribute_batch_size=False, fp16=False, previous_trainer="nnUNetTrainerV2_DP"): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, previous_trainer, fp16) self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, num_gpus, distribute_batch_size, fp16, previous_trainer) self.num_gpus = num_gpus self.distribute_batch_size = distribute_batch_size self.dice_do_BG = False self.dice_smooth = 1e-5 if self.output_folder is not None: task = self.output_folder.split("/")[-3] plans_identifier = self.output_folder.split("/")[-2].split( "__")[-1] folder_with_segs_prev_stage = join( network_training_output_dir, "3d_lowres", task, previous_trainer + "__" + plans_identifier, "pred_next_stage") self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage else: self.folder_with_segs_from_prev_stage = None print(self.folder_with_segs_from_prev_stage) def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, True, oversample_foreground_percent=self. oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides) dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True, oversample_foreground_percent=self. oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides) else: raise NotImplementedError("2D has no cascade") return dl_tr, dl_val def process_plans(self, plans): super().process_plans(plans) if not self.distribute_batch_size: self.batch_size = self.num_gpus * self.plans['plans_per_stage'][ self.stage]['batch_size'] else: if self.batch_size < self.num_gpus: print( "WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well" ) elif self.batch_size % self.num_gpus != 0: print( "WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well" ) def initialize(self, training=True, force_load_plans=False): if not self.was_initialized: if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here we wrap the loss for deep supervision ############ net_numpool = len(self.net_num_pool_op_kernel_sizes) weights = np.array([1 / (2**i) for i in range(net_numpool)]) mask = np.array([ True if i < net_numpool - 1 else False for i in range(net_numpool) ]) weights[~mask] = 0 weights = weights / weights.sum() self.loss_weights = weights ################# END ################### self.folder_with_preprocessed_data = join( self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: if not isdir(self.folder_with_segs_from_prev_stage): raise RuntimeError( "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the " "segmentations for the next stage") self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params['patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, pin_memory=self.pin_memory) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, DataParallel)) else: self.print_to_log_file( 'self.was_initialized is True, not running self.initialize again' ) self.was_initialized = True def initialize_network(self): """ replace genericUNet with the implementation of above for super speeds """ if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet_DP( self.num_input_channels, self.base_num_features, self.num_classes, len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper def run_training(self): self.maybe_update_lr(self.epoch) # amp must be initialized before DP ds = self.network.do_ds self.network.do_ds = True # self.network = DataParallel(self.network, tuple(range(self.num_gpus)), ) self.network = DataParallel(self.network, device_ids=list(range(0, self.num_gpus))) ret = nnUNetTrainer.run_training(self) self.network = self.network.module self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: l.backward() clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return l.detach().cpu().numpy() def run_online_evaluation(self, tp_hard, fp_hard, fn_hard): tp_hard = tp_hard.detach().cpu().numpy().mean(0) fp_hard = fp_hard.detach().cpu().numpy().mean(0) fn_hard = fn_hard.detach().cpu().numpy().mean(0) self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def compute_loss(self, ces, tps, fps, fns): loss = None for i in range(len(ces)): if not self.dice_do_BG: tp = tps[i][:, 1:] fp = fps[i][:, 1:] fn = fns[i][:, 1:] else: tp = tps[i] fp = fps[i] fn = fns[i] if self.batch_dice: tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) else: pass nominator = 2 * tp + self.dice_smooth denominator = 2 * tp + fp + fn + self.dice_smooth dice_loss = (-nominator / denominator).mean() if loss is None: loss = self.loss_weights[i] * (ces[i].mean() + dice_loss) else: loss += self.loss_weights[i] * (ces[i].mean() + dice_loss) return loss
class nnUNetTrainerV2_DP(nnUNetTrainerV2): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, num_gpus=1, distribute_batch_size=False, fp16=False): super(nnUNetTrainerV2_DP, self).__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, num_gpus, distribute_batch_size, fp16) self.num_gpus = num_gpus self.distribute_batch_size = distribute_batch_size self.dice_smooth = 1e-5 self.dice_do_BG = False self.loss = None self.loss_weights = None def setup_DA_params(self): super(nnUNetTrainerV2_DP, self).setup_DA_params() self.data_aug_params['num_threads'] = 8 * self.num_gpus def process_plans(self, plans): super(nnUNetTrainerV2_DP, self).process_plans(plans) if not self.distribute_batch_size: self.batch_size = self.num_gpus * self.plans['plans_per_stage'][ self.stage]['batch_size'] else: if self.batch_size < self.num_gpus: print( "WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well" ) elif self.batch_size % self.num_gpus != 0: print( "WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well" ) def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: os.makedirs(self.output_folder, exist_ok=True) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here configure the loss for deep supervision ############ net_numpool = len(self.net_num_pool_op_kernel_sizes) weights = np.array([1 / (2**i) for i in range(net_numpool)]) mask = np.array([ True if i < net_numpool - 1 else False for i in range(net_numpool) ]) weights[~mask] = 0 weights = weights / weights.sum() self.loss_weights = weights ################# END ################### self.folder_with_preprocessed_data = join( self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params['patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, pin_memory=self.pin_memory) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, DataParallel)) else: self.print_to_log_file( 'self.was_initialized is True, not running self.initialize again' ) self.was_initialized = True def initialize_network(self): """ replace genericUNet with the implementation of above for super speeds """ if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet_DP( self.num_input_channels, self.base_num_features, self.num_classes, len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None def run_training(self): self.maybe_update_lr(self.epoch) # amp must be initialized before DP ds = self.network.do_ds self.network.do_ds = True self.network = DataParallel( self.network, tuple(range(self.num_gpus)), ) ret = nnUNetTrainer.run_training(self) self.network = self.network.module self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return l.detach().cpu().numpy() def run_online_evaluation(self, tp_hard, fp_hard, fn_hard): tp_hard = tp_hard.detach().cpu().numpy().mean(0) fp_hard = fp_hard.detach().cpu().numpy().mean(0) fn_hard = fn_hard.detach().cpu().numpy().mean(0) self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def compute_loss(self, ces, tps, fps, fns): # we now need to effectively reimplement the loss loss = None for i in range(len(ces)): if not self.dice_do_BG: tp = tps[i][:, 1:] fp = fps[i][:, 1:] fn = fns[i][:, 1:] else: tp = tps[i] fp = fps[i] fn = fns[i] if self.batch_dice: tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) else: pass nominator = 2 * tp + self.dice_smooth denominator = 2 * tp + fp + fn + self.dice_smooth dice_loss = (-nominator / denominator).mean() if loss is None: loss = self.loss_weights[i] * (ces[i].mean() + dice_loss) else: loss += self.loss_weights[i] * (ces[i].mean() + dice_loss) ########### return loss