Ejemplo n.º 1
0
    def __init__(
        self,
        betas=(0.5, 0.999),
        criterion_class='torch.nn.MSELoss',
        init_weights=True,
        lr=0.001,
        nn_class='fnet.nn_modules.fnet_nn_3d.Net',
        nn_kwargs={},
        scheduler=None,
        weight_decay=0,
        gpu_ids=-1,
    ):
        self.betas = betas
        self.criterion = str_to_class(criterion_class)()
        self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids
        self.init_weights = init_weights
        self.lr = lr
        self.nn_class = nn_class
        self.nn_kwargs = nn_kwargs
        self.scheduler = scheduler
        self.weight_decay = weight_decay

        self.count_iter = 0
        self.device = (torch.device('cuda', self.gpu_ids[0])
                       if self.gpu_ids[0] >= 0 else torch.device('cpu'))
        self.optimizer = None
        self._init_model()
        self.fnet_model_kwargs, self.fnet_model_posargs = get_args()
        self.fnet_model_kwargs.pop('self')
Ejemplo n.º 2
0
 def _init_model(self):
     self.net = str_to_class(self.nn_class)(**self.nn_kwargs)
     if self.init_weights:
         self.net.apply(_weights_init)
     self.net.to(self.device)
     self.optimizer = torch.optim.Adam(
         get_per_param_options(self.net, wd=self.weight_decay),
         lr=self.lr,
         betas=self.betas,
     )
     if self.scheduler is not None:
         if self.scheduler[0] == 'snapshot':
             period = self.scheduler[1]
             self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                 self.optimizer,
                 lambda x: (0.01 + (1 - 0.01) *
                            (0.5 + 0.5 * math.cos(math.pi *
                                                  (x % period) / period))),
             )
         elif self.scheduler[0] == 'step':
             step_size = self.scheduler[1]
             self.scheduler = torch.optim.lr_scheduler.StepLR(
                 self.optimizer, step_size)
         else:
             raise NotImplementedError
Ejemplo n.º 3
0
def _load_model(path_model: str) -> Model:
    """Load saved model from path."""
    state = torch.load(path_model)
    fnet_model_class = state['fnet_model_class']
    fnet_model_kwargs = state['fnet_model_kwargs']
    model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
    model.load_state(state, no_optim=True)
    return model
Ejemplo n.º 4
0
def load_model(
    path_model: str,
    no_optim: bool = False,
    checkpoint: Optional[str] = None,
    path_options: Optional[str] = None,
) -> Model:
    """Loaded saved FnetModel.

    Parameters
    ----------
    path_model
        Path to model as a directory or .p file.
    no_optim
        Set to not the model optimizer.
    checkpoint
        Optional string that identifies a model checkpoint
    path_options
        Path to training options json. For legacy saved models where the
        FnetModel class/kwargs are not not included in the model save file.

    Returns
    -------
    Model
        Loaded model.

    """
    if not os.path.exists(path_model):
        raise ValueError(f'Model path does not exist: {path_model}')
    if os.path.isdir(path_model):
        if checkpoint is None:
            path_model = os.path.join(path_model, 'model.p')
            if not os.path.exists(path_model):
                raise ValueError(f'Default model not found: {path_model}')
        if checkpoint is not None:
            paths = sorted([
                p.path
                for p in os.scandir(os.path.join(path_model, 'checkpoints'))
                if p.path.endswith('.p')
            ])
            for path in paths:
                if checkpoint in os.path.basename(path):
                    path_model = path
                    break
            else:
                raise ValueError(f'Model checkpoint not found: {checkpoint}')
    state = torch.load(path_model)
    if 'fnet_model_class' not in state:
        if path_options is not None:
            with open(path_options, 'r') as fi:
                train_options = json.load(fi)
            if 'fnet_model_class' in train_options:
                state['fnet_model_class'] = train_options['fnet_model_class']
                state['fnet_model_kwargs'] = train_options['fnet_model_kwargs']
    fnet_model_class = state.get('fnet_model_class', 'fnet.models.Model')
    fnet_model_kwargs = state.get('fnet_model_kwargs', {})
    model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
    model.load_state(state, no_optim)
    return model
Ejemplo n.º 5
0
def load_model(
    path_model: str,
    no_optim: bool = False,
    checkpoint: Optional[str] = None,
    path_options: Optional[str] = None,
) -> Model:
    """Loaded saved FnetModel.

    Parameters
    ----------
    path_model
        Path to model as a directory or .p file.
    no_optim
        Set to not the model optimizer.
    checkpoint
        Optional string that identifies a model checkpoint
    path_options
        Path to training options json. For legacy saved models where the
        FnetModel class/kwargs are not not included in the model save file.

    Returns
    -------
    Model
        Loaded model.

    """
    if not os.path.exists(path_model):
        raise ValueError(f"Model path does not exist: {path_model}")
    if os.path.isdir(path_model):
        if checkpoint is None:
            path_model = os.path.join(path_model, "model.p")
            if not os.path.exists(path_model):
                raise ValueError(f"Default model not found: {path_model}")
        if checkpoint is not None:
            path_model = _find_model_checkpoint(path_model, checkpoint)
    state = torch.load(path_model)
    if "fnet_model_class" not in state:
        if path_options is not None:
            with open(path_options, "r") as fi:
                train_options = json.load(fi)
            if "fnet_model_class" in train_options:
                state["fnet_model_class"] = train_options["fnet_model_class"]
                state["fnet_model_kwargs"] = train_options["fnet_model_kwargs"]
    fnet_model_class = state.get("fnet_model_class", "fnet.models.Model")
    fnet_model_kwargs = state.get("fnet_model_kwargs", {})
    model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
    model.load_state(state, no_optim)
    return model
Ejemplo n.º 6
0
def load_model(
        path_model: str,
        no_optim: bool = False,
        path_options: Optional[str] = None,
) -> Model:
    """Loaded saved FnetModel.

    Parameters
    ----------
    path_model
        Path to model. If path is a directory, assumes directory contains an
        ensemble of models.
    no_optim
        Set to not the model optimizer.
    path_options
        Path to training options json. For legacy saved models where the
        FnetModel class/kwargs are not not included in the model save file.

    Returns
    -------
    Model or FnetEnsemble
        Loaded model.

    """
    if os.path.isdir(path_model):
        return FnetEnsemble(path_model)
    state = torch.load(path_model)
    if 'fnet_model_class' not in state:
        if path_options is not None:
            with open(path_options, 'r') as fi:
                train_options = json.load(fi)
            if 'fnet_model_class' in train_options:
                state['fnet_model_class'] = train_options['fnet_model_class']
                state['fnet_model_kwargs'] = train_options['fnet_model_kwargs']
    fnet_model_class = state.get('fnet_model_class', 'fnet.models.Model')
    fnet_model_kwargs = state.get('fnet_model_kwargs', {})
    model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
    model.load_state(state, no_optim)
    return model
Ejemplo n.º 7
0
def get_dataloader(args, n_iter_remaining, validation=False):
    dataset_kwargs = copy.deepcopy(args.dataset_kwargs)
    path_csv = (args.path_dataset_csv
                if not validation else args.path_dataset_val_csv)
    if path_csv is not None:
        assert 'path_csv' not in dataset_kwargs, 'dataset csv specified twice'
        dataset_kwargs['path_csv'] = path_csv
    ds = str_to_class(args.dataset_class)(**dataset_kwargs)
    bpds_kwargs = copy.deepcopy(args.bpds_kwargs)
    assert 'dataset' not in bpds_kwargs
    if not validation:
        bpds_kwargs['npatches'] = n_iter_remaining * args.batch_size
    else:
        bpds_kwargs['buffer_size'] = len(ds)
        bpds_kwargs['buffer_switch_frequency'] = -1
        bpds_kwargs['npatches'] = 4 * args.batch_size
    print(bpds_kwargs)
    bpds = fnet.data.BufferedPatchDataset(dataset=ds, **bpds_kwargs)
    dataloader = torch.utils.data.DataLoader(
        bpds,
        batch_size=args.batch_size,
    )
    return dataloader
Ejemplo n.º 8
0
    def __init__(
            self,
            betas=(0.5, 0.999),
            criterion_class='torch.nn.MSELoss',
            init_weights=True,
            lr=0.001,
            nn_class='fnet.nn_modules.fnet_nn_3d.Net',
            nn_kwargs={},
            nn_module=None,
            scheduler=None,
            weight_decay=0,
            gpu_ids=-1,
    ):
        self.betas = betas
        self.criterion = str_to_class(criterion_class)()
        self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids
        self.init_weights = init_weights
        self.lr = lr
        self.nn_class = nn_class
        self.nn_kwargs = nn_kwargs
        self.scheduler = scheduler
        self.weight_decay = weight_decay

        # *** Legacy support ***
        # self.nn_module might be specified in legacy saves.
        # If so, override self.nn_class
        if nn_module is not None:
            self.nn_class = nn_module + '.Net'
        self.count_iter = 0
        self.device = (
            torch.device('cuda', self.gpu_ids[0])
            if self.gpu_ids[0] >= 0
            else torch.device('cpu')
        )
        self._init_model()
        self.fnet_model_kwargs, self.fnet_model_posargs = get_args()
        self.fnet_model_kwargs.pop('self')
Ejemplo n.º 9
0
def load_model(
    path_model: str,
    no_optim: bool = False,
    path_options: Optional[str] = None,
):
    """Loaded saved FnetModel.

    Parameters
    ----------
    path_model
        Path to file in which saved model is saved.
    no_optim
        Set to not the model optimizer.
    path_options
        Path to training options json. For legacy saved models where the
        FnetModel class/kwargs are not not included in the model save file.

    Returns
    -------
    FnetModel
        Loaded FnetModel instance.

    """
    state = torch.load(path_model)
    if 'fnet_model_class' not in state:
        if path_options is not None:
            with open(path_options, 'r') as fi:
                train_options = json.load(fi)
            if 'fnet_model_class' in train_options:
                state['fnet_model_class'] = train_options['fnet_model_class']
                state['fnet_model_kwargs'] = train_options['fnet_model_kwargs']
    fnet_model_class = state.get('fnet_model_class', 'fnet.models.Model')
    fnet_model_kwargs = state.get('fnet_model_kwargs', {})
    model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
    model.load_state(state, no_optim)
    return model
Ejemplo n.º 10
0
def load_or_init_model(path_model: str, path_options: str):
    """Loaded saved model if it exists otherwise inititialize new model.

    Parameters
    ----------
    path_model
        Path to saved model.
    path_options
        Path to json where model training options are saved.

    Returns
    -------
    FnetModel
        Loaded or new FnetModel instance.

    """
    if not os.path.exists(path_model):
        with open(path_options, 'r') as fi:
            train_options = json.load(fi)
        print('DEBUG: Initializing new model!')
        fnet_model_class = train_options['fnet_model_class']
        fnet_model_kwargs = train_options['fnet_model_kwargs']
        return str_to_class(fnet_model_class)(**fnet_model_kwargs)
    return load_model(path_model, path_options=path_options)