def _fe_compile(model: Model, optimizer_fn: Union[str, Scheduler, Callable, None], weight: Union[str, None], name: str, framework: str) -> Model: """A function to bundle models with their optimizers. Args: model: The model to be bundled. optimizer_fn: The optimizer to be associated with the given `model`. weight: A path to weights to be associated with the `model`. name: The name of the model. framework: Which backend framework should be associated with this model (either 'tf' or 'torch'). Returns: The `model` combined with its optimizer, weights, and name. Models will also have an 'fe_compiled' flag to indicate that they were built via this function. """ if isinstance(optimizer_fn, EpochScheduler): for epoch, optimizer_def in optimizer_fn.epoch_dict.items(): optimizer_fn.epoch_dict[epoch] = _build_optimizer( optimizer_def, model, framework) model.current_optimizer = optimizer_fn.get_current_value(1) elif isinstance(optimizer_fn, RepeatScheduler): for idx, optimizer_def in enumerate(optimizer_fn.repeat_list): optimizer_fn.repeat_list[idx] = _build_optimizer( optimizer_def, model, framework) model.current_optimizer = optimizer_fn.get_current_value(1) else: optimizer_fn = _build_optimizer(optimizer_fn, model, framework) model.current_optimizer = optimizer_fn model.optimizer = optimizer_fn model.fe_compiled = True if weight: if weight.startswith(GOOGLE_DRIVE_URL): tmp_dir = tempfile.mkdtemp() file_name = gdown.download(weight, quiet=False) os.rename(os.path.join('./', file_name), os.path.join(tmp_dir, file_name)) weight = gdown.download(weight, os.path.join(tmp_dir, file_name), quiet=False) load_model(model, weight) model.model_name = name return model
def _load_files(self) -> None: """Restore from files. """ system_path = os.path.join(self.directory, self.system_file) self.system.load_state(json_path=system_path) for model in self.system.network.models: if isinstance(model, tf.keras.Model): framework = "tf" elif isinstance(model, torch.nn.Module): framework = "torch" else: raise ValueError("Unknown model type {}".format(type(model))) weights_path = os.path.join( self.directory, "{}.{}".format(model.model_name, self.model_extension[framework])) load_model(model, weights_path=weights_path, load_optimizer=True)
def on_end(self, data: Data) -> None: if self.load_best_final and self.model_path: print( "FastEstimator-BestModelSaver: Restoring model from {}".format( self.model_path)) load_model(self.model, self.model_path)