Esempio n. 1
0
 def _create_callbacks(self, cur_log_dir, params):
   """Creates a list of callbacks."""
   callbacks = misc.get_callbacks()
   if params["enable_checkpointing"]:
     ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
     callbacks.append(
         tf.keras.callbacks.ModelCheckpoint(
             ckpt_full_path, save_weights_only=params["save_weights_only"]))
   return callbacks
 def _create_callbacks(self, cur_log_dir, init_steps, params):
   """Creates a list of callbacks."""
   sfunc = optimizer.LearningRateFn(params["learning_rate"],
                                    params["hidden_size"],
                                    params["learning_rate_warmup_steps"])
   scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
   callbacks = misc.get_callbacks(params["steps_between_evals"])
   callbacks.append(scheduler_callback)
   ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
   callbacks.append(
       tf.keras.callbacks.ModelCheckpoint(
           ckpt_full_path, save_weights_only=True))
   return callbacks