Exemplo n.º 1
0
def _get_student_callbacks(log_dir, save_dir, current_iteration):
    """Creates callbacks to be used in student's training.

  Args:
    log_dir: Directory where logs are written to.
    save_dir: Directory where model checkpoints are written to.
    current_iteration: The current iteration of model training.

  Returns:
    A list of callbacks for student training.
  """

    student_callbacks = []

    if log_dir:
        student_log_dir = os.path.join(log_dir, 'student',
                                       f'iteration_{current_iteration:04d}')
        os.makedirs(student_log_dir, exist_ok=True)

        student_callbacks.append(
            tf.keras.callbacks.TensorBoard(log_dir=student_log_dir,
                                           histogram_freq=1))
        student_callbacks.append(
            utils.LearningRateLogger(log_dir=student_log_dir,
                                     name='student_learning_rate'))

    student_callbacks.append(
        utils.CustomEarlyStopping(monitor='val_loss',
                                  min_delta=0,
                                  patience=60,
                                  verbose=1,
                                  mode='min',
                                  restore_best_weights=True))

    student_callbacks.append(
        utils.CustomReduceLROnPlateau(monitor='val_loss',
                                      factor=0.5,
                                      patience=20,
                                      verbose=1,
                                      mode='min',
                                      min_delta=0.0001,
                                      cooldown=0,
                                      min_lr=0.0000001))

    checkpoint_path = os.path.join(save_dir, 'student',
                                   f'iteration_{current_iteration:04d}')
    os.makedirs(checkpoint_path, exist_ok=True)
    student_callbacks.append(
        tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(
            checkpoint_path, 'weights.{epoch:04d}.hdf5'),
                                           monitor='val_loss',
                                           verbose=1,
                                           save_best_only=True,
                                           save_weights_only=False,
                                           mode='min',
                                           save_freq='epoch'))

    return student_callbacks
Exemplo n.º 2
0
def _get_mentor_callbacks(log_dir, save_dir, current_iteration):
  """Creates callbacks to be used in mentor's training."""
  mentor_callbacks = []

  if log_dir:
    mentor_log_dir = os.path.join(log_dir, 'mentor',
                                  f'iteration_{current_iteration:04d}')
    os.makedirs(mentor_log_dir, exist_ok=True)

    mentor_callbacks.append(
        tf.keras.callbacks.TensorBoard(
            log_dir=mentor_log_dir, histogram_freq=1))
    mentor_callbacks.append(
        utils.LearningRateLogger(
            log_dir=mentor_log_dir, name='mentor_learning_rate'))

  mentor_callbacks.append(
      utils.CustomEarlyStopping(
          monitor='val_loss',
          min_delta=0,
          patience=100,
          verbose=1,
          mode='min',
          restore_best_weights=True))

  mentor_callbacks.append(
      utils.CustomReduceLROnPlateau(
          monitor='val_loss',
          factor=0.5,
          patience=20,
          verbose=1,
          mode='min',
          min_delta=0,
          cooldown=0,
          min_lr=0.0000001))

  checkpoint_path = os.path.join(save_dir, 'mentor',
                                 f'iteration_{current_iteration:04d}')
  os.makedirs(checkpoint_path, exist_ok=True)
  mentor_callbacks.append(
      tf.keras.callbacks.ModelCheckpoint(
          filepath=os.path.join(checkpoint_path, 'weights.{epoch:04d}.hdf5'),
          monitor='val_loss',
          verbose=1,
          save_best_only=True,
          save_weights_only=False,
          mode='min',
          save_freq='epoch'))

  return mentor_callbacks