Пример #1
0
def _validate_and_setup(common_module, keypoint_profiles_module, models_module,
                        keypoint_distance_config_override, embedder_fn_kwargs):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.base_model_type,
                  common_module.SUPPORTED_BASE_MODEL_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type in [
                common_module.DISTANCE_TYPE_SAMPLE,
                common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
        ]:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d)

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles_module.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profile_2d,
        'embedder_fn':
        models_module.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=True,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm,
            dropout_rate=FLAGS.dropout_rate,
            **embedder_fn_kwargs),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel
        in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
    }

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs
Пример #2
0
def _validate_and_setup(common_module, keypoint_profiles_module, models_module,
                        keypoint_distance_config_override,
                        create_model_input_fn_kwargs, embedder_fn_kwargs):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.model_input_keypoint_mask_type,
                  common_module.SUPPORTED_MODEL_INPUT_KEYPOINT_MASK_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.base_model_type,
                  common_module.SUPPORTED_BASE_MODEL_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type != common_module.DISTANCE_TYPE_CENTER:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD,
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD,
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d)

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles_module.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profile_2d,
        'create_model_input_fn':
        functools.partial(
            input_generator.create_model_input,
            model_input_keypoint_mask_type=(
                FLAGS.model_input_keypoint_mask_type),
            uniform_keypoint_jittering_max_offset_2d=(
                FLAGS.uniform_keypoint_jittering_max_offset_2d),
            gaussian_keypoint_jittering_offset_stddev_2d=(
                FLAGS.gaussian_keypoint_jittering_offset_stddev_2d),
            keypoint_dropout_probs=[
                float(x) for x in FLAGS.keypoint_dropout_probs
            ],
            set_on_mask_for_non_anchors=FLAGS.set_on_mask_for_non_anchors,
            mix_mask_sub_batches=FLAGS.mix_mask_sub_batches,
            forced_mask_on_part_names=FLAGS.forced_mask_on_part_names,
            forced_mask_off_part_names=FLAGS.forced_mask_off_part_names,
            **create_model_input_fn_kwargs),
        'embedder_fn':
        models_module.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=True,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm,
            dropout_rate=FLAGS.dropout_rate,
            **embedder_fn_kwargs),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
        ],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
        'random_projection_camera_depth_range':
        [float(x) for x in FLAGS.random_projection_camera_depth_range],
    }

    embedding_sample_distance_fn_kwargs = {
        'EXPECTED_LIKELIHOOD_min_stddev': 0.1,
        'EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance': 100.0,
    }
    if FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
    ] or FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB
    ]:
        # We only need sigmoid parameters when a related distance kernel is used.
        sigmoid_raw_a, sigmoid_a, sigmoid_b = pipeline_utils.get_sigmoid_parameters(
            name='MatchingSigmoid',
            raw_a_initial_value=FLAGS.sigmoid_raw_a_initial,
            b_initial_value=FLAGS.sigmoid_b_initial,
            a_range=(None, FLAGS.sigmoid_a_max))
        embedding_sample_distance_fn_kwargs.update({
            'L2_SIGMOID_MATCHING_PROB_a':
            sigmoid_a,
            'L2_SIGMOID_MATCHING_PROB_b':
            sigmoid_b,
            'SQUARED_L2_SIGMOID_MATCHING_PROB_a':
            sigmoid_a,
            'SQUARED_L2_SIGMOID_MATCHING_PROB_b':
            sigmoid_b,
        })
        configs.update({
            'sigmoid_raw_a': sigmoid_raw_a,
            'sigmoid_a': sigmoid_a,
            'sigmoid_b': sigmoid_b,
        })

    configs.update({
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            **embedding_sample_distance_fn_kwargs),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            **embedding_sample_distance_fn_kwargs),
    })

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs