コード例 #1
0
  def test_create_sample_distance_fn_case_2(self):
    # Shape = [1, 2, 2].
    lhs = tf.constant([[[1.0, 2.0], [3.0, 4.0]]])
    # Shape = [1, 3, 2].
    rhs = tf.constant([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])
    distances = loss_utils.create_sample_distance_fn(
        pair_type=common.DISTANCE_PAIR_TYPE_ALL_PAIRS,
        distance_kernel=common.DISTANCE_KERNEL_SQUARED_L2,
        pairwise_reduction=common.DISTANCE_REDUCTION_MEAN,
        componentwise_reduction=tf.identity)(lhs, rhs)

    self.assertAllClose(distances, [56.0 / 6.0])
コード例 #2
0
  def test_create_sample_distance_fn_case_1(self):
    # Shape = [2, 3, 2].
    lhs = tf.constant([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                       [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]])
    # Shape = [2, 4, 2].
    rhs = tf.constant([[[16.0, 15.0], [14.0, 13.0], [12.0, 11.0], [10.0, 9.0]],
                       [[8.0, 7.0], [6.0, 5.0], [4.0, 3.0], [2.0, 1.0]]])
    distances = loss_utils.create_sample_distance_fn(
        pair_type=common.DISTANCE_PAIR_TYPE_ALL_PAIRS,
        distance_kernel=common.DISTANCE_KERNEL_SQUARED_L2,
        pairwise_reduction=functools.partial(tf.math.reduce_min, axis=[-2, -1]),
        componentwise_reduction=tf.identity)(lhs, rhs)

    self.assertAllClose(distances, [34.0, 2.0])
コード例 #3
0
    def test_create_sample_distance_fn_case_2(self):
        # Shape = [1, 2, 2].
        lhs = tf.constant([[[1.0, 2.0], [3.0, 4.0]]])
        # Shape = [1, 3, 2].
        rhs = tf.constant([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])
        distances = loss_utils.create_sample_distance_fn(
            pair_type=common.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=common.DISTANCE_KERNEL_SQUARED_L2,
            pairwise_reduction=common.DISTANCE_REDUCTION_MEAN,
            componentwise_reduction=tf.identity)(lhs, rhs)

        with self.session() as sess:
            sess.run(tf.global_variables_initializer())
            distances_result = sess.run(distances)

        self.assertAllClose(distances_result, [56.0 / 6.0])
コード例 #4
0
    def test_compute_keypoint_triplet_losses_with_sample_mining_embeddings(
            self):
        # Shape = [3, 3, 1, 2].
        anchor_embeddings = tf.constant([
            [[[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]]],
            [[[3.0, 4.0]], [[3.0, 4.0]], [[3.0, 4.0]]],
            [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]],
        ])
        # Shape = [3, 3, 1, 2].
        positive_embeddings = tf.constant([
            [[[2.0, 1.0]], [[2.0, 1.0]], [[2.0, 1.0]]],
            [[[6.0, 5.0]], [[6.0, 5.0]], [[6.0, 5.0]]],
            [[[7.0, 6.0]], [[7.0, 6.0]], [[7.0, 6.0]]],
        ])
        # Shape = [4, 3, 2, 2].
        match_embeddings = tf.constant([
            [[[3.0, 2.0], [3.0, 2.0]], [[3.0, 2.0], [3.0, 2.0]],
             [[3.0, 2.0], [3.0, 2.0]]],
            [[[4.0, 3.0], [4.0, 3.0]], [[4.0, 3.0], [4.0, 3.0]],
             [[4.0, 3.0], [4.0, 3.0]]],
            [[[6.0, 5.0], [6.0, 5.0]], [[6.0, 5.0], [6.0, 5.0]],
             [[6.0, 5.0], [6.0, 5.0]]],
            [[[8.0, 7.0], [8.0, 7.0]], [[8.0, 7.0], [8.0, 7.0]],
             [[8.0, 7.0], [8.0, 7.0]]],
        ])
        # Shape = [3, 1].
        anchor_keypoints = tf.constant([[1], [2], [3]])
        # Shape = [4, 1].
        match_keypoints = tf.constant([[1], [2], [3], [4]])

        def mock_keypoint_distance_fn(unused_lhs, unused_rhs):
            return tf.constant([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0],
                                [1.0, 0.0, 1.0, 0.0]])

        # Shape = [3, 3, 1, 2].
        anchor_mining_embeddings = tf.constant([
            [[[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]]],
            [[[3.0, 4.0]], [[3.0, 4.0]], [[3.0, 4.0]]],
            [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]],
        ])
        # Shape = [3, 3, 1, 2].
        positive_mining_embeddings = tf.constant([
            [[[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]]],
            [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]],
            [[[6.0, 7.0]], [[6.0, 7.0]], [[6.0, 7.0]]],
        ])
        # Shape = [4, 3, 1, 2].
        match_mining_embeddings = tf.constant([
            [[[2.0, 3.0]], [[2.0, 3.0]], [[2.0, 3.0]]],
            [[[3.0, 4.0]], [[3.0, 4.0]], [[3.0, 4.0]]],
            [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]],
            [[[7.0, 8.0]], [[7.0, 8.0]], [[7.0, 8.0]]],
        ])

        loss, summaries = loss_utils.compute_keypoint_triplet_losses(
            anchor_embeddings,
            positive_embeddings,
            match_embeddings,
            anchor_keypoints,
            match_keypoints,
            margin=120.0,
            min_negative_keypoint_distance=0.5,
            use_semi_hard=True,
            exclude_inactive_triplet_loss=True,
            embedding_sample_distance_fn=loss_utils.create_sample_distance_fn(
                pairwise_reduction=functools.partial(tf.math.reduce_sum,
                                                     axis=[-2, -1]),
                componentwise_reduction=functools.partial(tf.math.reduce_sum,
                                                          axis=[-1])),
            keypoint_distance_fn=mock_keypoint_distance_fn,
            anchor_mining_embeddings=anchor_mining_embeddings,
            positive_mining_embeddings=positive_mining_embeddings,
            match_mining_embeddings=match_mining_embeddings)

        with self.session() as sess:
            loss_result, summaries_result = sess.run([loss, summaries])

        self.assertAlmostEqual(loss_result, 57.0)

        expected_summaries_result = {
            'triplet_loss/Margin': 120.0,
            'triplet_loss/Anchor/Positive/Distance/Mean': 48.0 / 3,
            'triplet_loss/Anchor/Positive/Distance/Median': 12.0,
            'triplet_loss/Anchor/HardNegative/Distance/Mean': 48.0 / 3,
            'triplet_loss/Anchor/HardNegative/Distance/Median': 12.0,
            'triplet_loss/Anchor/SemiHardNegative/Distance/Mean': 348.0 / 3,
            'triplet_loss/Anchor/SemiHardNegative/Distance/Median': 120.0,
            'triplet_loss/HardNegative/Loss/All': 360.0 / 3,
            'triplet_loss/HardNegative/Loss/Active': 360.0 / 3,
            'triplet_loss/HardNegative/ActiveTripletNum': 3,
            'triplet_loss/HardNegative/ActiveTripletRatio': 1.0,
            'triplet_loss/SemiHardNegative/Loss/All': 114.0 / 3,
            'triplet_loss/SemiHardNegative/Loss/Active': 114.0 / 2,
            'triplet_loss/SemiHardNegative/ActiveTripletNum': 2,
            'triplet_loss/SemiHardNegative/ActiveTripletRatio': 2.0 / 3,
            'triplet_mining/Anchor/Positive/Distance/Mean': 30.0 / 3,
            'triplet_mining/Anchor/Positive/Distance/Median': 6.0,
            'triplet_mining/Anchor/HardNegative/Distance/Mean': 6.0 / 3,
            'triplet_mining/Anchor/HardNegative/Distance/Median': 0.0,
            'triplet_mining/Anchor/SemiHardNegative/Distance/Mean': 156.0 / 3,
            'triplet_mining/Anchor/SemiHardNegative/Distance/Median': 54.0,
            'triplet_mining/HardNegative/Loss/All': 384.0 / 3,
            'triplet_mining/HardNegative/Loss/Active': 384.0 / 3,
            'triplet_mining/HardNegative/ActiveTripletNum': 3,
            'triplet_mining/HardNegative/ActiveTripletRatio': 1.0,
            'triplet_mining/SemiHardNegative/Loss/All': 234.0 / 3,
            'triplet_mining/SemiHardNegative/Loss/Active': 234.0 / 3,
            'triplet_mining/SemiHardNegative/ActiveTripletNum': 3,
            'triplet_mining/SemiHardNegative/ActiveTripletRatio': 1.0,
        }
        self._assert_dict_equal_or_almost_equal(summaries_result,
                                                expected_summaries_result)
コード例 #5
0
def _validate_and_setup(common_module, keypoint_profiles_module, models_module,
                        keypoint_distance_config_override, embedder_fn_kwargs):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.base_model_type,
                  common_module.SUPPORTED_BASE_MODEL_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type in [
                common_module.DISTANCE_TYPE_SAMPLE,
                common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
        ]:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d)

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles_module.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profile_2d,
        'embedder_fn':
        models_module.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=True,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm,
            dropout_rate=FLAGS.dropout_rate,
            **embedder_fn_kwargs),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel
        in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
    }

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs
コード例 #6
0
def _validate_and_setup(common_module, keypoint_profiles_module, models_module,
                        keypoint_distance_config_override,
                        create_model_input_fn_kwargs, embedder_fn_kwargs):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.model_input_keypoint_mask_type,
                  common_module.SUPPORTED_MODEL_INPUT_KEYPOINT_MASK_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.base_model_type,
                  common_module.SUPPORTED_BASE_MODEL_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type != common_module.DISTANCE_TYPE_CENTER:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD,
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD,
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d)

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles_module.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profile_2d,
        'create_model_input_fn':
        functools.partial(
            input_generator.create_model_input,
            model_input_keypoint_mask_type=(
                FLAGS.model_input_keypoint_mask_type),
            uniform_keypoint_jittering_max_offset_2d=(
                FLAGS.uniform_keypoint_jittering_max_offset_2d),
            gaussian_keypoint_jittering_offset_stddev_2d=(
                FLAGS.gaussian_keypoint_jittering_offset_stddev_2d),
            keypoint_dropout_probs=[
                float(x) for x in FLAGS.keypoint_dropout_probs
            ],
            set_on_mask_for_non_anchors=FLAGS.set_on_mask_for_non_anchors,
            mix_mask_sub_batches=FLAGS.mix_mask_sub_batches,
            forced_mask_on_part_names=FLAGS.forced_mask_on_part_names,
            forced_mask_off_part_names=FLAGS.forced_mask_off_part_names,
            **create_model_input_fn_kwargs),
        'embedder_fn':
        models_module.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=True,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm,
            dropout_rate=FLAGS.dropout_rate,
            **embedder_fn_kwargs),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
        ],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
        'random_projection_camera_depth_range':
        [float(x) for x in FLAGS.random_projection_camera_depth_range],
    }

    embedding_sample_distance_fn_kwargs = {
        'EXPECTED_LIKELIHOOD_min_stddev': 0.1,
        'EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance': 100.0,
    }
    if FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
    ] or FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
    ]:
        # We only need sigmoid parameters when a related distance kernel is used.
        sigmoid_raw_a, sigmoid_a, sigmoid_b = pipeline_utils.get_sigmoid_parameters(
            name='MatchingSigmoid',
            raw_a_initial_value=FLAGS.sigmoid_raw_a_initial,
            b_initial_value=FLAGS.sigmoid_b_initial,
            a_range=(None, FLAGS.sigmoid_a_max))
        embedding_sample_distance_fn_kwargs.update({
            'L2_SIGMOID_MATCHING_PROB_a':
            sigmoid_a,
            'L2_SIGMOID_MATCHING_PROB_b':
            sigmoid_b,
            'SQUARED_L2_SIGMOID_MATCHING_PROB_a':
            sigmoid_a,
            'SQUARED_L2_SIGMOID_MATCHING_PROB_b':
            sigmoid_b,
        })
        configs.update({
            'sigmoid_raw_a': sigmoid_raw_a,
            'sigmoid_a': sigmoid_a,
            'sigmoid_b': sigmoid_b,
        })

    configs.update({
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            **embedding_sample_distance_fn_kwargs),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            **embedding_sample_distance_fn_kwargs),
    })

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs