def __init__( self, betas=(0.5, 0.999), criterion_class='torch.nn.MSELoss', init_weights=True, lr=0.001, nn_class='fnet.nn_modules.fnet_nn_3d.Net', nn_kwargs={}, scheduler=None, weight_decay=0, gpu_ids=-1, ): self.betas = betas self.criterion = str_to_class(criterion_class)() self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids self.init_weights = init_weights self.lr = lr self.nn_class = nn_class self.nn_kwargs = nn_kwargs self.scheduler = scheduler self.weight_decay = weight_decay self.count_iter = 0 self.device = (torch.device('cuda', self.gpu_ids[0]) if self.gpu_ids[0] >= 0 else torch.device('cpu')) self.optimizer = None self._init_model() self.fnet_model_kwargs, self.fnet_model_posargs = get_args() self.fnet_model_kwargs.pop('self')
def _init_model(self): self.net = str_to_class(self.nn_class)(**self.nn_kwargs) if self.init_weights: self.net.apply(_weights_init) self.net.to(self.device) self.optimizer = torch.optim.Adam( get_per_param_options(self.net, wd=self.weight_decay), lr=self.lr, betas=self.betas, ) if self.scheduler is not None: if self.scheduler[0] == 'snapshot': period = self.scheduler[1] self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda x: (0.01 + (1 - 0.01) * (0.5 + 0.5 * math.cos(math.pi * (x % period) / period))), ) elif self.scheduler[0] == 'step': step_size = self.scheduler[1] self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size) else: raise NotImplementedError
def _load_model(path_model: str) -> Model: """Load saved model from path.""" state = torch.load(path_model) fnet_model_class = state['fnet_model_class'] fnet_model_kwargs = state['fnet_model_kwargs'] model = str_to_class(fnet_model_class)(**fnet_model_kwargs) model.load_state(state, no_optim=True) return model
def load_model( path_model: str, no_optim: bool = False, checkpoint: Optional[str] = None, path_options: Optional[str] = None, ) -> Model: """Loaded saved FnetModel. Parameters ---------- path_model Path to model as a directory or .p file. no_optim Set to not the model optimizer. checkpoint Optional string that identifies a model checkpoint path_options Path to training options json. For legacy saved models where the FnetModel class/kwargs are not not included in the model save file. Returns ------- Model Loaded model. """ if not os.path.exists(path_model): raise ValueError(f'Model path does not exist: {path_model}') if os.path.isdir(path_model): if checkpoint is None: path_model = os.path.join(path_model, 'model.p') if not os.path.exists(path_model): raise ValueError(f'Default model not found: {path_model}') if checkpoint is not None: paths = sorted([ p.path for p in os.scandir(os.path.join(path_model, 'checkpoints')) if p.path.endswith('.p') ]) for path in paths: if checkpoint in os.path.basename(path): path_model = path break else: raise ValueError(f'Model checkpoint not found: {checkpoint}') state = torch.load(path_model) if 'fnet_model_class' not in state: if path_options is not None: with open(path_options, 'r') as fi: train_options = json.load(fi) if 'fnet_model_class' in train_options: state['fnet_model_class'] = train_options['fnet_model_class'] state['fnet_model_kwargs'] = train_options['fnet_model_kwargs'] fnet_model_class = state.get('fnet_model_class', 'fnet.models.Model') fnet_model_kwargs = state.get('fnet_model_kwargs', {}) model = str_to_class(fnet_model_class)(**fnet_model_kwargs) model.load_state(state, no_optim) return model
def load_model( path_model: str, no_optim: bool = False, checkpoint: Optional[str] = None, path_options: Optional[str] = None, ) -> Model: """Loaded saved FnetModel. Parameters ---------- path_model Path to model as a directory or .p file. no_optim Set to not the model optimizer. checkpoint Optional string that identifies a model checkpoint path_options Path to training options json. For legacy saved models where the FnetModel class/kwargs are not not included in the model save file. Returns ------- Model Loaded model. """ if not os.path.exists(path_model): raise ValueError(f"Model path does not exist: {path_model}") if os.path.isdir(path_model): if checkpoint is None: path_model = os.path.join(path_model, "model.p") if not os.path.exists(path_model): raise ValueError(f"Default model not found: {path_model}") if checkpoint is not None: path_model = _find_model_checkpoint(path_model, checkpoint) state = torch.load(path_model) if "fnet_model_class" not in state: if path_options is not None: with open(path_options, "r") as fi: train_options = json.load(fi) if "fnet_model_class" in train_options: state["fnet_model_class"] = train_options["fnet_model_class"] state["fnet_model_kwargs"] = train_options["fnet_model_kwargs"] fnet_model_class = state.get("fnet_model_class", "fnet.models.Model") fnet_model_kwargs = state.get("fnet_model_kwargs", {}) model = str_to_class(fnet_model_class)(**fnet_model_kwargs) model.load_state(state, no_optim) return model
def load_model( path_model: str, no_optim: bool = False, path_options: Optional[str] = None, ) -> Model: """Loaded saved FnetModel. Parameters ---------- path_model Path to model. If path is a directory, assumes directory contains an ensemble of models. no_optim Set to not the model optimizer. path_options Path to training options json. For legacy saved models where the FnetModel class/kwargs are not not included in the model save file. Returns ------- Model or FnetEnsemble Loaded model. """ if os.path.isdir(path_model): return FnetEnsemble(path_model) state = torch.load(path_model) if 'fnet_model_class' not in state: if path_options is not None: with open(path_options, 'r') as fi: train_options = json.load(fi) if 'fnet_model_class' in train_options: state['fnet_model_class'] = train_options['fnet_model_class'] state['fnet_model_kwargs'] = train_options['fnet_model_kwargs'] fnet_model_class = state.get('fnet_model_class', 'fnet.models.Model') fnet_model_kwargs = state.get('fnet_model_kwargs', {}) model = str_to_class(fnet_model_class)(**fnet_model_kwargs) model.load_state(state, no_optim) return model
def get_dataloader(args, n_iter_remaining, validation=False): dataset_kwargs = copy.deepcopy(args.dataset_kwargs) path_csv = (args.path_dataset_csv if not validation else args.path_dataset_val_csv) if path_csv is not None: assert 'path_csv' not in dataset_kwargs, 'dataset csv specified twice' dataset_kwargs['path_csv'] = path_csv ds = str_to_class(args.dataset_class)(**dataset_kwargs) bpds_kwargs = copy.deepcopy(args.bpds_kwargs) assert 'dataset' not in bpds_kwargs if not validation: bpds_kwargs['npatches'] = n_iter_remaining * args.batch_size else: bpds_kwargs['buffer_size'] = len(ds) bpds_kwargs['buffer_switch_frequency'] = -1 bpds_kwargs['npatches'] = 4 * args.batch_size print(bpds_kwargs) bpds = fnet.data.BufferedPatchDataset(dataset=ds, **bpds_kwargs) dataloader = torch.utils.data.DataLoader( bpds, batch_size=args.batch_size, ) return dataloader
def __init__( self, betas=(0.5, 0.999), criterion_class='torch.nn.MSELoss', init_weights=True, lr=0.001, nn_class='fnet.nn_modules.fnet_nn_3d.Net', nn_kwargs={}, nn_module=None, scheduler=None, weight_decay=0, gpu_ids=-1, ): self.betas = betas self.criterion = str_to_class(criterion_class)() self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids self.init_weights = init_weights self.lr = lr self.nn_class = nn_class self.nn_kwargs = nn_kwargs self.scheduler = scheduler self.weight_decay = weight_decay # *** Legacy support *** # self.nn_module might be specified in legacy saves. # If so, override self.nn_class if nn_module is not None: self.nn_class = nn_module + '.Net' self.count_iter = 0 self.device = ( torch.device('cuda', self.gpu_ids[0]) if self.gpu_ids[0] >= 0 else torch.device('cpu') ) self._init_model() self.fnet_model_kwargs, self.fnet_model_posargs = get_args() self.fnet_model_kwargs.pop('self')
def load_model( path_model: str, no_optim: bool = False, path_options: Optional[str] = None, ): """Loaded saved FnetModel. Parameters ---------- path_model Path to file in which saved model is saved. no_optim Set to not the model optimizer. path_options Path to training options json. For legacy saved models where the FnetModel class/kwargs are not not included in the model save file. Returns ------- FnetModel Loaded FnetModel instance. """ state = torch.load(path_model) if 'fnet_model_class' not in state: if path_options is not None: with open(path_options, 'r') as fi: train_options = json.load(fi) if 'fnet_model_class' in train_options: state['fnet_model_class'] = train_options['fnet_model_class'] state['fnet_model_kwargs'] = train_options['fnet_model_kwargs'] fnet_model_class = state.get('fnet_model_class', 'fnet.models.Model') fnet_model_kwargs = state.get('fnet_model_kwargs', {}) model = str_to_class(fnet_model_class)(**fnet_model_kwargs) model.load_state(state, no_optim) return model
def load_or_init_model(path_model: str, path_options: str): """Loaded saved model if it exists otherwise inititialize new model. Parameters ---------- path_model Path to saved model. path_options Path to json where model training options are saved. Returns ------- FnetModel Loaded or new FnetModel instance. """ if not os.path.exists(path_model): with open(path_options, 'r') as fi: train_options = json.load(fi) print('DEBUG: Initializing new model!') fnet_model_class = train_options['fnet_model_class'] fnet_model_kwargs = train_options['fnet_model_kwargs'] return str_to_class(fnet_model_class)(**fnet_model_kwargs) return load_model(path_model, path_options=path_options)