예제 #1
0
    def _init_train_components_ensemble(self, reinitialise=False):
        self.val_metrics = {
            'loss': metrics.Loss(BCEAndJaccardLoss(eval_ensemble=True, gpu_node=self.config.gpu_node)),
            'segment_metrics': SegmentationMetrics(num_classes=self.data_loaders.num_classes,
                                                   threshold=self.config.binarize_threshold,
                                                   eval_ensemble=True)
        }

        self.model_cfg.network_params.input_channels = self.data_loaders.input_channels
        self.model_cfg.network_params.num_classes = self.data_loaders.num_classes
        self.model_cfg.network_params.image_size = self.data_loaders.image_size

        self.criterion = self._init_criterion(not reinitialise)
        self.ens_models = list()
        self.ens_optimizers = list()
        self.ens_lr_schedulers = list()

        optimizer_cls = get_optimizer(self.optim_cfg)
        init_param_names = retrieve_class_init_parameters(optimizer_cls)
        optimizer_params = {k: v for k, v in self.optim_cfg.items() if k in init_param_names}

        for _ in range(self.len_models):
            self.ens_models.append(get_model(self.model_cfg).to(device=self.device))
            self.ens_optimizers.append(optimizer_cls(self.ens_models[-1].parameters(), **optimizer_params))

            lr_scheduler = self._init_lr_scheduler(self.ens_optimizers[-1])
            self.ens_lr_schedulers.append(lr_scheduler)

        if not reinitialise:
            self.main_logger.info(f'Using ensemble of {self.len_models} {self.ens_models[0]}')
            self.main_logger.info(f'Using optimizers {self.ens_optimizers[0].__class__.__name__}')

        self.trainer, self.evaluator = self._init_engines()

        self._init_handlers()
예제 #2
0
def get_model(model_cfg):
    model_name = model_cfg.name
    model_cls = _get_model_instance(model_name, model_cfg.type)

    init_param_names = retrieve_class_init_parameters(model_cls)
    param_dict = {
        k: v
        for k, v in model_cfg.network_params.items() if k in init_param_names
    }

    model = model_cls(**param_dict)

    return model
예제 #3
0
    def _init_lr_scheduler(self, optimizer):
        if self.optim_cfg.scheduler == 'step':
            scheduler_class = optim.lr_scheduler.StepLR
        elif self.optim_cfg.scheduler == 'plateau':
            scheduler_class = optim.lr_scheduler.ReduceLROnPlateau
        else:
            return None

        init_param_names = retrieve_class_init_parameters(scheduler_class)
        scheduler_params = {k: v for k, v in self.optim_cfg.scheduler_params.items() if k in init_param_names}

        lr_scheduler = scheduler_class(optimizer, **scheduler_params)

        return lr_scheduler
예제 #4
0
    def _init_optimizer(self, init=True):
        optimizer_cls = get_optimizer(self.optim_cfg)

        init_param_names = retrieve_class_init_parameters(optimizer_cls)
        optimizer_params = {k: v for k, v in self.optim_cfg.items() if k in init_param_names}

        optimizer = optimizer_cls(self.model.parameters(), **optimizer_params)
        if init:
            self.main_logger.info(f'Using optimizer {optimizer.__class__.__name__}')

        if self.resume_cfg.resume_from is not None and self.resume_cfg.saved_optimizer is not None and init:
            optimizer_path = get_resume_optimizer_path(self.resume_cfg.resume_from, self.resume_cfg.saved_optimizer)
            if init:
                self.main_logger.info(f'Loading optimizer from {optimizer_path}')
            optimizer.load_state_dict(torch.load(optimizer_path))

        return optimizer
예제 #5
0
def get_loss_function(loss_cfg):
    loss_name = loss_cfg.name

    if loss_name is None:
        return nn.BCEWithLogitsLoss()
    else:
        if loss_name not in loss2class:
            raise NotImplementedError(f"Loss {loss_name} not implemented")

        loss_cls = loss2class[loss_name]

        init_param_names = retrieve_class_init_parameters(loss_cls)
        loss_params = {
            k: v
            for k, v in loss_cfg.items() if k in init_param_names
        }

        loss = loss_cls(**loss_params)

        return loss