示例#1
0
    def __init__(self,
                 trained_model: Type[torch.nn.Module],
                 refine: bool = False,
                 resize: Union[Tuple, List] = None,
                 use_gpu: bool = False,
                 logits: bool = True,
                 **kwargs: Union[int, float, bool]) -> None:
        """
        Initializes predictive object
        """
        super(SegPredictor, self).__init__(trained_model, use_gpu)
        set_train_rng(1)
        self.nb_classes = kwargs.get('nb_classes', None)
        if self.nb_classes is None:
            self.nb_classes = get_nb_classes(trained_model)
        self.downsampling = kwargs.get('downsampling', None)
        if self.downsampling is None:
            self.downsampling = get_downsample_factor(trained_model)

        self.resize = resize
        self.logits = logits
        self.refine = refine
        self.d = kwargs.get("d", None)
        self.thresh = kwargs.get("thresh", .5)
        self.use_gpu = use_gpu
        self.verbose = kwargs.get("verbose", True)
示例#2
0
 def __init__(self):
     set_train_rng(1)
     self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
     self.net = None
     self.criterion = None
     self.optimizer = None
     self.compute_accuracy = False
     self.full_epoch = True
     self.swa = False
     self.perturb_weights = False
     self.running_weights = {}
     self.training_cycles = 0
     self.batch_idx_train, self.batch_idx_test = [], []
     self.batch_size = 1
     self.nb_classes = None
     self.X_train, self.y_train = None, None
     self.X_test, self.y_test = None, None
     self.train_loader = torch.utils.data.TensorDataset()
     self.test_loader = torch.utils.data.TensorDataset()
     self.data_is_set = False
     self.augdict = {}
     self.augment_fn = None
     self.filename = "model"
     self.print_loss = 1
     self.meta_state_dict = dict()
     self.loss_acc = {
         "train_loss": [],
         "test_loss": [],
         "train_accuracy": [],
         "test_accuracy": []
     }
示例#3
0
 def __init__(self,
              trained_model: Type[torch.nn.Module],
              output_dim: Tuple[int],
              use_gpu: bool = False,
              **kwargs: str) -> None:
     """
     Initialize predictor
     """
     super(ImSpecPredictor, self).__init__(trained_model, use_gpu)
     if isinstance(output_dim, int):
         output_dim = (output_dim, )
     if len(output_dim) not in [1, 2]:
         raise ValueError(
             "output_dim must be a two-value tuple for images" +
             " and a single-value tuple for spectra")
     set_train_rng(1)
     self.output_dim = output_dim
     self.verbose = kwargs.get("verbose", True)
示例#4
0
 def __init__(self,
              model: Union[Type[torch.nn.Module], str] = "Unet",
              nb_classes: int = 1,
              **kwargs: Union[int, List, str, bool]) -> None:
     """
     Initialize a single FCNN model trainer
     """
     super(SegTrainer, self).__init__()
     seed = kwargs.get("seed", 1)
     kwargs["batch_seed"] = kwargs.get("batch_seed", seed)
     set_train_rng(seed)
     self.nb_classes = nb_classes
     self.net, self.meta_state_dict = init_fcnn_model(
         model, self.nb_classes, **kwargs)
     self.net.to(self.device)
     if self.device == 'cpu':
         warnings.warn("No GPU found. The training can be EXTREMELY slow",
                       UserWarning)
     self.meta_state_dict["weights"] = self.net.state_dict()
示例#5
0
    def __init__(self,
                 in_dim: Tuple[int],
                 out_dim: Tuple[int],
                 latent_dim: int = 2,
                 **kwargs: Union[int, bool, str]) -> None:
        super(ImSpecTrainer, self).__init__()
        """
        Initialize trainer's parameters
        """
        seed = kwargs.get("seed", 1)
        kwargs["batch_seed"] = kwargs.get("batch_seed", seed)
        set_train_rng(seed)

        self.in_dim, self.out_dim = in_dim, out_dim
        (self.net,
         self.meta_state_dict) = init_imspec_model(in_dim, out_dim, latent_dim,
                                                   **kwargs)

        self.net.to(self.device)
        self.meta_state_dict["weights"] = self.net.state_dict()
示例#6
0
 def __init__(self,
              X_train: np.ndarray,
              y_train: np.ndarray,
              X_test: np.ndarray = None,
              y_test: np.ndarray = None,
              n_models=30,
              model: str = "dilUnet",
              strategy: str = "from_baseline",
              swa=False,
              training_cycles_base: int = 1000,
              training_cycles_ensemble: int = 50,
              filename: str = "./model",
              **kwargs: Dict) -> None:
     """
     Initializes parameters of ensemble trainer
     """
     if X_test is None or y_test is None:
         X_train, X_test, y_train, y_test = train_test_split(
             X_train,
             y_train,
             test_size=kwargs.get("test_size", 0.15),
             shuffle=True,
             random_state=0)
     set_train_rng(seed=1)
     self.X_train, self.y_train = X_train, y_train
     self.X_test, self.y_test = X_test, y_test
     self.model_type, self.n_models = model, n_models
     self.strategy = strategy
     if self.strategy not in ["from_baseline", "from_scratch", "swag"]:
         raise NotImplementedError(
             "Select 'from_baseline' 'from_scratch', or 'swag'  strategy")
     self.iter_base = training_cycles_base
     if self.strategy == "from_baseline":
         self.iter_ensemble = training_cycles_ensemble
     self.filename, self.kdict = filename, kwargs
     if swa or self.strategy == 'swag':
         self.kdict["swa"] = True
         #self.kdict["use_batchnorm"] = False  # there were some issues when using batchnorm together with swa in pytorch 1.4
     self.ensemble_state_dict = {}
示例#7
0
    def __init__(self,
                 trained_model: Type[torch.nn.Module],
                 refine: bool = False,
                 resize: Union[Tuple, List] = None,
                 use_gpu: bool = False,
                 logits: bool = True,
                 seed: int = 1,
                 **kwargs: Union[int, float, bool]) -> None:
        """
        Initializes predictive object
        """
        if seed:
            set_train_rng(seed)
        model = trained_model
        self.nb_classes = kwargs.get('nb_classes', None)
        if self.nb_classes is None:
            hookF = [Hook(layer[1]) for layer in list(model._modules.items())]
            mock_forward(model)
            self.nb_classes = [hook.output.shape for hook in hookF][-1][1]
        self.downsampling = kwargs.get('downsampling', None)
        if self.downsampling is None:
            hookF = [Hook(layer[1]) for layer in list(model._modules.items())]
            mock_forward(model)
            imsize = [hook.output.shape[-1] for hook in hookF]
            self.downsampling = max(imsize) / min(imsize)
        self.model = model
        if use_gpu and torch.cuda.is_available():
            self.model.cuda()
        else:
            self.model.cpu()

        self.resize = resize
        self.logits = logits
        self.refine = refine
        self.d = kwargs.get("d", None)
        self.thresh = kwargs.get("thresh", .5)
        self.use_gpu = use_gpu
        self.verbose = kwargs.get("verbose", True)
示例#8
0
    def __init__(self,
                 X_train: training_data_types,
                 y_train: training_data_types,
                 X_test: training_data_types,
                 y_test: training_data_types,
                 training_cycles: int,
                 model: str = 'dilUnet',
                 IoU: bool = False,
                 seed: int = 1,
                 batch_seed: int = None,
                 **kwargs: Union[int, List, str, bool]) -> None:
        """
        Initialize single model trainer
        """
        if seed:
            set_train_rng(seed)
        if batch_seed is None:
            np.random.seed(seed)
        else:
            np.random.seed(batch_seed)
        self.batch_size = kwargs.get("batch_size", 32)
        self.full_epoch = kwargs.get("full_epoch", False)
        (self.X_train, self.y_train, self.X_test, self.y_test,
         self.num_classes) = preprocess_training_data(X_train, y_train, X_test,
                                                      y_test, self.batch_size)
        if self.full_epoch:
            self.train_loader, self.test_loader = init_torch_dataloaders(
                self.X_train, self.y_train, self.X_test, self.y_test,
                self.batch_size, self.num_classes)

        use_batchnorm = kwargs.get('use_batchnorm', True)
        use_dropouts = kwargs.get('use_dropouts', False)
        upsampling = kwargs.get('upsampling', "bilinear")

        self.swa = kwargs.get("swa", False)
        if self.swa:
            self.recent_weights = {}
        self.perturb_weights = kwargs.get("perturb_weights", False)
        if self.perturb_weights:
            use_batchnorm = False
            if isinstance(self.perturb_weights, bool):
                e_p = 1 if self.full_epoch else 50
                self.perturb_weights = {"a": .01, "gamma": 1.5, "e_p": e_p}

        if not isinstance(model, str) and hasattr(model, "state_dict"):
            self.net = model
        elif isinstance(model, str) and model == 'dilUnet':
            with_dilation = kwargs.get('with_dilation', True)
            nb_filters = kwargs.get('nb_filters', 16)
            layers = kwargs.get("layers", [1, 2, 2, 3])
            self.net = dilUnet(self.num_classes,
                               nb_filters,
                               use_dropouts,
                               use_batchnorm,
                               upsampling,
                               with_dilation,
                               layers=layers)
        elif isinstance(model, str) and model == 'dilnet':
            nb_filters = kwargs.get('nb_filters', 25)
            layers = kwargs.get("layers", [1, 3, 3, 3])
            self.net = dilnet(self.num_classes,
                              nb_filters,
                              use_dropouts,
                              use_batchnorm,
                              upsampling,
                              layers=layers)
        else:
            raise NotImplementedError(
                "Currently implemented models are 'dilUnet' and 'dilnet'")
        if torch.cuda.is_available():
            self.net.cuda()
        else:
            warnings.warn("No GPU found. The training can be EXTREMELY slow",
                          UserWarning)
        loss = kwargs.get('loss', "ce")
        if loss == 'dice':
            self.criterion = losses_metrics.dice_loss()
        elif loss == 'focal':
            self.criterion = losses_metrics.focal_loss()
        elif loss == 'ce' and self.num_classes == 1:
            self.criterion = torch.nn.BCEWithLogitsLoss()
        elif loss == 'ce' and self.num_classes > 2:
            self.criterion = torch.nn.CrossEntropyLoss()
        else:
            raise NotImplementedError(
                "Select Dice loss ('dice'), focal loss ('focal') or"
                " cross-entropy loss ('ce')")
        if not self.full_epoch:
            self.batch_idx_train = np.random.randint(0, len(self.X_train),
                                                     training_cycles)
            self.batch_idx_test = np.random.randint(0, len(self.X_test),
                                                    training_cycles)
            auglist = [
                "custom_transform", "zoom", "gauss_noise", "jitter",
                "poisson_noise", "contrast", "salt_and_pepper", "blur",
                "resize", "rotation", "background"
            ]
            self.augdict = {
                k: kwargs[k]
                for k in auglist if k in kwargs.keys()
            }
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3)
        self.training_cycles = training_cycles
        self.iou = IoU
        if self.iou:
            self.iou_score, self.iou_score_test = [], []
        self.print_loss = kwargs.get("print_loss")
        if self.print_loss is None:
            if not self.full_epoch:
                self.print_loss = 100
            else:
                self.print_loss = 1
        self.filename = kwargs.get("filename", "./model")
        self.plot_training_history = kwargs.get("plot_training_history", True)
        self.train_loss, self.test_loss = [], []
        if isinstance(model, str):
            self.meta_state_dict = {
                'model_type': model,
                'batchnorm': use_batchnorm,
                'dropout': use_dropouts,
                'upsampling': upsampling,
                'nb_filters': nb_filters,
                'layers': layers,
                'nb_classes': self.num_classes,
                'weights': self.net.state_dict()
            }
            if "with_dilation" in locals():
                self.meta_state_dict["with_dilation"] = with_dilation
        else:
            self.meta_state_dict = {
                'nb_classes': self.num_classes,
                'weights': self.net.state_dict()
            }
示例#9
0
 def _reset_rng(self, seed: int) -> None:
     """
     (re)sets seeds for pytorch and numpy random number generators
     """
     set_train_rng(seed)