def load_weights( model: tf.keras.models.Model, path: tk.typing.PathLike, by_name: bool = False, skip_mismatch: bool = False, skip_not_exist: bool = False, strict: bool = True, strict_fraction: float = 0.95, ) -> bool: """モデルの重みの読み込み。 Args: model: モデル path: ファイルパス by_name: レイヤー名が一致する重みを読むモードにするならTrue。Falseなら並び順。 skip_mismatch: shapeが不一致の場合にskipするならTrue。(by_name=Trueの場合のみ有効) skip_not_exist: ファイルが存在しない場合にエラーにしないならTrue。 strict: 読み込み前と重みがあまり変わらなかったらエラーにする。 strict_fraction: 重み不一致率の最低値。これ以下ならエラーにする。 Returns: 読み込んだか否か。skip_not_exist=Trueの場合に限りFalseが返る可能性がある。 """ path = pathlib.Path(path) if path.exists(): with tk.log.trace(f"load_weights({path})"): if strict: old_weights = model.get_weights() if path.is_dir(): # SavedModelはload_weights未対応? # TODO: by_name, skip_mismatch対応? loaded_model = tf.keras.models.load_model(str(path), compile=False) model.set_weights(loaded_model.get_weights()) else: model.load_weights(str(path), by_name=by_name, skip_mismatch=skip_mismatch) if strict: new_weights = model.get_weights() changed_params = np.sum([ np.sum(np.not_equal(w1, w2)) for w1, w2 in zip(old_weights, new_weights) ]) num_params = np.sum([w.size for w in new_weights]) r = changed_params / num_params msg = f"{changed_params:,} params chagnged. ({r:.1%})" if r < strict_fraction: raise RuntimeError(msg) tk.log.get(__name__).info(msg) # 念のため重みのfingerprintをログ出力しておく tk.log.get(__name__).info( f"fingerprint: {tk.models.fingerprint(model)}") elif skip_not_exist: tk.log.get(__name__).info(f"{path} is not found.") return False else: raise RuntimeError(f"{path} is not found.") return True
def fingerprint(model: tf.keras.models.Model) -> str: """重みの同一性を確認するための文字列を作成して返す。"xx:xx:xx:xx"形式。""" m = hashlib.sha256() for w in model.get_weights(): m.update(w.tobytes()) h = m.hexdigest() return f"{h[:2]}:{h[2:4]}:{h[4:6]}:{h[6:8]}"
def save_weights(self, model: tf.keras.models.Model): """Update in-memory and on-disk weights from `model`.""" weights = model.get_weights() weights = {model.trainable_weights[i].name: weights[i] for i in range(len(weights))} for weight_name in weights: self.weights[weight_name] = weights[weight_name] self._save_weights_to_disk()
def load_weights(self, model: tf.keras.models.Model): """Load weights into `model` by name.""" weights = model.get_weights() weights = OrderedDict((model.trainable_weights[i].name, weights[i]) for i in range(len(weights))) # override random weights with the pre-trained ones # because the number of leaf outputs in a block can change with mutation, the size of the # weight matrix of the conv 1x1 on the block's output can change. We don't have a good way # to account for and share a weight matrix for this (yet?), so ignore such weights conv_1x1 = re.compile('repeat_\d+/(block_\d+|reduce)/conv') for weight_name in weights: if weight_name in self.weights and not re.match(conv_1x1, weight_name): assert weights[weight_name].shape == self.weights[weight_name].shape, f"{weights[weight_name].shape}, {self.weights[weight_name].shape} ({weight_name})" weights[weight_name] = self.weights[weight_name] model.set_weights(list(weights.values()))
def synchronize(self, original: tf.keras.models.Model, target: tf.keras.models.Model): target.set_weights(original.get_weights())