예제 #1
0
def load_weights(
    model: tf.keras.models.Model,
    path: tk.typing.PathLike,
    by_name: bool = False,
    skip_mismatch: bool = False,
    skip_not_exist: bool = False,
    strict: bool = True,
    strict_fraction: float = 0.95,
) -> bool:
    """モデルの重みの読み込み。

    Args:
        model: モデル
        path: ファイルパス
        by_name: レイヤー名が一致する重みを読むモードにするならTrue。Falseなら並び順。
        skip_mismatch: shapeが不一致の場合にskipするならTrue。(by_name=Trueの場合のみ有効)
        skip_not_exist: ファイルが存在しない場合にエラーにしないならTrue。
        strict: 読み込み前と重みがあまり変わらなかったらエラーにする。
        strict_fraction: 重み不一致率の最低値。これ以下ならエラーにする。

    Returns:
        読み込んだか否か。skip_not_exist=Trueの場合に限りFalseが返る可能性がある。

    """
    path = pathlib.Path(path)
    if path.exists():
        with tk.log.trace(f"load_weights({path})"):
            if strict:
                old_weights = model.get_weights()
            if path.is_dir():
                # SavedModelはload_weights未対応?
                # TODO: by_name, skip_mismatch対応?
                loaded_model = tf.keras.models.load_model(str(path),
                                                          compile=False)
                model.set_weights(loaded_model.get_weights())
            else:
                model.load_weights(str(path),
                                   by_name=by_name,
                                   skip_mismatch=skip_mismatch)
            if strict:
                new_weights = model.get_weights()
                changed_params = np.sum([
                    np.sum(np.not_equal(w1, w2))
                    for w1, w2 in zip(old_weights, new_weights)
                ])
                num_params = np.sum([w.size for w in new_weights])
                r = changed_params / num_params
                msg = f"{changed_params:,} params chagnged. ({r:.1%})"
                if r < strict_fraction:
                    raise RuntimeError(msg)
                tk.log.get(__name__).info(msg)
            # 念のため重みのfingerprintをログ出力しておく
            tk.log.get(__name__).info(
                f"fingerprint: {tk.models.fingerprint(model)}")
    elif skip_not_exist:
        tk.log.get(__name__).info(f"{path} is not found.")
        return False
    else:
        raise RuntimeError(f"{path} is not found.")
    return True
예제 #2
0
파일: models.py 프로젝트: ak110/pytoolkit
def fingerprint(model: tf.keras.models.Model) -> str:
    """重みの同一性を確認するための文字列を作成して返す。"xx:xx:xx:xx"形式。"""
    m = hashlib.sha256()
    for w in model.get_weights():
        m.update(w.tobytes())
    h = m.hexdigest()
    return f"{h[:2]}:{h[2:4]}:{h[4:6]}:{h[6:8]}"
예제 #3
0
    def save_weights(self, model: tf.keras.models.Model):
        """Update in-memory and on-disk weights from `model`."""
        weights = model.get_weights()
        weights = {model.trainable_weights[i].name: weights[i] for i in range(len(weights))}

        for weight_name in weights:
            self.weights[weight_name] = weights[weight_name]
        self._save_weights_to_disk()
예제 #4
0
    def load_weights(self, model: tf.keras.models.Model):
        """Load weights into `model` by name."""
        weights = model.get_weights()
        weights = OrderedDict((model.trainable_weights[i].name, weights[i]) for i in range(len(weights)))

        # override random weights with the pre-trained ones
        # because the number of leaf outputs in a block can change with mutation, the size of the
        # weight matrix of the conv 1x1 on the block's output can change. We don't have a good way
        # to account for and share a weight matrix for this (yet?), so ignore such weights
        conv_1x1 = re.compile('repeat_\d+/(block_\d+|reduce)/conv')
        for weight_name in weights:
            if weight_name in self.weights and not re.match(conv_1x1, weight_name):
                assert weights[weight_name].shape == self.weights[weight_name].shape, f"{weights[weight_name].shape}, {self.weights[weight_name].shape} ({weight_name})"
                weights[weight_name] = self.weights[weight_name]

        model.set_weights(list(weights.values()))
예제 #5
0
 def synchronize(self, original: tf.keras.models.Model, target: tf.keras.models.Model):
     target.set_weights(original.get_weights())