def after_run(self, _, run_values): # cprnt(**{self.mode: "AFTER Console Logger"}) global_step = run_values.results.pop("global_step")[0] if self.mode == ModeKeys.TRAIN: if global_step % self.every_n_iter == 0 or global_step == 1: current_time = time.time() duration = current_time - self._start_time sec_per_step = float( duration / (self.every_n_iter if global_step != 1 else 1)) step_per_sec = float( (self.every_n_iter if global_step != 1 else 1) / duration) self._start_time = time.time() cprnt( tf=True, TRAIN=self.template.format_map({ "duration": duration, "sec_per_step": sec_per_step, "step_per_sec": step_per_sec, "step": global_step, "epoch": global_step / self.epoch_steps, **run_values.results, }), )
def _init_uid(self): datasets_uids = [dataset.uid for dataset in self.datasets] dataset_names = [dataset.name for dataset in self.datasets] oov_policy = [self._oov_train_threshold, self._num_oov_buckets] uid_data = [self._embedding.uid] + datasets_uids + oov_policy cprnt( tf=True, INFO="""Feature Data: Embedding: {embedding_uid} \t CASE-{ci}SENSITIVE Dataset(s): {datasets_uids} OOV Threshold: {train_oov_threshold} OOV Buckets: {oov_buckets} OOV Init Fn: {function} \t args: {args} \t kwargs: {kwargs} """.format_map({ "embedding_uid": self._embedding.uid, "ci": "IN" if self._embedding.case_insensitive else "", "datasets_uids": "\t".join(datasets_uids), "train_oov_threshold": self._oov_train_threshold, "oov_buckets": self._num_oov_buckets, **literal_eval(stringify(self._oov_fn)), }), ) self._uid = hash_data(uid_data) name_data = [self._embedding.name] + dataset_names + [self._uid] self._name = "--".join(name_data)
def make_default_run_cofig(self, custom_config=None): custom_config = custom_config or {} # setting session config anything other than None in distributed # setting prevents device filters from being automatically set # job hangs indefinitely at final checkpoint session_conf_keywords = ["session", "sess"] custom_sess_config = extract_config_subset( config_objs=[custom_config], keywords=session_conf_keywords) default_session_config = { "allow_soft_placement": True, "log_device_placement": False, } default_session_config.update(custom_sess_config) custom_run_config = { key: value for key, value in custom_config.items() if key.split("_")[0] not in session_conf_keywords } if "save_checkpoints_steps" in [*custom_run_config]: self.model.aux_config["checkpoints_freq"] = self.model.aux_config[ "early_stopping_freq"] = custom_run_config.pop( "save_checkpoints_steps") elif "save_checkpoints_secs" in [*custom_run_config]: self.model.aux_config["checkpoints_secs"] = self.model.aux_config[ "early_stopping_secs"] = custom_run_config.pop( "save_checkpoints_secs") if "save_summary_steps" in [*custom_run_config]: self.model.aux_config["summaries_freq"] = custom_run_config.pop( "save_summary_steps") default_run_config = { # "session_config": tf.ConfigProto(**default_session_config), "model_dir": self.experiment_dir, "tf_random_seed": TF_RANDOM_SEED, "log_step_count_steps": LOG_STEP_COUNT_STEPS, "keep_checkpoint_max": KEEP_CHECKPOINT_MAX, # * these settings are delegated to the respective addons "save_checkpoints_steps": None, "save_checkpoints_secs": None, "save_summary_steps": None, } default_run_config.update(custom_run_config) if default_run_config["tf_random_seed"] is not None: cprnt( tf=True, warn="TF Random seed is set to {}".format( default_run_config.get("tf_random_seed")), ) if default_run_config.get("keep_checkpoint_max") < 2: default_run_config["keep_checkpoint_max"] = 2 cprnt( tf=True, warn=("TSAPLAY requires keep_checkpoint_max >=2" + " got {}, using 2").format( default_run_config.get("keep_checkpoint_max")), ) return default_run_config
def after_run(self, _, run_values): # cprnt(**{self.mode: "AFTER Metadata"}) if self.mode == ModeKeys.TRAIN: _counter = run_values.results["global_step"][0] tag = "step-{}".format(_counter) elif self.mode == ModeKeys.EVAL: if self.freq == "done": return eval_run = run_values.results["global_step"][0] self.counter += 1 _counter = self.counter tag = "step-{}-batch-{}".format(eval_run, _counter) if self.freq == "once" or _counter == 1 or _counter % self.freq == 0: if self.freq == "once": self.freq = "done" try: self._summary_writer.add_run_metadata(run_values.run_metadata, tag) cprnt(tf=True, **{self.mode: "Logged Metadata ({})".format(tag)}) except ValueError: cprnt( tf=True, warn="{} Metadata ERR ({}) \n {}".format( self.mode.upper(), tag, traceback.format_exc()), )
def run_next_experiment(batch_file_path, job_dir=None, defaults=None): task_parser = argument_parser() tasks = parse_batch_file(batch_file_path, defaults) if job_dir: tasks = [t + ["--job-dir", job_dir] for t in tasks] task_index = int(environ.get("TSATASK", 0)) if task_index >= len(tasks): del environ["TSATASK"] return try: task_args = task_parser.parse_args(tasks[task_index]) cprnt(tf=True, info="RUNNING TASK {0}: {1}".format(task_index, task_args)) run_experiment(task_args, experiment_index=task_index) except Exception: # pylint: disable=W0703 traceback.print_exc() environ["TSATASK"] = str(task_index + 1) job_dir_arg = "--job-dir {}".format(job_dir) if (job_dir) else "" defaults_arg = ("--defaults {}".format(" ".join(defaults)) if defaults else "") next_cmd = " ".join([ "python3 -m", "tsaplay.task", "batch", "{batch_file}", "{job_dir}", "{defaults}", ]).format(batch_file=batch_file_path, job_dir=job_dir_arg, defaults=defaults_arg) execvpe("python3", next_cmd.split(), environ)
def after_save(self, session, global_step_value): stopped_early = (tf.get_default_graph().get_tensor_by_name( "signal_early_stopping/STOP:0").eval(session=session).astype(bool)) # cprnt( # train="Discard Chckpt-{}? {}".format( # global_step_value, "YEP" if stopped_early else "NOPE" # ) # ) if stopped_early: dir_str, query_str = path.split( tf.train.latest_checkpoint(self.model_dir)) if dir_str.startswith("gs://"): file_paths = [ path.join(dir_str, fname) for fname in tf.gfile.ListDirectory(dir_str) if fname.startswith(query_str) ] for old_file_path in file_paths: tf.gfile.Remove(old_file_path) cprnt(tf=True, warn="DEL: {}".format(old_file_path)) else: for old_file_path in search_dir(dir_str, query=query_str): os.remove(old_file_path) cprnt(tf=True, warn="DEL: {}".format(old_file_path)) if self.comet is not None: _ts = datetime.timestamp(datetime.now()) self.comet.log_other("END", str(_ts)) self.comet.disable_mp() self.comet.end() summary_writer = tf.summary.FileWriterCache.get(self.model_dir) summary_writer.flush()
def histograms(model, features, labels, spec, params): step_freq = resolve_summary_step_freq( config_objs=[params, model.aux_config], keywords=["summaries", "logging", "histograms"], epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_SUMMARY_STEPS, ) model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "HISTOGRAMS": step_freq, } cprnt(tf=True, info="HISTOGRAMS every {} steps".format(step_freq)) trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) names = [variable.name.replace(":", "_") for variable in trainables] for (name, variable) in zip(names, trainables): tf.summary.histogram(name, variable) if model.comet_experiment: train_hooks = list(spec.training_hooks) or [] train_hooks += [ LogHistogramsToComet( comet=model.comet_experiment, names=names, trainables=trainables, every_n_iter=step_freq, ) ] spec = spec._replace(training_hooks=train_hooks) return spec
def main(): parser = argument_parser() args = parser.parse_args() try: ds_name = generate_dataset_files(args) cprnt(success="Imported '{}' successfully.".format(ds_name)) except Exception as e: cprnt(warn="Error importing dataset. {}".format(e))
def after_run(self, _, run_values): # cprnt(train="AFTER Summary Saving ({})".format(self._global_step)) global_step = run_values.results.pop("global_step")[0] self._global_step = global_step if "summary" in [ *run_values.results ] and (global_step % self.summary_freq == 0 or global_step == 1): summary_data = run_values.results.pop("summary") self._summary_writer.add_summary(summary_data, global_step) cprnt(tf=True, info="Saved summary for step {}".format(global_step))
def checkpoints(model, features, labels, spec, params): config = extract_config_subset( config_objs=[params, model.aux_config], keywords=["summaries", "logging", "checkpoints"], ) if "early_stopping" in model.aux_config.get("applied_addons"): checkpoint_listeners = [ DiscardRedundantStopSignalCheckpoint( model_dir=model.run_config.model_dir, comet=model.comet_experiment, ) ] ea_freq_setting = model.aux_config["_resolved_freqs"]["EARLYSTOPPING"] freq_setting = {{ "run_every_secs": "save_secs", "run_every_steps": "save_steps", }.get(k): v for k, v in ea_freq_setting.items()} else: checkpoint_listeners = [] if config.get("secs") is not None: freq_setting = {"save_secs": config["secs"]} else: freq_setting = { "save_steps": resolve_summary_step_freq( config=config, epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_CHECKPOINTS_STEPS, ) } model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "CHECKPOINTS": freq_setting, } cprnt( tf=True, info=("CHECKPOINTS every {save_secs} seconds" if freq_setting.get("save_secs") is not None else "CHECKPOINTS every {save_steps} steps").format_map(freq_setting), ) train_hooks = list(spec.training_hooks) or [] train_hooks += [ tf.train.CheckpointSaverHook( model.run_config.model_dir, **freq_setting, listeners=checkpoint_listeners, scaffold=spec.scaffold, ) ] return spec._replace(training_hooks=train_hooks)
def main(): parser = argument_parser() args = parser.parse_args() nruns = args.nruns start = int(args.run_start) if args.run_start else 1 jobs_args = [] if args.jobs_file: block_args = [] for line in open(args.jobs_file, "r"): if len(line.strip()) == 0: block_args = [] continue if line.startswith("#") or line.startswith(";"): continue if line.startswith("block"): block_args += line.strip().split()[1:] else: line_cmd = line.split() if block_args and "--defaults" not in line_cmd: line_cmd += ["--defaults"] this_args = parser.parse_args(line_cmd + block_args) this_nruns = this_args.nruns or nruns this_start = (int(this_args.run_start) if this_args.run_start is not None else start) if this_nruns: orig_job_id = this_args.job_id for run_num in range(this_start, this_start + int(this_nruns)): this_args.job_id = orig_job_id + ( "_run{:02}".format(run_num)) jobs_args += [deepcopy(this_args)] else: jobs_args += [deepcopy(this_args)] else: if nruns: orig_job_id = args.job_id for run_num in range(start, start + int(nruns)): args.job_id = orig_job_id + ("_run{:02}".format(run_num)) jobs_args += [deepcopy(args)] else: jobs_args = [args] for job_args in jobs_args: try: submit_job(job_args) except Exception: # pylint: disable=W0703 job_str = " ".join(job_args) cprnt(WARN="Encountered exception in job: {}".format(job_str)) traceback.print_exc() continue
def end(self, session): # cprnt(**{self.mode: "END Console Logger"}) if self.mode == ModeKeys.EVAL: run_values = session.run({ "step": tf.train.get_global_step(), **self.tensors }) if self.epoch_steps is not None: cprnt( tf=True, EVAL=self.template.format_map({ **run_values, **{ "epoch": run_values.get("step") / self.epoch_steps }, }), )
def wrapper(*args, **kw): if environ.get("TIMEIT", "ON").lower() == "off": return func(*args, **kw) name = func.__qualname__ + "():" time_stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") cprnt(c=time_stamp, r=name, g=pre) start_time = time.time() result = func(*args, **kw) end_time = time.time() time_taken = timedelta(seconds=(end_time - start_time)) time_stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if post: cprnt(c=time_stamp, r=name, g=post + " in", row=str(time_taken)) return result
def generate_dataset_files(args): dataset_name = args.dataset_name or basename(normpath(args.path)) parsing_fn = get_parsing_fn(args.path, args.parser_name) ftrain, ftest = get_raw_file_paths(args.path) train_dict, test_dict = get_dataset_dicts(ftrain, ftest, parsing_fn) target_path = join(DATASET_DATA_PATH, dataset_name) if exists(target_path): if not args.force: cprnt(warn="Dataset '{}' already exists, use -f to overwrite". format(dataset_name)) return cprnt(info="Overwriting previous '{}' dataset".format(dataset_name)) rmtree(target_path) makedirs(target_path) pickle_file(join(target_path, "_train_dict.pkl"), train_dict) pickle_file(join(target_path, "_test_dict.pkl"), test_dict) return dataset_name
def metadata(model, features, labels, spec, params): if spec.mode == ModeKeys.TRAIN: step_freq = resolve_summary_step_freq( config_objs=[params, model.aux_config], keywords=["summaries", "logging", "metadata"], epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_SUMMARY_STEPS, ) model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "METADATA": step_freq, } cprnt(tf=True, info="METADATA every {} steps".format(step_freq)) metadata_dir = model.run_config.model_dir makedirs(metadata_dir, exist_ok=True) train_hooks = list(spec.training_hooks) or [] train_hooks += [ MetadataHook( mode=ModeKeys.TRAIN, summary_writer=tf.summary.FileWriterCache.get( model.run_config.model_dir), every_n_iter=step_freq, first_step=model.aux_config["chkpt"], ) ] spec = spec._replace(training_hooks=train_hooks) elif spec.mode == ModeKeys.EVAL: eval_hooks = list(spec.evaluation_hooks) or [] metadata_dir = model.estimator.eval_dir() makedirs(metadata_dir, exist_ok=True) eval_hooks += [ MetadataHook( mode=ModeKeys.EVAL, summary_writer=tf.summary.FileWriterCache.get( join(model.run_config.model_dir, "eval")), every_n_iter="once", ) ] spec = spec._replace(evaluation_hooks=eval_hooks) return spec
def profiling(model, features, labels, spec, params): step_freq = resolve_summary_step_freq( config_objs=[params, model.aux_config], keywords=["summaries", "logging", "profiling"], epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_SUMMARY_STEPS, ) model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "PROFILING": step_freq, } cprnt(tf=True, info="PROFILING every {} steps".format(step_freq)) train_hooks = list(spec.training_hooks) or [] profiler_dir = join(model.run_config.model_dir) makedirs(profiler_dir, exist_ok=True) train_hooks += [ tf.train.ProfilerHook(save_steps=step_freq, output_dir=profiler_dir, show_memory=True) ] return spec._replace(training_hooks=train_hooks)
def summaries(model, features, labels, spec, params): step_freq = resolve_summary_step_freq( config_objs=[params, model.aux_config], keywords=["summaries"], epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_SUMMARY_STEPS, ) model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "SUMMARIES": step_freq, } cprnt(tf=True, info="SUMMARIES every {} steps".format(step_freq)) train_hooks = list(spec.training_hooks) or [] train_hooks += [ SummarySavingHook( ops=tf.summary.merge_all(), every_n_iter=step_freq, writer=tf.summary.FileWriterCache.get(model.run_config.model_dir), first_step=model.aux_config["chkpt"], ) ] return spec._replace(training_hooks=train_hooks)
def export_model(self, overwrite=False, restart_tfserve=False): if self.contd_tag is None: print("No continue tag defined, nothing to export!") else: export_model_name = self.model.name.lower() + "_" + self.contd_tag model_export_dir = join(EXPORTS_DATA_PATH, export_model_name) if exists(model_export_dir) and overwrite: rmtree(model_export_dir) prev_exported_models = self.get_exported_models() self.model.export( directory=model_export_dir, feature_provider=self.feature_provider, ) if prev_exported_models != self.get_exported_models(): cprnt( info="Updating tfserve.conf with new exported model info") self._update_export_models_config() if restart_tfserve: cprnt(info="Restarting tsaplay docker container") logs = restart_tf_serve_container() cprnt(info=logs)
def early_stopping(model, features, labels, spec, params): if model.aux_config["chkpt"] > 0: cprnt( tf=True, warn=("Early Stopping DISABLED: " + "Not a new experiment, " + "restoring from STEP {}".format(model.aux_config["chkpt"])), ) model.aux_config["applied_addons"].remove("early_stopping") return spec train_hooks = list(spec.training_hooks) or [] eval_dir = model.estimator.eval_dir() makedirs(eval_dir, exist_ok=True) config = extract_config_subset( config_objs=[params, model.aux_config], keywords=["ea", "early_stopping"], ) metric = config.get("metric", "loss") comparison = ("decrease" if metric == "loss" else "increase" if metric in [ "accuracy", "macro-f1", "micro-f1", "weighted-f1" ] else config.get("comparison")) epochs = params.get("epochs") steps = params.get("steps") if config.get("secs") is not None: freq_setting = {"run_every_secs": config["secs"]} else: freq_setting = { "run_every_steps": resolve_summary_step_freq( config=config, epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_CHECKPOINTS_STEPS, ), "run_every_secs": None, } model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "EARLYSTOPPING": freq_setting, } patience = config.get("patience", 10 if epochs is not None else 1000) minimum_iter = config.get("minimum_iter", 0) if spec.mode == ModeKeys.TRAIN: cprnt( tf=True, info=("Early Stopping: \n" + "\t".join([ "metric: {metric}", "mode: {comparison}", "run every: {run_every}", "patience: {patience}", "minimum: {minimum_iter}", "maximum: {maximum_iter}", ])).format( metric=metric, comparison=comparison, run_every=( "1 epoch ({run_every_steps} steps)" if epochs is not None else "{run_every_steps} steps").format_map(freq_setting) if "run_every_steps" in [*freq_setting] else "{run_every_secs} secs".format_map(freq_setting), patience=("{} epoch(s)" if epochs is not None else "{} steps").format(patience), minimum_iter=("{} epoch(s)" if epochs is not None else "{} steps").format(minimum_iter) if minimum_iter > 0 else "Indefinite", maximum_iter=( "{} epoch(s)" if epochs is not None else "{} steps" ).format(epochs if epochs is not None else steps), ), ) early_stopping_hook_fn = (stop_if_no_increase_hook if comparison == "increase" else stop_if_no_decrease_hook) early_stopping_hook_args = { "estimator": model.estimator, "eval_dir": eval_dir, "metric_name": metric, "min_steps": (minimum_iter * (params.get("epoch_steps") if epochs is not None else 1)), **freq_setting, ("max_steps_without_{}".format(comparison)): (patience * (params.get("epoch_steps") if epochs is not None else 1)), } train_hooks += [early_stopping_hook_fn(**early_stopping_hook_args)] return spec._replace(training_hooks=train_hooks)
def run_experiment(args, experiment_index=None): try: tf.logging.set_verbosity(getattr(tf.logging, args.verbosity)) except AttributeError: tf.logging.set_verbosity(tf.logging.INFO) if experiment_index is not None: cprnt(tf=True, y="Running experiment {}".format(experiment_index + 1)) cprnt(tf=True, y="Args: {}".format(args)) contd_tag = args.contd_tag if args.comet_api and not contd_tag: cprnt(tf="WARN", warn="Not contd-tag provided, auto-generating...") contd_tag = "..autogen" feature_provider = make_feature_provider(args) params = args_to_dict(args.model_params) aux_config = args_to_dict(args.aux_config) if args.batch_size: params.update({"batch-size": args.batch_size}) model = load_model(args.model)(params, aux_config) model.params["steps"] = (args.steps if args.steps is not None else model.params.get("steps")) model.params["epochs"] = (args.epochs if args.epochs is not None else model.params.get("epochs")) epoch_steps, num_training_samples = feature_provider.steps_per_epoch( model.params["batch-size"]) model.params.update({ "epoch_steps": epoch_steps, "shuffle_buffer": num_training_samples }) if args.steps is not None: model.params.pop("epochs", None) if model.params.get("epochs") is not None: model.params.pop("steps", None) if contd_tag.startswith(".."): oov_tag = "oovt{}b{}".format(params["oov_train"], params["oov_buckets"]) gen_tag = "-".join([ args.model.replace("_", ""), model.params["optimizer"].lower(), EMBEDDING_TAGS[feature_provider.embedding.name], ] + [ "{}-{}".format( name, ",".join([ "_".join(list(map(str, map(int, vals)))) for vals in redist.values() if vals is not None ]), ) if redist is not None else name for name, redist in args.datasets ]) p_tags = [] for p_name, p_arg in model.params.items(): if p_name in [*CONTD_TAG_PARAMS]: p_tags += [ CONTD_TAG_PARAMS[p_name].format( str(p_arg).replace(".", "")) ] p_tag = "-".join(sorted(p_tags)) contd_tag = "-".join([gen_tag, oov_tag, p_tag, contd_tag[2:]]) cprnt(tf="INFO", info="CONTD-TAG: {}".format(contd_tag)) experiment = Experiment( feature_provider, model, run_config=args_to_dict(args.run_config), comet_api=args.comet_api, comet_workspace=args.comet_workspace, contd_tag=contd_tag, job_dir=args.job_dir, ) try: experiment.run(job="train+eval", steps=args.steps, epochs=args.epochs) except Exception as exception: if experiment.model.comet_experiment is not None: experiment.model.comet_experiment.log_html("<pre>{}</pre>".format( traceback.format_exc())) experiment.model.comet_experiment.end() raise exception if args.tb_port: try: tb_port = int(args.tb_port) except ValueError: cprnt( tf="FATAL", warn="Invalid tensorboard port {}".format(args.tb_port), ) experiment.launch_tensorboard(tb_port=tb_port)
def logging(model, features, labels, spec, params): config = extract_config_subset( config_objs=[params, model.aux_config], keywords=["summaries", "logging"], ) id_tag = str_snippet( params.get("contd_tag") if params.get("contd_tag") is not None else path_split(model.run_config.model_dir)[1]) if spec.mode == ModeKeys.TRAIN: std_metrics = { "accuracy": tf.metrics.accuracy( labels=labels, predictions=spec.predictions["class_ids"], name="acc_op", ) } tensors_to_log = { "step": tf.train.get_global_step(), "loss": spec.loss, "accuracy": std_metrics["accuracy"][1], } console_template = config.get( "train_template", id_tag + "STEP: {step} \t EPOCH: {epoch:.1f} \t|\t" + "acc: {accuracy:.5f} \t loss: {loss:.8f} |\t " + "duration: {duration:.2f}s " + "sec/step: {sec_per_step:.2f}s step/sec: {step_per_sec:.2f}", ) step_freq = resolve_summary_step_freq( config=config, epochs=params.get("epochs"), epoch_steps=params["epoch_steps"], default=SAVE_CHECKPOINTS_STEPS, ) model.aux_config["_resolved_freqs"] = { **model.aux_config.get("_resolved_freqs", {}), "LOGGING": step_freq, } cprnt(tf=True, info="LOGGING every {} steps".format(step_freq)) cprnt( tf=True, INFO=("\n".join([ "Run Configuration:", pformat(model.run_config.__dict__), "AUX Configuration:", pformat(model.aux_config), "Hyper Parameters:", pformat(model.params), ])), ) train_hooks = list(spec.training_hooks) or [] train_hooks += [ ConsoleLoggerHook( mode=ModeKeys.TRAIN, epoch_steps=params["epoch_steps"], every_n_iter=step_freq, tensors=tensors_to_log, template=console_template, ) ] spec = spec._replace(training_hooks=train_hooks) elif spec.mode == ModeKeys.EVAL: eval_hooks = list(spec.evaluation_hooks) or [] eval_hooks += [ ConsoleLoggerHook( mode=ModeKeys.EVAL, epoch_steps=params["epoch_steps"], tensors={k: v[0] for k, v in spec.eval_metric_ops.items()}, template=( id_tag + "STEP: {step} \t EPOCH: {epoch:.1f} \t|\t" + "acc: {accuracy:.5f} \t mpc_acc: {mpc_accuracy:.5f} \t" + "macro-f1: {macro-f1:.5f} \t" + "weighted-f1: {weighted-f1:.5f}" + "\n{conf-mat}"), ) ] spec = spec._replace(evaluation_hooks=eval_hooks) return spec