Example #1
0
def _fe_compile(model: Model, optimizer_fn: Union[str, Scheduler, Callable,
                                                  None],
                weight: Union[str, None], name: str, framework: str) -> Model:
    """A function to bundle models with their optimizers.

    Args:
        model: The model to be bundled.
        optimizer_fn: The optimizer to be associated with the given `model`.
        weight: A path to weights to be associated with the `model`.
        name: The name of the model.
        framework: Which backend framework should be associated with this model (either 'tf' or 'torch').

    Returns:
        The `model` combined with its optimizer, weights, and name. Models will also have an 'fe_compiled' flag to
        indicate that they were built via this function.
    """
    if isinstance(optimizer_fn, EpochScheduler):
        for epoch, optimizer_def in optimizer_fn.epoch_dict.items():
            optimizer_fn.epoch_dict[epoch] = _build_optimizer(
                optimizer_def, model, framework)
        model.current_optimizer = optimizer_fn.get_current_value(1)
    elif isinstance(optimizer_fn, RepeatScheduler):
        for idx, optimizer_def in enumerate(optimizer_fn.repeat_list):
            optimizer_fn.repeat_list[idx] = _build_optimizer(
                optimizer_def, model, framework)
        model.current_optimizer = optimizer_fn.get_current_value(1)
    else:
        optimizer_fn = _build_optimizer(optimizer_fn, model, framework)
        model.current_optimizer = optimizer_fn
    model.optimizer = optimizer_fn
    model.fe_compiled = True

    if weight:
        if weight.startswith(GOOGLE_DRIVE_URL):
            tmp_dir = tempfile.mkdtemp()
            file_name = gdown.download(weight, quiet=False)
            os.rename(os.path.join('./', file_name),
                      os.path.join(tmp_dir, file_name))
            weight = gdown.download(weight,
                                    os.path.join(tmp_dir, file_name),
                                    quiet=False)

        load_model(model, weight)

    model.model_name = name
    return model
 def _load_files(self) -> None:
     """Restore from files.
     """
     system_path = os.path.join(self.directory, self.system_file)
     self.system.load_state(json_path=system_path)
     for model in self.system.network.models:
         if isinstance(model, tf.keras.Model):
             framework = "tf"
         elif isinstance(model, torch.nn.Module):
             framework = "torch"
         else:
             raise ValueError("Unknown model type {}".format(type(model)))
         weights_path = os.path.join(
             self.directory,
             "{}.{}".format(model.model_name,
                            self.model_extension[framework]))
         load_model(model, weights_path=weights_path, load_optimizer=True)
 def on_end(self, data: Data) -> None:
     if self.load_best_final and self.model_path:
         print(
             "FastEstimator-BestModelSaver: Restoring model from {}".format(
                 self.model_path))
         load_model(self.model, self.model_path)