def test_create_dataset_from_tables(self):
    testdata_dir = 'poem/testdata'  # Assume $PWD == "google_research/".
    dataset = pipelines.create_dataset_from_tables(
        [os.path.join(FLAGS.test_srcdir, testdata_dir, 'tfe-2.tfrecords')],
        batch_sizes=[4],
        num_instances_per_record=2,
        shuffle=True,
        num_epochs=None,
        keypoint_names_3d=keypoint_profiles.create_keypoint_profile_or_die(
            'LEGACY_3DH36M17').keypoint_names,
        keypoint_names_2d=keypoint_profiles.create_keypoint_profile_or_die(
            'LEGACY_2DCOCO13').keypoint_names,
        seed=0)

    inputs = list(dataset.take(1))[0]
    self.assertCountEqual(inputs.keys(), [
        'image_sizes', 'keypoints_2d', 'keypoint_scores_2d', 'keypoints_3d'
    ])
    self.assertEqual(inputs['image_sizes'].shape, [4, 2, 2])
    self.assertEqual(inputs['keypoints_2d'].shape, [4, 2, 13, 2])
    self.assertEqual(inputs['keypoint_scores_2d'].shape, [4, 2, 13])
    self.assertEqual(inputs['keypoints_3d'].shape, [4, 2, 17, 3])
Exemplo n.º 2
0
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator, keypoint_preprocessor_3d):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
    keypoint_preprocessor_3d: A function handle for preprocessing raw 3D
      keypoints.
  """
  _validate(common_module)

  log_dir_path = FLAGS.log_dir_path
  pipeline_utils.create_dir_and_save_flags(flags, log_dir_path,
                                           'all_flags.train.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      os.path.join(log_dir_path, 'train_logs'), flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)
  keypoint_profile_3d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_3d)

  model = algorithms.get_algorithm(
      algorithm_type=FLAGS.algorithm_type,
      pose_embedding_dim=FLAGS.pose_embedding_dim,
      view_embedding_dim=FLAGS.view_embedding_dim,
      fusion_op_type=FLAGS.fusion_op_type,
      view_loss_weight=FLAGS.view_loss_weight,
      regularization_loss_weight=FLAGS.regularization_loss_weight,
      disentangle_loss_weight=FLAGS.disentangle_loss_weight,
      embedder_type=FLAGS.embedder_type)
  optimizers = algorithms.get_optimizers(
      algorithm_type=FLAGS.algorithm_type, learning_rate=FLAGS.learning_rate)
  global_step = optimizers['encoder_optimizer'].iterations
  ckpt_manager, _, _ = utils.create_checkpoint(
      log_dir_path, **optimizers, model=model, global_step=global_step)

  # Setup the training dataset.
  dataset = pipelines.create_dataset_from_tables(
      [FLAGS.input_table], [FLAGS.batch_size],
      num_instances_per_record=2,
      shuffle=True,
      num_epochs=None,
      drop_remainder=True,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      keypoint_names_3d=keypoint_profile_3d.keypoint_names,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)

  def train_one_iteration(inputs):
    """Trains the model for one iteration.

    Args:
      inputs: A dictionary for training inputs.

    Returns:
      The training loss for this iteration.
    """
    _, side_outputs = pipelines.create_model_input(
        inputs, FLAGS.model_input_keypoint_type, keypoint_profile_2d,
        keypoint_profile_3d)

    keypoints_2d = side_outputs[common_module.KEY_PREPROCESSED_KEYPOINTS_2D]
    keypoints_3d, _ = keypoint_preprocessor_3d(
        inputs[common_module.KEY_KEYPOINTS_3D],
        keypoint_profile_3d,
        normalize_keypoints_3d=True)
    keypoints_2d, keypoints_3d = data_utils.shuffle_batches(
        [keypoints_2d, keypoints_3d])

    return model.train((keypoints_2d, keypoints_3d), **optimizers)

  if FLAGS.compile:
    train_one_iteration = tf.function(train_one_iteration)

  record_every_n_steps = min(100, FLAGS.num_iterations)
  save_ckpt_every_n_steps = min(10000, FLAGS.num_iterations)

  with summary_writer.as_default():
    with tf.summary.record_if(global_step % record_every_n_steps == 0):
      start = time.time()
      for inputs in dataset:
        if global_step >= FLAGS.num_iterations:
          break

        model_losses = train_one_iteration(inputs)
        duration = time.time() - start
        start = time.time()

        for tag, losses in model_losses.items():
          for name, loss in losses.items():
            tf.summary.scalar(
                'train/{}/{}'.format(tag, name), loss, step=global_step)

        for tag, optimizer in optimizers.items():
          tf.summary.scalar(
              'train/{}_learning_rate'.format(tag),
              optimizer.lr,
              step=global_step)

        tf.summary.scalar('train/batch_time', duration, step=global_step)
        tf.summary.scalar('global_step/sec', 1 / duration, step=global_step)

        if global_step % record_every_n_steps == 0:
          logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format(
              global_step.numpy(), FLAGS.num_iterations, duration,
              model_losses['encoder']['total_loss'].numpy()))

        # Save checkpoint.
        if global_step % save_ckpt_every_n_steps == 0:
          ckpt_manager.save(checkpoint_number=global_step)
          logging.info('Checkpoint saved at step %d.', global_step.numpy())
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
  """
  log_dir_path = FLAGS.log_dir_path
  if not tf.io.gfile.exists(log_dir_path):
    raise ValueError(
        'The directory {} does not exist. Please provide a new log_dir_path.'
        .format(log_dir_path))
  eval_log_dir = os.path.join(log_dir_path, FLAGS.eval_name)
  pipeline_utils.create_dir_and_save_flags(flags, eval_log_dir,
                                           'all_flags.eval_with_features.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      eval_log_dir, flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)

  # Setup model.
  input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate)
  if FLAGS.input_features_dim > 0:
    feature_dim = FLAGS.input_features_dim
    input_shape = (input_length, feature_dim)
  else:
    feature_dim = None
    input_shape = (input_length, 13 * 2)
  classifier = models.get_temporal_classifier(
      FLAGS.classifier_type,
      input_shape=input_shape,
      num_classes=FLAGS.num_classes)
  ema_classifier = tf.keras.models.clone_model(classifier)
  global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
  dataset = pipelines.create_dataset_from_tables(
      FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes],
      num_instances_per_record=1,
      shuffle=False,
      num_epochs=1,
      drop_remainder=False,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      feature_dim=feature_dim,
      num_classes=FLAGS.num_classes,
      num_frames=FLAGS.num_frames,
      common_module=common_module,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)

  if FLAGS.compile:
    classifier.call = tf.function(classifier.call)
    ema_classifier.call = tf.function(ema_classifier.call)

  top_1_best_accuracy = None
  top_5_best_accuracy = None
  evaluated_last_ckpt = False

  def timeout_fn():
    """Timeout function to stop the evaluation."""
    return evaluated_last_ckpt

  def evaluate_once():
    """Evaluates the model for one time."""
    _, status, _ = utils.create_checkpoint(
        log_dir_path,
        model=classifier,
        ema_model=ema_classifier,
        global_step=global_step)
    status.expect_partial()
    logging.info('Last checkpoint [iteration: %d] restored at %s.',
                 global_step.numpy(), log_dir_path)

    if global_step.numpy() >= FLAGS.max_iteration:
      nonlocal evaluated_last_ckpt
      evaluated_last_ckpt = True

    top_1_accuracy = tf.keras.metrics.CategoricalAccuracy()
    top_5_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(k=5)
    for inputs in dataset:
      if FLAGS.input_features_dim > 0:
        features = inputs[common_module.KEY_FEATURES]
      else:
        features, _ = pipelines.create_model_input(
            inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
            keypoint_profile_2d)
      features = tf.squeeze(features, axis=1)
      features = features[:, ::FLAGS.downsample_rate, :]
      labels = inputs[common_module.KEY_CLASS_TARGETS]
      labels = tf.squeeze(labels, axis=1)

      if FLAGS.use_moving_average:
        outputs = ema_classifier(features, training=False)
      else:
        outputs = classifier(features, training=False)
      top_1_accuracy.update_state(y_true=labels, y_pred=outputs)
      top_5_accuracy.update_state(y_true=labels, y_pred=outputs)

    nonlocal top_1_best_accuracy
    if (top_1_best_accuracy is None or
        top_1_accuracy.result().numpy() > top_1_best_accuracy):
      top_1_best_accuracy = top_1_accuracy.result().numpy()

    nonlocal top_5_best_accuracy
    if (top_5_best_accuracy is None or
        top_5_accuracy.result().numpy() > top_5_best_accuracy):
      top_5_best_accuracy = top_5_accuracy.result().numpy()

    tf.summary.scalar(
        'eval/Basic/Top1_Accuracy',
        top_1_accuracy.result(),
        step=global_step.numpy())
    tf.summary.scalar(
        'eval/Best/Top1_Accuracy',
        top_1_best_accuracy,
        step=global_step.numpy())
    tf.summary.scalar(
        'eval/Basic/Top5_Accuracy',
        top_5_accuracy.result(),
        step=global_step.numpy())
    tf.summary.scalar(
        'eval/Best/Top5_Accuracy',
        top_5_best_accuracy,
        step=global_step.numpy())
    logging.info('Accuracy: {:.2f}'.format(top_1_accuracy.result().numpy()))

  with summary_writer.as_default():
    with tf.summary.record_if(True):
      if FLAGS.continuous_eval:
        for _ in tf.train.checkpoints_iterator(log_dir_path, timeout=1,
                                               timeout_fn=timeout_fn):
          evaluate_once()
      else:
        evaluate_once()
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
  """
  log_dir_path = FLAGS.log_dir_path
  pipeline_utils.create_dir_and_save_flags(
      flags, log_dir_path, 'all_flags.train_with_features.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      os.path.join(log_dir_path, 'train_logs'), flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)

  # Setup model.
  input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate)
  if FLAGS.input_features_dim > 0:
    feature_dim = FLAGS.input_features_dim
    input_shape = (input_length, feature_dim)
  else:
    feature_dim = None
    input_shape = (input_length, 13 * 2)
  classifier = models.get_temporal_classifier(
      FLAGS.classifier_type,
      input_shape=input_shape,
      num_classes=FLAGS.num_classes)
  ema_classifier = tf.keras.models.clone_model(classifier)
  optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
  optimizer = tfa_optimizers.MovingAverage(optimizer)
  global_step = optimizer.iterations
  ckpt_manager, _, _ = utils.create_checkpoint(
      log_dir_path,
      optimizer=optimizer,
      model=classifier,
      ema_model=ema_classifier,
      global_step=global_step)

  # Setup the training dataset.
  dataset = pipelines.create_dataset_from_tables(
      FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes],
      num_instances_per_record=1,
      shuffle=True,
      drop_remainder=True,
      num_epochs=None,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      feature_dim=feature_dim,
      num_classes=FLAGS.num_classes,
      num_frames=FLAGS.num_frames,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      common_module=common_module,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)
  loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

  def train_one_iteration(inputs):
    """Trains the model for one iteration.

    Args:
      inputs: A dictionary for training inputs.

    Returns:
      loss: The training loss for this iteration.
    """
    if FLAGS.input_features_dim > 0:
      features = inputs[common_module.KEY_FEATURES]
    else:
      features, _ = pipelines.create_model_input(
          inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
          keypoint_profile_2d)

    features = tf.squeeze(features, axis=1)
    features = features[:, ::FLAGS.downsample_rate, :]
    labels = inputs[common_module.KEY_CLASS_TARGETS]
    labels = tf.squeeze(labels, axis=1)

    with tf.GradientTape() as tape:
      outputs = classifier(features, training=True)
      regularization_loss = sum(classifier.losses)
      crossentropy_loss = loss_object(labels, outputs)
      total_loss = crossentropy_loss + regularization_loss

    trainable_variables = classifier.trainable_variables
    grads = tape.gradient(total_loss, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))

    for grad, trainable_variable in zip(grads, trainable_variables):
      tf.summary.scalar(
          'summarize_grads/' + trainable_variable.name,
          tf.linalg.norm(grad),
          step=global_step)

    return dict(
        total_loss=total_loss,
        crossentropy_loss=crossentropy_loss,
        regularization_loss=regularization_loss)

  if FLAGS.compile:
    train_one_iteration = tf.function(train_one_iteration)

  record_every_n_steps = min(5, FLAGS.num_iterations)
  save_ckpt_every_n_steps = min(500, FLAGS.num_iterations)

  with summary_writer.as_default():
    with tf.summary.record_if(global_step % record_every_n_steps == 0):
      start = time.time()
      for inputs in dataset:
        if global_step >= FLAGS.num_iterations:
          break

        model_losses = train_one_iteration(inputs)
        duration = time.time() - start
        start = time.time()

        for name, loss in model_losses.items():
          tf.summary.scalar('train/' + name, loss, step=global_step)

        tf.summary.scalar('train/learning_rate', optimizer.lr, step=global_step)
        tf.summary.scalar('train/batch_time', duration, step=global_step)
        tf.summary.scalar('global_step/sec', 1 / duration, step=global_step)

        if global_step % record_every_n_steps == 0:
          logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format(
              global_step.numpy(), FLAGS.num_iterations, duration,
              model_losses['total_loss'].numpy()))

        # Save checkpoint.
        if global_step % save_ckpt_every_n_steps == 0:
          utils.assign_moving_average_vars(classifier, ema_classifier,
                                           optimizer)
          ckpt_manager.save(checkpoint_number=global_step)
          logging.info('Checkpoint saved at step %d.', global_step.numpy())