Exemple #1
0
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator, keypoint_preprocessor_3d):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
    keypoint_preprocessor_3d: A function handle for preprocessing raw 3D
      keypoints.
  """
  _validate(common_module)

  log_dir_path = FLAGS.log_dir_path
  pipeline_utils.create_dir_and_save_flags(flags, log_dir_path,
                                           'all_flags.train.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      os.path.join(log_dir_path, 'train_logs'), flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)
  keypoint_profile_3d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_3d)

  model = algorithms.get_algorithm(
      algorithm_type=FLAGS.algorithm_type,
      pose_embedding_dim=FLAGS.pose_embedding_dim,
      view_embedding_dim=FLAGS.view_embedding_dim,
      fusion_op_type=FLAGS.fusion_op_type,
      view_loss_weight=FLAGS.view_loss_weight,
      regularization_loss_weight=FLAGS.regularization_loss_weight,
      disentangle_loss_weight=FLAGS.disentangle_loss_weight,
      embedder_type=FLAGS.embedder_type)
  optimizers = algorithms.get_optimizers(
      algorithm_type=FLAGS.algorithm_type, learning_rate=FLAGS.learning_rate)
  global_step = optimizers['encoder_optimizer'].iterations
  ckpt_manager, _, _ = utils.create_checkpoint(
      log_dir_path, **optimizers, model=model, global_step=global_step)

  # Setup the training dataset.
  dataset = pipelines.create_dataset_from_tables(
      [FLAGS.input_table], [FLAGS.batch_size],
      num_instances_per_record=2,
      shuffle=True,
      num_epochs=None,
      drop_remainder=True,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      keypoint_names_3d=keypoint_profile_3d.keypoint_names,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)

  def train_one_iteration(inputs):
    """Trains the model for one iteration.

    Args:
      inputs: A dictionary for training inputs.

    Returns:
      The training loss for this iteration.
    """
    _, side_outputs = pipelines.create_model_input(
        inputs, FLAGS.model_input_keypoint_type, keypoint_profile_2d,
        keypoint_profile_3d)

    keypoints_2d = side_outputs[common_module.KEY_PREPROCESSED_KEYPOINTS_2D]
    keypoints_3d, _ = keypoint_preprocessor_3d(
        inputs[common_module.KEY_KEYPOINTS_3D],
        keypoint_profile_3d,
        normalize_keypoints_3d=True)
    keypoints_2d, keypoints_3d = data_utils.shuffle_batches(
        [keypoints_2d, keypoints_3d])

    return model.train((keypoints_2d, keypoints_3d), **optimizers)

  if FLAGS.compile:
    train_one_iteration = tf.function(train_one_iteration)

  record_every_n_steps = min(100, FLAGS.num_iterations)
  save_ckpt_every_n_steps = min(10000, FLAGS.num_iterations)

  with summary_writer.as_default():
    with tf.summary.record_if(global_step % record_every_n_steps == 0):
      start = time.time()
      for inputs in dataset:
        if global_step >= FLAGS.num_iterations:
          break

        model_losses = train_one_iteration(inputs)
        duration = time.time() - start
        start = time.time()

        for tag, losses in model_losses.items():
          for name, loss in losses.items():
            tf.summary.scalar(
                'train/{}/{}'.format(tag, name), loss, step=global_step)

        for tag, optimizer in optimizers.items():
          tf.summary.scalar(
              'train/{}_learning_rate'.format(tag),
              optimizer.lr,
              step=global_step)

        tf.summary.scalar('train/batch_time', duration, step=global_step)
        tf.summary.scalar('global_step/sec', 1 / duration, step=global_step)

        if global_step % record_every_n_steps == 0:
          logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format(
              global_step.numpy(), FLAGS.num_iterations, duration,
              model_losses['encoder']['total_loss'].numpy()))

        # Save checkpoint.
        if global_step % save_ckpt_every_n_steps == 0:
          ckpt_manager.save(checkpoint_number=global_step)
          logging.info('Checkpoint saved at step %d.', global_step.numpy())
Exemple #2
0
def _validate_and_setup(common_module, keypoint_profiles_module, models_module,
                        keypoint_distance_config_override, embedder_fn_kwargs):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.base_model_type,
                  common_module.SUPPORTED_BASE_MODEL_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type in [
                common_module.DISTANCE_TYPE_SAMPLE,
                common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
        ]:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d)

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles_module.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profile_2d,
        'embedder_fn':
        models_module.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=True,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm,
            dropout_rate=FLAGS.dropout_rate,
            **embedder_fn_kwargs),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel
        in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
    }

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
  """
  log_dir_path = FLAGS.log_dir_path
  pipeline_utils.create_dir_and_save_flags(
      flags, log_dir_path, 'all_flags.train_with_features.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      os.path.join(log_dir_path, 'train_logs'), flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)

  # Setup model.
  input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate)
  if FLAGS.input_features_dim > 0:
    feature_dim = FLAGS.input_features_dim
    input_shape = (input_length, feature_dim)
  else:
    feature_dim = None
    input_shape = (input_length, 13 * 2)
  classifier = models.get_temporal_classifier(
      FLAGS.classifier_type,
      input_shape=input_shape,
      num_classes=FLAGS.num_classes)
  ema_classifier = tf.keras.models.clone_model(classifier)
  optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
  optimizer = tfa_optimizers.MovingAverage(optimizer)
  global_step = optimizer.iterations
  ckpt_manager, _, _ = utils.create_checkpoint(
      log_dir_path,
      optimizer=optimizer,
      model=classifier,
      ema_model=ema_classifier,
      global_step=global_step)

  # Setup the training dataset.
  dataset = pipelines.create_dataset_from_tables(
      FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes],
      num_instances_per_record=1,
      shuffle=True,
      drop_remainder=True,
      num_epochs=None,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      feature_dim=feature_dim,
      num_classes=FLAGS.num_classes,
      num_frames=FLAGS.num_frames,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size,
      common_module=common_module,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)
  loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

  def train_one_iteration(inputs):
    """Trains the model for one iteration.

    Args:
      inputs: A dictionary for training inputs.

    Returns:
      loss: The training loss for this iteration.
    """
    if FLAGS.input_features_dim > 0:
      features = inputs[common_module.KEY_FEATURES]
    else:
      features, _ = pipelines.create_model_input(
          inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
          keypoint_profile_2d)

    features = tf.squeeze(features, axis=1)
    features = features[:, ::FLAGS.downsample_rate, :]
    labels = inputs[common_module.KEY_CLASS_TARGETS]
    labels = tf.squeeze(labels, axis=1)

    with tf.GradientTape() as tape:
      outputs = classifier(features, training=True)
      regularization_loss = sum(classifier.losses)
      crossentropy_loss = loss_object(labels, outputs)
      total_loss = crossentropy_loss + regularization_loss

    trainable_variables = classifier.trainable_variables
    grads = tape.gradient(total_loss, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))

    for grad, trainable_variable in zip(grads, trainable_variables):
      tf.summary.scalar(
          'summarize_grads/' + trainable_variable.name,
          tf.linalg.norm(grad),
          step=global_step)

    return dict(
        total_loss=total_loss,
        crossentropy_loss=crossentropy_loss,
        regularization_loss=regularization_loss)

  if FLAGS.compile:
    train_one_iteration = tf.function(train_one_iteration)

  record_every_n_steps = min(5, FLAGS.num_iterations)
  save_ckpt_every_n_steps = min(500, FLAGS.num_iterations)

  with summary_writer.as_default():
    with tf.summary.record_if(global_step % record_every_n_steps == 0):
      start = time.time()
      for inputs in dataset:
        if global_step >= FLAGS.num_iterations:
          break

        model_losses = train_one_iteration(inputs)
        duration = time.time() - start
        start = time.time()

        for name, loss in model_losses.items():
          tf.summary.scalar('train/' + name, loss, step=global_step)

        tf.summary.scalar('train/learning_rate', optimizer.lr, step=global_step)
        tf.summary.scalar('train/batch_time', duration, step=global_step)
        tf.summary.scalar('global_step/sec', 1 / duration, step=global_step)

        if global_step % record_every_n_steps == 0:
          logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format(
              global_step.numpy(), FLAGS.num_iterations, duration,
              model_losses['total_loss'].numpy()))

        # Save checkpoint.
        if global_step % save_ckpt_every_n_steps == 0:
          utils.assign_moving_average_vars(classifier, ema_classifier,
                                           optimizer)
          ckpt_manager.save(checkpoint_number=global_step)
          logging.info('Checkpoint saved at step %d.', global_step.numpy())
def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator):
  """Runs training pipeline.

  Args:
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
  """
  log_dir_path = FLAGS.log_dir_path
  if not tf.io.gfile.exists(log_dir_path):
    raise ValueError(
        'The directory {} does not exist. Please provide a new log_dir_path.'
        .format(log_dir_path))
  eval_log_dir = os.path.join(log_dir_path, FLAGS.eval_name)
  pipeline_utils.create_dir_and_save_flags(flags, eval_log_dir,
                                           'all_flags.eval_with_features.json')

  # Setup summary writer.
  summary_writer = tf.summary.create_file_writer(
      eval_log_dir, flush_millis=10000)

  # Setup configuration.
  keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
      FLAGS.keypoint_profile_name_2d)

  # Setup model.
  input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate)
  if FLAGS.input_features_dim > 0:
    feature_dim = FLAGS.input_features_dim
    input_shape = (input_length, feature_dim)
  else:
    feature_dim = None
    input_shape = (input_length, 13 * 2)
  classifier = models.get_temporal_classifier(
      FLAGS.classifier_type,
      input_shape=input_shape,
      num_classes=FLAGS.num_classes)
  ema_classifier = tf.keras.models.clone_model(classifier)
  global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
  dataset = pipelines.create_dataset_from_tables(
      FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes],
      num_instances_per_record=1,
      shuffle=False,
      num_epochs=1,
      drop_remainder=False,
      keypoint_names_2d=keypoint_profile_2d.keypoint_names,
      feature_dim=feature_dim,
      num_classes=FLAGS.num_classes,
      num_frames=FLAGS.num_frames,
      common_module=common_module,
      dataset_class=input_dataset_class,
      input_example_parser_creator=input_example_parser_creator)

  if FLAGS.compile:
    classifier.call = tf.function(classifier.call)
    ema_classifier.call = tf.function(ema_classifier.call)

  top_1_best_accuracy = None
  top_5_best_accuracy = None
  evaluated_last_ckpt = False

  def timeout_fn():
    """Timeout function to stop the evaluation."""
    return evaluated_last_ckpt

  def evaluate_once():
    """Evaluates the model for one time."""
    _, status, _ = utils.create_checkpoint(
        log_dir_path,
        model=classifier,
        ema_model=ema_classifier,
        global_step=global_step)
    status.expect_partial()
    logging.info('Last checkpoint [iteration: %d] restored at %s.',
                 global_step.numpy(), log_dir_path)

    if global_step.numpy() >= FLAGS.max_iteration:
      nonlocal evaluated_last_ckpt
      evaluated_last_ckpt = True

    top_1_accuracy = tf.keras.metrics.CategoricalAccuracy()
    top_5_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(k=5)
    for inputs in dataset:
      if FLAGS.input_features_dim > 0:
        features = inputs[common_module.KEY_FEATURES]
      else:
        features, _ = pipelines.create_model_input(
            inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
            keypoint_profile_2d)
      features = tf.squeeze(features, axis=1)
      features = features[:, ::FLAGS.downsample_rate, :]
      labels = inputs[common_module.KEY_CLASS_TARGETS]
      labels = tf.squeeze(labels, axis=1)

      if FLAGS.use_moving_average:
        outputs = ema_classifier(features, training=False)
      else:
        outputs = classifier(features, training=False)
      top_1_accuracy.update_state(y_true=labels, y_pred=outputs)
      top_5_accuracy.update_state(y_true=labels, y_pred=outputs)

    nonlocal top_1_best_accuracy
    if (top_1_best_accuracy is None or
        top_1_accuracy.result().numpy() > top_1_best_accuracy):
      top_1_best_accuracy = top_1_accuracy.result().numpy()

    nonlocal top_5_best_accuracy
    if (top_5_best_accuracy is None or
        top_5_accuracy.result().numpy() > top_5_best_accuracy):
      top_5_best_accuracy = top_5_accuracy.result().numpy()

    tf.summary.scalar(
        'eval/Basic/Top1_Accuracy',
        top_1_accuracy.result(),
        step=global_step.numpy())
    tf.summary.scalar(
        'eval/Best/Top1_Accuracy',
        top_1_best_accuracy,
        step=global_step.numpy())
    tf.summary.scalar(
        'eval/Basic/Top5_Accuracy',
        top_5_accuracy.result(),
        step=global_step.numpy())
    tf.summary.scalar(
        'eval/Best/Top5_Accuracy',
        top_5_best_accuracy,
        step=global_step.numpy())
    logging.info('Accuracy: {:.2f}'.format(top_1_accuracy.result().numpy()))

  with summary_writer.as_default():
    with tf.summary.record_if(True):
      if FLAGS.continuous_eval:
        for _ in tf.train.checkpoints_iterator(log_dir_path, timeout=1,
                                               timeout_fn=timeout_fn):
          evaluate_once()
      else:
        evaluate_once()
Exemple #5
0
def _validate_and_setup(common_module, keypoint_profiles_module, models_module,
                        keypoint_distance_config_override,
                        create_model_input_fn_kwargs, embedder_fn_kwargs):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.model_input_keypoint_mask_type,
                  common_module.SUPPORTED_MODEL_INPUT_KEYPOINT_MASK_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.base_model_type,
                  common_module.SUPPORTED_BASE_MODEL_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type != common_module.DISTANCE_TYPE_CENTER:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD,
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD,
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d)

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles_module.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profile_2d,
        'create_model_input_fn':
        functools.partial(
            input_generator.create_model_input,
            model_input_keypoint_mask_type=(
                FLAGS.model_input_keypoint_mask_type),
            uniform_keypoint_jittering_max_offset_2d=(
                FLAGS.uniform_keypoint_jittering_max_offset_2d),
            gaussian_keypoint_jittering_offset_stddev_2d=(
                FLAGS.gaussian_keypoint_jittering_offset_stddev_2d),
            keypoint_dropout_probs=[
                float(x) for x in FLAGS.keypoint_dropout_probs
            ],
            set_on_mask_for_non_anchors=FLAGS.set_on_mask_for_non_anchors,
            mix_mask_sub_batches=FLAGS.mix_mask_sub_batches,
            forced_mask_on_part_names=FLAGS.forced_mask_on_part_names,
            forced_mask_off_part_names=FLAGS.forced_mask_off_part_names,
            **create_model_input_fn_kwargs),
        'embedder_fn':
        models_module.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=True,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm,
            dropout_rate=FLAGS.dropout_rate,
            **embedder_fn_kwargs),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
        ],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
        'random_projection_camera_depth_range':
        [float(x) for x in FLAGS.random_projection_camera_depth_range],
    }

    embedding_sample_distance_fn_kwargs = {
        'EXPECTED_LIKELIHOOD_min_stddev': 0.1,
        'EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance': 100.0,
    }
    if FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
    ] or FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
    ]:
        # We only need sigmoid parameters when a related distance kernel is used.
        sigmoid_raw_a, sigmoid_a, sigmoid_b = pipeline_utils.get_sigmoid_parameters(
            name='MatchingSigmoid',
            raw_a_initial_value=FLAGS.sigmoid_raw_a_initial,
            b_initial_value=FLAGS.sigmoid_b_initial,
            a_range=(None, FLAGS.sigmoid_a_max))
        embedding_sample_distance_fn_kwargs.update({
            'L2_SIGMOID_MATCHING_PROB_a':
            sigmoid_a,
            'L2_SIGMOID_MATCHING_PROB_b':
            sigmoid_b,
            'SQUARED_L2_SIGMOID_MATCHING_PROB_a':
            sigmoid_a,
            'SQUARED_L2_SIGMOID_MATCHING_PROB_b':
            sigmoid_b,
        })
        configs.update({
            'sigmoid_raw_a': sigmoid_raw_a,
            'sigmoid_a': sigmoid_a,
            'sigmoid_b': sigmoid_b,
        })

    configs.update({
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            **embedding_sample_distance_fn_kwargs),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            **embedding_sample_distance_fn_kwargs),
    })

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs