def _validate_and_setup(common_module, keypoint_profiles_module, models_module, keypoint_distance_config_override, embedder_fn_kwargs): """Validates and sets up training configurations.""" # Set default values for unspecified flags. if FLAGS.use_normalized_embeddings_for_triplet_mining is None: FLAGS.use_normalized_embeddings_for_triplet_mining = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None: FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.positive_pairwise_distance_type is None: FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type if FLAGS.positive_pairwise_distance_kernel is None: FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel if FLAGS.positive_pairwise_pairwise_reduction is None: FLAGS.positive_pairwise_pairwise_reduction = ( FLAGS.triplet_pairwise_reduction) if FLAGS.positive_pairwise_componentwise_reduction is None: FLAGS.positive_pairwise_componentwise_reduction = ( FLAGS.triplet_componentwise_reduction) # Validate flags. validate_flag = common_module.validate validate_flag(FLAGS.model_input_keypoint_type, common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES) validate_flag(FLAGS.embedding_type, common_module.SUPPORTED_EMBEDDING_TYPES) validate_flag(FLAGS.base_model_type, common_module.SUPPORTED_BASE_MODEL_TYPES) validate_flag(FLAGS.keypoint_distance_type, common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.triplet_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.triplet_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.positive_pairwise_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.positive_pairwise_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT: if FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ]: raise ValueError( 'No support for triplet distance type `%s` for embedding type `%s`.' % (FLAGS.triplet_distance_type, FLAGS.embedding_type)) if FLAGS.kl_regularization_loss_weight > 0.0: raise ValueError( 'No support for KL regularization loss for embedding type `%s`.' % FLAGS.embedding_type) if ((FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ] or FLAGS.positive_pairwise_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ]) and FLAGS.num_embedding_samples <= 0): raise ValueError( 'Must specify positive `num_embedding_samples` to use `%s` ' 'triplet/positive pairwise distance type.' % FLAGS.triplet_distance_type) if (((FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD ]) != (FLAGS.triplet_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ])) or ((FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD ]) != (FLAGS.positive_pairwise_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ]))): raise ValueError( 'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance ' 'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer' ' in pairs.') keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_2d) # Set up configurations. configs = { 'keypoint_profile_3d': keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_3d), 'keypoint_profile_2d': keypoint_profile_2d, 'embedder_fn': models_module.get_embedder( base_model_type=FLAGS.base_model_type, embedding_type=FLAGS.embedding_type, num_embedding_components=FLAGS.num_embedding_components, embedding_size=FLAGS.embedding_size, num_embedding_samples=FLAGS.num_embedding_samples, is_training=True, num_fc_blocks=FLAGS.num_fc_blocks, num_fcs_per_block=FLAGS.num_fcs_per_block, num_hidden_nodes=FLAGS.num_hidden_nodes, num_bottleneck_nodes=FLAGS.num_bottleneck_nodes, weight_max_norm=FLAGS.weight_max_norm, dropout_rate=FLAGS.dropout_rate, **embedder_fn_kwargs), 'triplet_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, common_module=common_module), 'triplet_mining_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, replace_samples_with_means=True, common_module=common_module), 'triplet_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.triplet_distance_kernel, pairwise_reduction=FLAGS.triplet_pairwise_reduction, componentwise_reduction=FLAGS.triplet_componentwise_reduction, # We initialize the sigmoid parameters to avoid model being stuck # in a `dead zone` at the beginning of training. L2_SIGMOID_MATCHING_PROB_a_initializer=( tf.initializers.constant(-0.65)), L2_SIGMOID_MATCHING_PROB_b_initializer=( tf.initializers.constant(-0.5)), EXPECTED_LIKELIHOOD_min_stddev=0.1, EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0), 'positive_pairwise_embedding_keys': pipeline_utils.get_embedding_keys( FLAGS.positive_pairwise_distance_type, common_module=common_module), 'positive_pairwise_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.positive_pairwise_distance_kernel, pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction, componentwise_reduction=( FLAGS.positive_pairwise_componentwise_reduction), # We initialize the sigmoid parameters to avoid model being stuck # in a `dead zone` at the beginning of training. L2_SIGMOID_MATCHING_PROB_a_initializer=( tf.initializers.constant(-0.65)), L2_SIGMOID_MATCHING_PROB_b_initializer=( tf.initializers.constant(-0.5)), EXPECTED_LIKELIHOOD_min_stddev=0.1, EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0), 'summarize_matching_sigmoid_vars': FLAGS.triplet_distance_kernel in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB], 'random_projection_azimuth_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_azimuth_range ], 'random_projection_elevation_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_elevation_range ], 'random_projection_roll_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_roll_range ], } if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE: configs.update({ 'keypoint_distance_fn': keypoint_utils.compute_procrustes_aligned_mpjpes, 'min_negative_keypoint_distance': FLAGS.min_negative_keypoint_mpjpe }) # We use the following assignments to get around pytype check failures. # TODO(liuti): Figure out a better workaround. if 'keypoint_distance_fn' in keypoint_distance_config_override: configs['keypoint_distance_fn'] = ( keypoint_distance_config_override['keypoint_distance_fn']) if 'min_negative_keypoint_distance' in keypoint_distance_config_override: configs['min_negative_keypoint_distance'] = ( keypoint_distance_config_override['min_negative_keypoint_distance'] ) if ('keypoint_distance_fn' not in configs or 'min_negative_keypoint_distance' not in configs): raise ValueError('Invalid keypoint distance config: %s.' % str(configs)) if FLAGS.task == 0 and not FLAGS.profile_only: # Save all key flags. pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir, 'all_flags.train.json') return configs
def _validate_and_setup(common_module, keypoint_profiles_module, models_module, keypoint_distance_config_override, create_model_input_fn_kwargs, embedder_fn_kwargs): """Validates and sets up training configurations.""" # Set default values for unspecified flags. if FLAGS.use_normalized_embeddings_for_triplet_mining is None: FLAGS.use_normalized_embeddings_for_triplet_mining = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None: FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.positive_pairwise_distance_type is None: FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type if FLAGS.positive_pairwise_distance_kernel is None: FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel if FLAGS.positive_pairwise_pairwise_reduction is None: FLAGS.positive_pairwise_pairwise_reduction = ( FLAGS.triplet_pairwise_reduction) if FLAGS.positive_pairwise_componentwise_reduction is None: FLAGS.positive_pairwise_componentwise_reduction = ( FLAGS.triplet_componentwise_reduction) # Validate flags. validate_flag = common_module.validate validate_flag(FLAGS.model_input_keypoint_type, common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES) validate_flag(FLAGS.model_input_keypoint_mask_type, common_module.SUPPORTED_MODEL_INPUT_KEYPOINT_MASK_TYPES) validate_flag(FLAGS.embedding_type, common_module.SUPPORTED_EMBEDDING_TYPES) validate_flag(FLAGS.base_model_type, common_module.SUPPORTED_BASE_MODEL_TYPES) validate_flag(FLAGS.keypoint_distance_type, common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.triplet_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.triplet_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.positive_pairwise_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.positive_pairwise_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT: if FLAGS.triplet_distance_type != common_module.DISTANCE_TYPE_CENTER: raise ValueError( 'No support for triplet distance type `%s` for embedding type `%s`.' % (FLAGS.triplet_distance_type, FLAGS.embedding_type)) if FLAGS.kl_regularization_loss_weight > 0.0: raise ValueError( 'No support for KL regularization loss for embedding type `%s`.' % FLAGS.embedding_type) if ((FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, ] or FLAGS.positive_pairwise_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, ]) and FLAGS.num_embedding_samples <= 0): raise ValueError( 'Must specify positive `num_embedding_samples` to use `%s` ' 'triplet/positive pairwise distance type.' % FLAGS.triplet_distance_type) if (((FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD, ]) != (FLAGS.triplet_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ])) or ((FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD, ]) != (FLAGS.positive_pairwise_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ]))): raise ValueError( 'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance ' 'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer' ' in pairs.') keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_2d) # Set up configurations. configs = { 'keypoint_profile_3d': keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_3d), 'keypoint_profile_2d': keypoint_profile_2d, 'create_model_input_fn': functools.partial( input_generator.create_model_input, model_input_keypoint_mask_type=( FLAGS.model_input_keypoint_mask_type), uniform_keypoint_jittering_max_offset_2d=( FLAGS.uniform_keypoint_jittering_max_offset_2d), gaussian_keypoint_jittering_offset_stddev_2d=( FLAGS.gaussian_keypoint_jittering_offset_stddev_2d), keypoint_dropout_probs=[ float(x) for x in FLAGS.keypoint_dropout_probs ], set_on_mask_for_non_anchors=FLAGS.set_on_mask_for_non_anchors, mix_mask_sub_batches=FLAGS.mix_mask_sub_batches, forced_mask_on_part_names=FLAGS.forced_mask_on_part_names, forced_mask_off_part_names=FLAGS.forced_mask_off_part_names, **create_model_input_fn_kwargs), 'embedder_fn': models_module.get_embedder( base_model_type=FLAGS.base_model_type, embedding_type=FLAGS.embedding_type, num_embedding_components=FLAGS.num_embedding_components, embedding_size=FLAGS.embedding_size, num_embedding_samples=FLAGS.num_embedding_samples, is_training=True, num_fc_blocks=FLAGS.num_fc_blocks, num_fcs_per_block=FLAGS.num_fcs_per_block, num_hidden_nodes=FLAGS.num_hidden_nodes, num_bottleneck_nodes=FLAGS.num_bottleneck_nodes, weight_max_norm=FLAGS.weight_max_norm, dropout_rate=FLAGS.dropout_rate, **embedder_fn_kwargs), 'triplet_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, common_module=common_module), 'triplet_mining_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, replace_samples_with_means=True, common_module=common_module), 'positive_pairwise_embedding_keys': pipeline_utils.get_embedding_keys( FLAGS.positive_pairwise_distance_type, common_module=common_module), 'summarize_matching_sigmoid_vars': FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB ], 'random_projection_azimuth_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_azimuth_range ], 'random_projection_elevation_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_elevation_range ], 'random_projection_roll_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_roll_range ], 'random_projection_camera_depth_range': [float(x) for x in FLAGS.random_projection_camera_depth_range], } embedding_sample_distance_fn_kwargs = { 'EXPECTED_LIKELIHOOD_min_stddev': 0.1, 'EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance': 100.0, } if FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB ] or FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB ]: # We only need sigmoid parameters when a related distance kernel is used. sigmoid_raw_a, sigmoid_a, sigmoid_b = pipeline_utils.get_sigmoid_parameters( name='MatchingSigmoid', raw_a_initial_value=FLAGS.sigmoid_raw_a_initial, b_initial_value=FLAGS.sigmoid_b_initial, a_range=(None, FLAGS.sigmoid_a_max)) embedding_sample_distance_fn_kwargs.update({ 'L2_SIGMOID_MATCHING_PROB_a': sigmoid_a, 'L2_SIGMOID_MATCHING_PROB_b': sigmoid_b, 'SQUARED_L2_SIGMOID_MATCHING_PROB_a': sigmoid_a, 'SQUARED_L2_SIGMOID_MATCHING_PROB_b': sigmoid_b, }) configs.update({ 'sigmoid_raw_a': sigmoid_raw_a, 'sigmoid_a': sigmoid_a, 'sigmoid_b': sigmoid_b, }) configs.update({ 'triplet_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.triplet_distance_kernel, pairwise_reduction=FLAGS.triplet_pairwise_reduction, componentwise_reduction=FLAGS.triplet_componentwise_reduction, **embedding_sample_distance_fn_kwargs), 'positive_pairwise_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.positive_pairwise_distance_kernel, pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction, componentwise_reduction=( FLAGS.positive_pairwise_componentwise_reduction), **embedding_sample_distance_fn_kwargs), }) if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE: configs.update({ 'keypoint_distance_fn': keypoint_utils.compute_procrustes_aligned_mpjpes, 'min_negative_keypoint_distance': FLAGS.min_negative_keypoint_mpjpe }) # We use the following assignments to get around pytype check failures. # TODO(liuti): Figure out a better workaround. if 'keypoint_distance_fn' in keypoint_distance_config_override: configs['keypoint_distance_fn'] = ( keypoint_distance_config_override['keypoint_distance_fn']) if 'min_negative_keypoint_distance' in keypoint_distance_config_override: configs['min_negative_keypoint_distance'] = ( keypoint_distance_config_override['min_negative_keypoint_distance'] ) if ('keypoint_distance_fn' not in configs or 'min_negative_keypoint_distance' not in configs): raise ValueError('Invalid keypoint distance config: %s.' % str(configs)) if FLAGS.task == 0 and not FLAGS.profile_only: # Save all key flags. pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir, 'all_flags.train.json') return configs