def load_weight(self, base_directory: str, filename: str, model: keras.models.Model) -> keras.models.Model: if filename == '': model.load_weights( os.path.join(base_directory, self.weight_filename)) else: model.load_weights(filename) return model
def load_checkpoint(self, fs: FSBase, model: keras.models.Model) -> None: with tempfile.NamedTemporaryFile(suffix=".h5") as tf: local_fs = FileSystem() with fs.open("model.h5", "rb") as fin: local_fs.writefile(tf.name, fin) model.load_weights(tf.name)