def check( train_model: tf.keras.models.Model, pred_model: tf.keras.models.Model, models_dir: tk.typing.PathLike, dataset: tk.data.Dataset = None, train_data_loader: tk.data.DataLoader = None, pred_data_loader: tk.data.DataLoader = None, save_mode: str = "hdf5", ): """モデルの簡易動作確認用コード。 Args: train_model: 学習用モデル pred_model: 推論用モデル models_dir: 情報の保存先ディレクトリ dataset: チェック用データ (少数にしておくこと) train_data_loader: 学習用DataLoader pred_data_loader: 推論用DataLoader save_mode: 保存形式 ("hdf5", "saved_model", "onnx", "tflite"のいずれか) """ models_dir = pathlib.Path(models_dir) # summary表示 tk.models.summary(train_model) # グラフを出力 tk.models.plot(train_model, models_dir / "model.png") # save/loadの動作確認 (とりあえず落ちなければOKとする) with tempfile.TemporaryDirectory() as tmpdir: save_path = pathlib.Path(tmpdir) / f"model.{save_mode}" tk.models.save(pred_model, save_path) pred_model = tk.models.load(save_path) # train_model.evaluate if dataset is not None and train_data_loader is not None: ds, steps = train_data_loader.get_ds(dataset, shuffle=True) logger.info(f"train_model.evaluate: {ds.element_spec} {steps=}") values = train_model.evaluate(ds, steps=steps, verbose=1) if len(train_model.metrics_names) == 1: evals = {train_model.metrics_names[0]: values} else: evals = dict(zip(train_model.metrics_names, values)) logger.info(f"check.evaluate: {tk.evaluations.to_str(evals)}") # pred_model.predict if dataset is not None and pred_data_loader is not None: ds, steps = pred_data_loader.get_ds(dataset) logger.info(f"pred_model.evaluate: {ds.element_spec} {steps=}") pred = pred_model.predict(ds, steps=steps, verbose=1) if isinstance(pred, (list, tuple)): logger.info(f"check.predict: shape={[p.shape for p in pred]}") else: logger.info(f"check.predict: shape={pred.shape}") # train_model.fit if dataset is not None and train_data_loader is not None: ds, steps = train_data_loader.get_ds(dataset, shuffle=True) train_model.fit(ds, steps_per_epoch=steps, epochs=1, verbose=1)
def evaluate( model: tf.keras.models.Model, iterator: tk.data.Iterator, callbacks: typing.List[tf.keras.callbacks.Callback] = None, verbose: int = 1, ) -> typing.Dict[str, float]: """評価。 Args: model: モデル iterator: データ callbacks: コールバック verbose: 1ならプログレスバー表示 Returns: メトリクス名と値のdict """ with tk.log.trace("evaluate"): use_horovod = tk.hvd.is_active() verbose = verbose if tk.hvd.is_master() else 0 callbacks = make_callbacks(callbacks, training=False) dataset = tk.hvd.split( iterator.dataset) if use_horovod else iterator.dataset ds, steps = iterator.data_loader.get_ds(dataset) tk.log.get(__name__).info(f"evaluate: {ds.element_spec} {steps=}") values = model.evaluate( ds, steps=steps, verbose=verbose, callbacks=callbacks, ) values = tk.hvd.allreduce(values) if use_horovod else values if len(model.metrics_names) == 1: evals = {model.metrics_names[0]: values} else: evals = dict(zip(model.metrics_names, values)) return evals