def _init_train_components_ensemble(self, reinitialise=False): self.val_metrics = { 'loss': metrics.Loss(BCEAndJaccardLoss(eval_ensemble=True, gpu_node=self.config.gpu_node)), 'segment_metrics': SegmentationMetrics(num_classes=self.data_loaders.num_classes, threshold=self.config.binarize_threshold, eval_ensemble=True) } self.model_cfg.network_params.input_channels = self.data_loaders.input_channels self.model_cfg.network_params.num_classes = self.data_loaders.num_classes self.model_cfg.network_params.image_size = self.data_loaders.image_size self.criterion = self._init_criterion(not reinitialise) self.ens_models = list() self.ens_optimizers = list() self.ens_lr_schedulers = list() optimizer_cls = get_optimizer(self.optim_cfg) init_param_names = retrieve_class_init_parameters(optimizer_cls) optimizer_params = {k: v for k, v in self.optim_cfg.items() if k in init_param_names} for _ in range(self.len_models): self.ens_models.append(get_model(self.model_cfg).to(device=self.device)) self.ens_optimizers.append(optimizer_cls(self.ens_models[-1].parameters(), **optimizer_params)) lr_scheduler = self._init_lr_scheduler(self.ens_optimizers[-1]) self.ens_lr_schedulers.append(lr_scheduler) if not reinitialise: self.main_logger.info(f'Using ensemble of {self.len_models} {self.ens_models[0]}') self.main_logger.info(f'Using optimizers {self.ens_optimizers[0].__class__.__name__}') self.trainer, self.evaluator = self._init_engines() self._init_handlers()
def get_model(model_cfg): model_name = model_cfg.name model_cls = _get_model_instance(model_name, model_cfg.type) init_param_names = retrieve_class_init_parameters(model_cls) param_dict = { k: v for k, v in model_cfg.network_params.items() if k in init_param_names } model = model_cls(**param_dict) return model
def _init_lr_scheduler(self, optimizer): if self.optim_cfg.scheduler == 'step': scheduler_class = optim.lr_scheduler.StepLR elif self.optim_cfg.scheduler == 'plateau': scheduler_class = optim.lr_scheduler.ReduceLROnPlateau else: return None init_param_names = retrieve_class_init_parameters(scheduler_class) scheduler_params = {k: v for k, v in self.optim_cfg.scheduler_params.items() if k in init_param_names} lr_scheduler = scheduler_class(optimizer, **scheduler_params) return lr_scheduler
def _init_optimizer(self, init=True): optimizer_cls = get_optimizer(self.optim_cfg) init_param_names = retrieve_class_init_parameters(optimizer_cls) optimizer_params = {k: v for k, v in self.optim_cfg.items() if k in init_param_names} optimizer = optimizer_cls(self.model.parameters(), **optimizer_params) if init: self.main_logger.info(f'Using optimizer {optimizer.__class__.__name__}') if self.resume_cfg.resume_from is not None and self.resume_cfg.saved_optimizer is not None and init: optimizer_path = get_resume_optimizer_path(self.resume_cfg.resume_from, self.resume_cfg.saved_optimizer) if init: self.main_logger.info(f'Loading optimizer from {optimizer_path}') optimizer.load_state_dict(torch.load(optimizer_path)) return optimizer
def get_loss_function(loss_cfg): loss_name = loss_cfg.name if loss_name is None: return nn.BCEWithLogitsLoss() else: if loss_name not in loss2class: raise NotImplementedError(f"Loss {loss_name} not implemented") loss_cls = loss2class[loss_name] init_param_names = retrieve_class_init_parameters(loss_cls) loss_params = { k: v for k, v in loss_cfg.items() if k in init_param_names } loss = loss_cls(**loss_params) return loss