示例#1
0
def get_optimizer(params):
    """Get optimizer."""
    learning_rate = learning_rate_schedule(params)
    momentum = params['momentum']
    if params['optimizer'].lower() == 'sgd':
        logging.info('Use SGD optimizer')
        optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=momentum)
    elif params['optimizer'].lower() == 'adam':
        logging.info('Use Adam optimizer')
        optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=momentum)
    else:
        raise ValueError('optimizers should be adam or sgd')

    moving_average_decay = params['moving_average_decay']
    if moving_average_decay:
        # TODO(tanmingxing): potentially add dynamic_decay for new tfa release.
        from tensorflow_addons import optimizers as tfa_optimizers  # pylint: disable=g-import-not-at-top
        optimizer = tfa_optimizers.MovingAverage(
            optimizer, average_decay=moving_average_decay, dynamic_decay=True)
    if params['mixed_precision']:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer,
            loss_scale=tf.mixed_precision.experimental.DynamicLossScale(
                params['loss_scale']))
    return optimizer
示例#2
0
def get_optimizer(params):
  """Get optimizer."""
  learning_rate = learning_rate_schedule(params)
  momentum = params['momentum']
  if params['optimizer'].lower() == 'sgd':
    logging.info('Use SGD optimizer')
    optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=momentum)
  elif params['optimizer'].lower() == 'adam':
    logging.info('Use Adam optimizer')
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=momentum)
  else:
    raise ValueError('optimizers should be adam or sgd')

  moving_average_decay = params['moving_average_decay']
  if moving_average_decay:
    from tensorflow_addons import optimizers as tfa_optimizers  # pylint: disable=g-import-not-at-top
    optimizer = tfa_optimizers.MovingAverage(
        optimizer, average_decay=moving_average_decay, dynamic_decay=True)
  precision = utils.get_precision(params['strategy'], params['mixed_precision'])
  if precision == 'mixed_float16' and params['loss_scale']:
    optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
        optimizer,
        initial_scale=params['loss_scale'])
  return optimizer
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
  """
  log_dir_path = FLAGS.log_dir_path
  pipeline_utils.create_dir_and_save_flags(
      flags, log_dir_path, 'all_flags.train_with_features.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      os.path.join(log_dir_path, 'train_logs'), flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)

  # Setup model.
  input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate)
  if FLAGS.input_features_dim > 0:
    feature_dim = FLAGS.input_features_dim
    input_shape = (input_length, feature_dim)
  else:
    feature_dim = None
    input_shape = (input_length, 13 * 2)
  classifier = models.get_temporal_classifier(
      FLAGS.classifier_type,
      input_shape=input_shape,
      num_classes=FLAGS.num_classes)
  ema_classifier = tf.keras.models.clone_model(classifier)
  optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
  optimizer = tfa_optimizers.MovingAverage(optimizer)
  global_step = optimizer.iterations
  ckpt_manager, _, _ = utils.create_checkpoint(
      log_dir_path,
      optimizer=optimizer,
      model=classifier,
      ema_model=ema_classifier,
      global_step=global_step)

  # Setup the training dataset.
  dataset = pipelines.create_dataset_from_tables(
      FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes],
      num_instances_per_record=1,
      shuffle=True,
      drop_remainder=True,
      num_epochs=None,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      feature_dim=feature_dim,
      num_classes=FLAGS.num_classes,
      num_frames=FLAGS.num_frames,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      common_module=common_module,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)
  loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

  def train_one_iteration(inputs):
    """Trains the model for one iteration.

    Args:
      inputs: A dictionary for training inputs.

    Returns:
      loss: The training loss for this iteration.
    """
    if FLAGS.input_features_dim > 0:
      features = inputs[common_module.KEY_FEATURES]
    else:
      features, _ = pipelines.create_model_input(
          inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
          keypoint_profile_2d)

    features = tf.squeeze(features, axis=1)
    features = features[:, ::FLAGS.downsample_rate, :]
    labels = inputs[common_module.KEY_CLASS_TARGETS]
    labels = tf.squeeze(labels, axis=1)

    with tf.GradientTape() as tape:
      outputs = classifier(features, training=True)
      regularization_loss = sum(classifier.losses)
      crossentropy_loss = loss_object(labels, outputs)
      total_loss = crossentropy_loss + regularization_loss

    trainable_variables = classifier.trainable_variables
    grads = tape.gradient(total_loss, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))

    for grad, trainable_variable in zip(grads, trainable_variables):
      tf.summary.scalar(
          'summarize_grads/' + trainable_variable.name,
          tf.linalg.norm(grad),
          step=global_step)

    return dict(
        total_loss=total_loss,
        crossentropy_loss=crossentropy_loss,
        regularization_loss=regularization_loss)

  if FLAGS.compile:
    train_one_iteration = tf.function(train_one_iteration)

  record_every_n_steps = min(5, FLAGS.num_iterations)
  save_ckpt_every_n_steps = min(500, FLAGS.num_iterations)

  with summary_writer.as_default():
    with tf.summary.record_if(global_step % record_every_n_steps == 0):
      start = time.time()
      for inputs in dataset:
        if global_step >= FLAGS.num_iterations:
          break

        model_losses = train_one_iteration(inputs)
        duration = time.time() - start
        start = time.time()

        for name, loss in model_losses.items():
          tf.summary.scalar('train/' + name, loss, step=global_step)

        tf.summary.scalar('train/learning_rate', optimizer.lr, step=global_step)
        tf.summary.scalar('train/batch_time', duration, step=global_step)
        tf.summary.scalar('global_step/sec', 1 / duration, step=global_step)

        if global_step % record_every_n_steps == 0:
          logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format(
              global_step.numpy(), FLAGS.num_iterations, duration,
              model_losses['total_loss'].numpy()))

        # Save checkpoint.
        if global_step % save_ckpt_every_n_steps == 0:
          utils.assign_moving_average_vars(classifier, ema_classifier,
                                           optimizer)
          ckpt_manager.save(checkpoint_number=global_step)
          logging.info('Checkpoint saved at step %d.', global_step.numpy())