コード例 #1
0
ファイル: hparams.py プロジェクト: bstriner/aae_losses
def write_hparams_events(model_dir, hparams):
    # Write HParams events
    hparams_dict = hparams.values()
    hparams_dict_train = hparams_dict.copy()
    hparams_dict_eval = hparams_dict.copy()
    hparams_dict_train['mode'] = 'train'
    hparams_dict_eval['mode'] = 'eval'
    hparams_pb = hp.hparams_pb(hparams_dict_train).SerializeToString()
    hparams_pb_eval = hp.hparams_pb(hparams_dict_eval).SerializeToString()
    with tf.summary.FileWriter(model_dir) as w:
        w.add_summary(hparams_pb)
    with tf.summary.FileWriter(os.path.join(model_dir, 'eval')) as w:
        w.add_summary(hparams_pb_eval)
コード例 #2
0
def write_hparams_v1(writer, hparams: dict):
    hparams = _copy_and_clean_hparams(hparams)
    hparams = _set_precision_if_missing(hparams)

    with tf.compat.v1.Graph().as_default():
        if isinstance(writer, str):
            writer = SummaryWriterCache.get(writer)
        summary = hp.hparams_pb(hparams).SerializeToString()
        writer.add_summary(summary)
コード例 #3
0
ファイル: tb_utils.py プロジェクト: HabanaAI/Model-References
def write_hparams_v1(writer, hparams: dict):
    hparams = _copy_and_clean_hparams(hparams)
    hparams = _set_precision_if_missing(hparams)

    # We create Session here, because in case of older topologies
    # that run in graph mode the FileWriter needs it.
    with tf.compat.v1.Session():
        if isinstance(writer, str):
            writer = SummaryWriterCache.get(writer)
        summary = hp.hparams_pb(hparams).SerializeToString()
        writer.add_summary(summary)
コード例 #4
0
 def _check_hparams_equal(hp1, hp2):
     assert (hparams_api.hparams_pb(
         hp1,
         start_time_secs=0).SerializeToString() == hparams_api.hparams_pb(
             hp2, start_time_secs=0).SerializeToString())