def get_callbacks(steps_per_epoch, current_rank, cluster_size, learning_rate_schedule_fn): """Returns common callbacks.""" time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) callbacks = [time_callback] if not FLAGS.use_tensor_lr and learning_rate_schedule_fn: lr_callback = LearningRateBatchScheduler( learning_rate_schedule_fn, batch_size=FLAGS.batch_size, steps_per_epoch=steps_per_epoch, cluster_size=cluster_size) callbacks.append(lr_callback) if FLAGS.enable_tensorboard and current_rank == 0: tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=FLAGS.model_dir) callbacks.append(tensorboard_callback) if FLAGS.profile_steps: profiler_callback = keras_utils.get_profiler_callback( FLAGS.model_dir, FLAGS.profile_steps, FLAGS.enable_tensorboard, steps_per_epoch) callbacks.append(profiler_callback) return callbacks
def get_callbacks(learning_rate_schedule_fn, num_images): """Returns common callbacks.""" time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) callbacks = [time_callback] if not FLAGS.use_tensor_lr: lr_callback = LearningRateBatchScheduler( learning_rate_schedule_fn, batch_size=FLAGS.batch_size, num_images=num_images) callbacks.append(lr_callback) if FLAGS.enable_tensorboard: tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=FLAGS.model_dir) callbacks.append(tensorboard_callback) if FLAGS.profile_steps: profiler_callback = keras_utils.get_profiler_callback( FLAGS.model_dir, FLAGS.profile_steps, FLAGS.enable_tensorboard) callbacks.append(profiler_callback) return callbacks
def get_callbacks( steps_per_epoch, learning_rate_schedule_fn=None, pruning_method=None, enable_checkpoint_and_export=False, model_dir=None): """Returns common callbacks.""" time_callback = keras_utils.TimeHistory( FLAGS.batch_size, FLAGS.log_steps, logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None) callbacks = [time_callback] if not FLAGS.use_tensor_lr and learning_rate_schedule_fn: lr_callback = LearningRateBatchScheduler( learning_rate_schedule_fn, batch_size=FLAGS.batch_size, steps_per_epoch=steps_per_epoch) callbacks.append(lr_callback) if FLAGS.enable_tensorboard: tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=FLAGS.model_dir) callbacks.append(tensorboard_callback) if FLAGS.profile_steps: profiler_callback = keras_utils.get_profiler_callback( FLAGS.model_dir, FLAGS.profile_steps, FLAGS.enable_tensorboard, steps_per_epoch) callbacks.append(profiler_callback) is_pruning_enabled = pruning_method is not None if is_pruning_enabled: callbacks.append(tfmot.sparsity.keras.UpdatePruningStep()) if model_dir is not None: callbacks.append(tfmot.sparsity.keras.PruningSummaries( log_dir=model_dir, profile_batch=0)) if enable_checkpoint_and_export: if model_dir is not None: ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}') callbacks.append( tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True)) return callbacks
def get_callbacks(): """Returns common callbacks.""" callbacks = [] if FLAGS.enable_time_history: time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) callbacks.append(time_callback) if FLAGS.enable_tensorboard: tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=FLAGS.model_dir) callbacks.append(tensorboard_callback) if FLAGS.profile_steps: profiler_callback = keras_utils.get_profiler_callback( FLAGS.model_dir, FLAGS.profile_steps, FLAGS.enable_tensorboard) callbacks.append(profiler_callback) return callbacks