def run(input_dataset_class, common_module, keypoint_profiles_module, input_example_parser_creator, keypoint_preprocessor_3d): """Runs training pipeline. Args: input_dataset_class: An input dataset class that matches input table type. common_module: A Python module that defines common flags and constants. keypoint_profiles_module: A Python module that defines keypoint profiles. input_example_parser_creator: A function handle for creating data parser function. If None, uses the default parser creator. keypoint_preprocessor_3d: A function handle for preprocessing raw 3D keypoints. """ _validate(common_module) log_dir_path = FLAGS.log_dir_path pipeline_utils.create_dir_and_save_flags(flags, log_dir_path, 'all_flags.train.json') # Setup summary writer. summary_writer = tf.summary.create_file_writer( os.path.join(log_dir_path, 'train_logs'), flush_millis=10000) # Setup configuration. keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.keypoint_profile_name_2d) keypoint_profile_3d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.keypoint_profile_name_3d) model = algorithms.get_algorithm( algorithm_type=FLAGS.algorithm_type, pose_embedding_dim=FLAGS.pose_embedding_dim, view_embedding_dim=FLAGS.view_embedding_dim, fusion_op_type=FLAGS.fusion_op_type, view_loss_weight=FLAGS.view_loss_weight, regularization_loss_weight=FLAGS.regularization_loss_weight, disentangle_loss_weight=FLAGS.disentangle_loss_weight, embedder_type=FLAGS.embedder_type) optimizers = algorithms.get_optimizers( algorithm_type=FLAGS.algorithm_type, learning_rate=FLAGS.learning_rate) global_step = optimizers['encoder_optimizer'].iterations ckpt_manager, _, _ = utils.create_checkpoint( log_dir_path, **optimizers, model=model, global_step=global_step) # Setup the training dataset. dataset = pipelines.create_dataset_from_tables( [FLAGS.input_table], [FLAGS.batch_size], num_instances_per_record=2, shuffle=True, num_epochs=None, drop_remainder=True, keypoint_names_2d=keypoint_profile_2d.keypoint_names, keypoint_names_3d=keypoint_profile_3d.keypoint_names, shuffle_buffer_size=FLAGS.shuffle_buffer_size, dataset_class=input_dataset_class, input_example_parser_creator=input_example_parser_creator) def train_one_iteration(inputs): """Trains the model for one iteration. Args: inputs: A dictionary for training inputs. Returns: The training loss for this iteration. """ _, side_outputs = pipelines.create_model_input( inputs, FLAGS.model_input_keypoint_type, keypoint_profile_2d, keypoint_profile_3d) keypoints_2d = side_outputs[common_module.KEY_PREPROCESSED_KEYPOINTS_2D] keypoints_3d, _ = keypoint_preprocessor_3d( inputs[common_module.KEY_KEYPOINTS_3D], keypoint_profile_3d, normalize_keypoints_3d=True) keypoints_2d, keypoints_3d = data_utils.shuffle_batches( [keypoints_2d, keypoints_3d]) return model.train((keypoints_2d, keypoints_3d), **optimizers) if FLAGS.compile: train_one_iteration = tf.function(train_one_iteration) record_every_n_steps = min(100, FLAGS.num_iterations) save_ckpt_every_n_steps = min(10000, FLAGS.num_iterations) with summary_writer.as_default(): with tf.summary.record_if(global_step % record_every_n_steps == 0): start = time.time() for inputs in dataset: if global_step >= FLAGS.num_iterations: break model_losses = train_one_iteration(inputs) duration = time.time() - start start = time.time() for tag, losses in model_losses.items(): for name, loss in losses.items(): tf.summary.scalar( 'train/{}/{}'.format(tag, name), loss, step=global_step) for tag, optimizer in optimizers.items(): tf.summary.scalar( 'train/{}_learning_rate'.format(tag), optimizer.lr, step=global_step) tf.summary.scalar('train/batch_time', duration, step=global_step) tf.summary.scalar('global_step/sec', 1 / duration, step=global_step) if global_step % record_every_n_steps == 0: logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format( global_step.numpy(), FLAGS.num_iterations, duration, model_losses['encoder']['total_loss'].numpy())) # Save checkpoint. if global_step % save_ckpt_every_n_steps == 0: ckpt_manager.save(checkpoint_number=global_step) logging.info('Checkpoint saved at step %d.', global_step.numpy())
def _validate_and_setup(common_module, keypoint_profiles_module, models_module, keypoint_distance_config_override, embedder_fn_kwargs): """Validates and sets up training configurations.""" # Set default values for unspecified flags. if FLAGS.use_normalized_embeddings_for_triplet_mining is None: FLAGS.use_normalized_embeddings_for_triplet_mining = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None: FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.positive_pairwise_distance_type is None: FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type if FLAGS.positive_pairwise_distance_kernel is None: FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel if FLAGS.positive_pairwise_pairwise_reduction is None: FLAGS.positive_pairwise_pairwise_reduction = ( FLAGS.triplet_pairwise_reduction) if FLAGS.positive_pairwise_componentwise_reduction is None: FLAGS.positive_pairwise_componentwise_reduction = ( FLAGS.triplet_componentwise_reduction) # Validate flags. validate_flag = common_module.validate validate_flag(FLAGS.model_input_keypoint_type, common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES) validate_flag(FLAGS.embedding_type, common_module.SUPPORTED_EMBEDDING_TYPES) validate_flag(FLAGS.base_model_type, common_module.SUPPORTED_BASE_MODEL_TYPES) validate_flag(FLAGS.keypoint_distance_type, common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.triplet_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.triplet_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.positive_pairwise_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.positive_pairwise_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT: if FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ]: raise ValueError( 'No support for triplet distance type `%s` for embedding type `%s`.' % (FLAGS.triplet_distance_type, FLAGS.embedding_type)) if FLAGS.kl_regularization_loss_weight > 0.0: raise ValueError( 'No support for KL regularization loss for embedding type `%s`.' % FLAGS.embedding_type) if ((FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ] or FLAGS.positive_pairwise_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ]) and FLAGS.num_embedding_samples <= 0): raise ValueError( 'Must specify positive `num_embedding_samples` to use `%s` ' 'triplet/positive pairwise distance type.' % FLAGS.triplet_distance_type) if (((FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD ]) != (FLAGS.triplet_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ])) or ((FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD ]) != (FLAGS.positive_pairwise_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ]))): raise ValueError( 'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance ' 'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer' ' in pairs.') keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_2d) # Set up configurations. configs = { 'keypoint_profile_3d': keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_3d), 'keypoint_profile_2d': keypoint_profile_2d, 'embedder_fn': models_module.get_embedder( base_model_type=FLAGS.base_model_type, embedding_type=FLAGS.embedding_type, num_embedding_components=FLAGS.num_embedding_components, embedding_size=FLAGS.embedding_size, num_embedding_samples=FLAGS.num_embedding_samples, is_training=True, num_fc_blocks=FLAGS.num_fc_blocks, num_fcs_per_block=FLAGS.num_fcs_per_block, num_hidden_nodes=FLAGS.num_hidden_nodes, num_bottleneck_nodes=FLAGS.num_bottleneck_nodes, weight_max_norm=FLAGS.weight_max_norm, dropout_rate=FLAGS.dropout_rate, **embedder_fn_kwargs), 'triplet_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, common_module=common_module), 'triplet_mining_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, replace_samples_with_means=True, common_module=common_module), 'triplet_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.triplet_distance_kernel, pairwise_reduction=FLAGS.triplet_pairwise_reduction, componentwise_reduction=FLAGS.triplet_componentwise_reduction, # We initialize the sigmoid parameters to avoid model being stuck # in a `dead zone` at the beginning of training. L2_SIGMOID_MATCHING_PROB_a_initializer=( tf.initializers.constant(-0.65)), L2_SIGMOID_MATCHING_PROB_b_initializer=( tf.initializers.constant(-0.5)), EXPECTED_LIKELIHOOD_min_stddev=0.1, EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0), 'positive_pairwise_embedding_keys': pipeline_utils.get_embedding_keys( FLAGS.positive_pairwise_distance_type, common_module=common_module), 'positive_pairwise_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.positive_pairwise_distance_kernel, pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction, componentwise_reduction=( FLAGS.positive_pairwise_componentwise_reduction), # We initialize the sigmoid parameters to avoid model being stuck # in a `dead zone` at the beginning of training. L2_SIGMOID_MATCHING_PROB_a_initializer=( tf.initializers.constant(-0.65)), L2_SIGMOID_MATCHING_PROB_b_initializer=( tf.initializers.constant(-0.5)), EXPECTED_LIKELIHOOD_min_stddev=0.1, EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0), 'summarize_matching_sigmoid_vars': FLAGS.triplet_distance_kernel in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB], 'random_projection_azimuth_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_azimuth_range ], 'random_projection_elevation_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_elevation_range ], 'random_projection_roll_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_roll_range ], } if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE: configs.update({ 'keypoint_distance_fn': keypoint_utils.compute_procrustes_aligned_mpjpes, 'min_negative_keypoint_distance': FLAGS.min_negative_keypoint_mpjpe }) # We use the following assignments to get around pytype check failures. # TODO(liuti): Figure out a better workaround. if 'keypoint_distance_fn' in keypoint_distance_config_override: configs['keypoint_distance_fn'] = ( keypoint_distance_config_override['keypoint_distance_fn']) if 'min_negative_keypoint_distance' in keypoint_distance_config_override: configs['min_negative_keypoint_distance'] = ( keypoint_distance_config_override['min_negative_keypoint_distance'] ) if ('keypoint_distance_fn' not in configs or 'min_negative_keypoint_distance' not in configs): raise ValueError('Invalid keypoint distance config: %s.' % str(configs)) if FLAGS.task == 0 and not FLAGS.profile_only: # Save all key flags. pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir, 'all_flags.train.json') return configs
def run(input_dataset_class, common_module, keypoint_profiles_module, input_example_parser_creator): """Runs training pipeline. Args: input_dataset_class: An input dataset class that matches input table type. common_module: A Python module that defines common flags and constants. keypoint_profiles_module: A Python module that defines keypoint profiles. input_example_parser_creator: A function handle for creating data parser function. If None, uses the default parser creator. """ log_dir_path = FLAGS.log_dir_path pipeline_utils.create_dir_and_save_flags( flags, log_dir_path, 'all_flags.train_with_features.json') # Setup summary writer. summary_writer = tf.summary.create_file_writer( os.path.join(log_dir_path, 'train_logs'), flush_millis=10000) # Setup configuration. keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.keypoint_profile_name_2d) # Setup model. input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate) if FLAGS.input_features_dim > 0: feature_dim = FLAGS.input_features_dim input_shape = (input_length, feature_dim) else: feature_dim = None input_shape = (input_length, 13 * 2) classifier = models.get_temporal_classifier( FLAGS.classifier_type, input_shape=input_shape, num_classes=FLAGS.num_classes) ema_classifier = tf.keras.models.clone_model(classifier) optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) optimizer = tfa_optimizers.MovingAverage(optimizer) global_step = optimizer.iterations ckpt_manager, _, _ = utils.create_checkpoint( log_dir_path, optimizer=optimizer, model=classifier, ema_model=ema_classifier, global_step=global_step) # Setup the training dataset. dataset = pipelines.create_dataset_from_tables( FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes], num_instances_per_record=1, shuffle=True, drop_remainder=True, num_epochs=None, keypoint_names_2d=keypoint_profile_2d.keypoint_names, feature_dim=feature_dim, num_classes=FLAGS.num_classes, num_frames=FLAGS.num_frames, shuffle_buffer_size=FLAGS.shuffle_buffer_size, common_module=common_module, dataset_class=input_dataset_class, input_example_parser_creator=input_example_parser_creator) loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True) def train_one_iteration(inputs): """Trains the model for one iteration. Args: inputs: A dictionary for training inputs. Returns: loss: The training loss for this iteration. """ if FLAGS.input_features_dim > 0: features = inputs[common_module.KEY_FEATURES] else: features, _ = pipelines.create_model_input( inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT, keypoint_profile_2d) features = tf.squeeze(features, axis=1) features = features[:, ::FLAGS.downsample_rate, :] labels = inputs[common_module.KEY_CLASS_TARGETS] labels = tf.squeeze(labels, axis=1) with tf.GradientTape() as tape: outputs = classifier(features, training=True) regularization_loss = sum(classifier.losses) crossentropy_loss = loss_object(labels, outputs) total_loss = crossentropy_loss + regularization_loss trainable_variables = classifier.trainable_variables grads = tape.gradient(total_loss, trainable_variables) optimizer.apply_gradients(zip(grads, trainable_variables)) for grad, trainable_variable in zip(grads, trainable_variables): tf.summary.scalar( 'summarize_grads/' + trainable_variable.name, tf.linalg.norm(grad), step=global_step) return dict( total_loss=total_loss, crossentropy_loss=crossentropy_loss, regularization_loss=regularization_loss) if FLAGS.compile: train_one_iteration = tf.function(train_one_iteration) record_every_n_steps = min(5, FLAGS.num_iterations) save_ckpt_every_n_steps = min(500, FLAGS.num_iterations) with summary_writer.as_default(): with tf.summary.record_if(global_step % record_every_n_steps == 0): start = time.time() for inputs in dataset: if global_step >= FLAGS.num_iterations: break model_losses = train_one_iteration(inputs) duration = time.time() - start start = time.time() for name, loss in model_losses.items(): tf.summary.scalar('train/' + name, loss, step=global_step) tf.summary.scalar('train/learning_rate', optimizer.lr, step=global_step) tf.summary.scalar('train/batch_time', duration, step=global_step) tf.summary.scalar('global_step/sec', 1 / duration, step=global_step) if global_step % record_every_n_steps == 0: logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format( global_step.numpy(), FLAGS.num_iterations, duration, model_losses['total_loss'].numpy())) # Save checkpoint. if global_step % save_ckpt_every_n_steps == 0: utils.assign_moving_average_vars(classifier, ema_classifier, optimizer) ckpt_manager.save(checkpoint_number=global_step) logging.info('Checkpoint saved at step %d.', global_step.numpy())
def run(input_dataset_class, common_module, keypoint_profiles_module, input_example_parser_creator): """Runs training pipeline. Args: input_dataset_class: An input dataset class that matches input table type. common_module: A Python module that defines common flags and constants. keypoint_profiles_module: A Python module that defines keypoint profiles. input_example_parser_creator: A function handle for creating data parser function. If None, uses the default parser creator. """ log_dir_path = FLAGS.log_dir_path if not tf.io.gfile.exists(log_dir_path): raise ValueError( 'The directory {} does not exist. Please provide a new log_dir_path.' .format(log_dir_path)) eval_log_dir = os.path.join(log_dir_path, FLAGS.eval_name) pipeline_utils.create_dir_and_save_flags(flags, eval_log_dir, 'all_flags.eval_with_features.json') # Setup summary writer. summary_writer = tf.summary.create_file_writer( eval_log_dir, flush_millis=10000) # Setup configuration. keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.keypoint_profile_name_2d) # Setup model. input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate) if FLAGS.input_features_dim > 0: feature_dim = FLAGS.input_features_dim input_shape = (input_length, feature_dim) else: feature_dim = None input_shape = (input_length, 13 * 2) classifier = models.get_temporal_classifier( FLAGS.classifier_type, input_shape=input_shape, num_classes=FLAGS.num_classes) ema_classifier = tf.keras.models.clone_model(classifier) global_step = tf.Variable(0, dtype=tf.int64, trainable=False) dataset = pipelines.create_dataset_from_tables( FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes], num_instances_per_record=1, shuffle=False, num_epochs=1, drop_remainder=False, keypoint_names_2d=keypoint_profile_2d.keypoint_names, feature_dim=feature_dim, num_classes=FLAGS.num_classes, num_frames=FLAGS.num_frames, common_module=common_module, dataset_class=input_dataset_class, input_example_parser_creator=input_example_parser_creator) if FLAGS.compile: classifier.call = tf.function(classifier.call) ema_classifier.call = tf.function(ema_classifier.call) top_1_best_accuracy = None top_5_best_accuracy = None evaluated_last_ckpt = False def timeout_fn(): """Timeout function to stop the evaluation.""" return evaluated_last_ckpt def evaluate_once(): """Evaluates the model for one time.""" _, status, _ = utils.create_checkpoint( log_dir_path, model=classifier, ema_model=ema_classifier, global_step=global_step) status.expect_partial() logging.info('Last checkpoint [iteration: %d] restored at %s.', global_step.numpy(), log_dir_path) if global_step.numpy() >= FLAGS.max_iteration: nonlocal evaluated_last_ckpt evaluated_last_ckpt = True top_1_accuracy = tf.keras.metrics.CategoricalAccuracy() top_5_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(k=5) for inputs in dataset: if FLAGS.input_features_dim > 0: features = inputs[common_module.KEY_FEATURES] else: features, _ = pipelines.create_model_input( inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT, keypoint_profile_2d) features = tf.squeeze(features, axis=1) features = features[:, ::FLAGS.downsample_rate, :] labels = inputs[common_module.KEY_CLASS_TARGETS] labels = tf.squeeze(labels, axis=1) if FLAGS.use_moving_average: outputs = ema_classifier(features, training=False) else: outputs = classifier(features, training=False) top_1_accuracy.update_state(y_true=labels, y_pred=outputs) top_5_accuracy.update_state(y_true=labels, y_pred=outputs) nonlocal top_1_best_accuracy if (top_1_best_accuracy is None or top_1_accuracy.result().numpy() > top_1_best_accuracy): top_1_best_accuracy = top_1_accuracy.result().numpy() nonlocal top_5_best_accuracy if (top_5_best_accuracy is None or top_5_accuracy.result().numpy() > top_5_best_accuracy): top_5_best_accuracy = top_5_accuracy.result().numpy() tf.summary.scalar( 'eval/Basic/Top1_Accuracy', top_1_accuracy.result(), step=global_step.numpy()) tf.summary.scalar( 'eval/Best/Top1_Accuracy', top_1_best_accuracy, step=global_step.numpy()) tf.summary.scalar( 'eval/Basic/Top5_Accuracy', top_5_accuracy.result(), step=global_step.numpy()) tf.summary.scalar( 'eval/Best/Top5_Accuracy', top_5_best_accuracy, step=global_step.numpy()) logging.info('Accuracy: {:.2f}'.format(top_1_accuracy.result().numpy())) with summary_writer.as_default(): with tf.summary.record_if(True): if FLAGS.continuous_eval: for _ in tf.train.checkpoints_iterator(log_dir_path, timeout=1, timeout_fn=timeout_fn): evaluate_once() else: evaluate_once()
def _validate_and_setup(common_module, keypoint_profiles_module, models_module, keypoint_distance_config_override, create_model_input_fn_kwargs, embedder_fn_kwargs): """Validates and sets up training configurations.""" # Set default values for unspecified flags. if FLAGS.use_normalized_embeddings_for_triplet_mining is None: FLAGS.use_normalized_embeddings_for_triplet_mining = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None: FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.positive_pairwise_distance_type is None: FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type if FLAGS.positive_pairwise_distance_kernel is None: FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel if FLAGS.positive_pairwise_pairwise_reduction is None: FLAGS.positive_pairwise_pairwise_reduction = ( FLAGS.triplet_pairwise_reduction) if FLAGS.positive_pairwise_componentwise_reduction is None: FLAGS.positive_pairwise_componentwise_reduction = ( FLAGS.triplet_componentwise_reduction) # Validate flags. validate_flag = common_module.validate validate_flag(FLAGS.model_input_keypoint_type, common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES) validate_flag(FLAGS.model_input_keypoint_mask_type, common_module.SUPPORTED_MODEL_INPUT_KEYPOINT_MASK_TYPES) validate_flag(FLAGS.embedding_type, common_module.SUPPORTED_EMBEDDING_TYPES) validate_flag(FLAGS.base_model_type, common_module.SUPPORTED_BASE_MODEL_TYPES) validate_flag(FLAGS.keypoint_distance_type, common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.triplet_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.triplet_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.positive_pairwise_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.positive_pairwise_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT: if FLAGS.triplet_distance_type != common_module.DISTANCE_TYPE_CENTER: raise ValueError( 'No support for triplet distance type `%s` for embedding type `%s`.' % (FLAGS.triplet_distance_type, FLAGS.embedding_type)) if FLAGS.kl_regularization_loss_weight > 0.0: raise ValueError( 'No support for KL regularization loss for embedding type `%s`.' % FLAGS.embedding_type) if ((FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, ] or FLAGS.positive_pairwise_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, ]) and FLAGS.num_embedding_samples <= 0): raise ValueError( 'Must specify positive `num_embedding_samples` to use `%s` ' 'triplet/positive pairwise distance type.' % FLAGS.triplet_distance_type) if (((FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD, ]) != (FLAGS.triplet_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ])) or ((FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD, ]) != (FLAGS.positive_pairwise_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ]))): raise ValueError( 'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance ' 'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer' ' in pairs.') keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_2d) # Set up configurations. configs = { 'keypoint_profile_3d': keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_3d), 'keypoint_profile_2d': keypoint_profile_2d, 'create_model_input_fn': functools.partial( input_generator.create_model_input, model_input_keypoint_mask_type=( FLAGS.model_input_keypoint_mask_type), uniform_keypoint_jittering_max_offset_2d=( FLAGS.uniform_keypoint_jittering_max_offset_2d), gaussian_keypoint_jittering_offset_stddev_2d=( FLAGS.gaussian_keypoint_jittering_offset_stddev_2d), keypoint_dropout_probs=[ float(x) for x in FLAGS.keypoint_dropout_probs ], set_on_mask_for_non_anchors=FLAGS.set_on_mask_for_non_anchors, mix_mask_sub_batches=FLAGS.mix_mask_sub_batches, forced_mask_on_part_names=FLAGS.forced_mask_on_part_names, forced_mask_off_part_names=FLAGS.forced_mask_off_part_names, **create_model_input_fn_kwargs), 'embedder_fn': models_module.get_embedder( base_model_type=FLAGS.base_model_type, embedding_type=FLAGS.embedding_type, num_embedding_components=FLAGS.num_embedding_components, embedding_size=FLAGS.embedding_size, num_embedding_samples=FLAGS.num_embedding_samples, is_training=True, num_fc_blocks=FLAGS.num_fc_blocks, num_fcs_per_block=FLAGS.num_fcs_per_block, num_hidden_nodes=FLAGS.num_hidden_nodes, num_bottleneck_nodes=FLAGS.num_bottleneck_nodes, weight_max_norm=FLAGS.weight_max_norm, dropout_rate=FLAGS.dropout_rate, **embedder_fn_kwargs), 'triplet_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, common_module=common_module), 'triplet_mining_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, replace_samples_with_means=True, common_module=common_module), 'positive_pairwise_embedding_keys': pipeline_utils.get_embedding_keys( FLAGS.positive_pairwise_distance_type, common_module=common_module), 'summarize_matching_sigmoid_vars': FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB ], 'random_projection_azimuth_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_azimuth_range ], 'random_projection_elevation_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_elevation_range ], 'random_projection_roll_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_roll_range ], 'random_projection_camera_depth_range': [float(x) for x in FLAGS.random_projection_camera_depth_range], } embedding_sample_distance_fn_kwargs = { 'EXPECTED_LIKELIHOOD_min_stddev': 0.1, 'EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance': 100.0, } if FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB ] or FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_SQUARED_L2_SIGMOID_MATCHING_PROB ]: # We only need sigmoid parameters when a related distance kernel is used. sigmoid_raw_a, sigmoid_a, sigmoid_b = pipeline_utils.get_sigmoid_parameters( name='MatchingSigmoid', raw_a_initial_value=FLAGS.sigmoid_raw_a_initial, b_initial_value=FLAGS.sigmoid_b_initial, a_range=(None, FLAGS.sigmoid_a_max)) embedding_sample_distance_fn_kwargs.update({ 'L2_SIGMOID_MATCHING_PROB_a': sigmoid_a, 'L2_SIGMOID_MATCHING_PROB_b': sigmoid_b, 'SQUARED_L2_SIGMOID_MATCHING_PROB_a': sigmoid_a, 'SQUARED_L2_SIGMOID_MATCHING_PROB_b': sigmoid_b, }) configs.update({ 'sigmoid_raw_a': sigmoid_raw_a, 'sigmoid_a': sigmoid_a, 'sigmoid_b': sigmoid_b, }) configs.update({ 'triplet_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.triplet_distance_kernel, pairwise_reduction=FLAGS.triplet_pairwise_reduction, componentwise_reduction=FLAGS.triplet_componentwise_reduction, **embedding_sample_distance_fn_kwargs), 'positive_pairwise_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.positive_pairwise_distance_kernel, pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction, componentwise_reduction=( FLAGS.positive_pairwise_componentwise_reduction), **embedding_sample_distance_fn_kwargs), }) if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE: configs.update({ 'keypoint_distance_fn': keypoint_utils.compute_procrustes_aligned_mpjpes, 'min_negative_keypoint_distance': FLAGS.min_negative_keypoint_mpjpe }) # We use the following assignments to get around pytype check failures. # TODO(liuti): Figure out a better workaround. if 'keypoint_distance_fn' in keypoint_distance_config_override: configs['keypoint_distance_fn'] = ( keypoint_distance_config_override['keypoint_distance_fn']) if 'min_negative_keypoint_distance' in keypoint_distance_config_override: configs['min_negative_keypoint_distance'] = ( keypoint_distance_config_override['min_negative_keypoint_distance'] ) if ('keypoint_distance_fn' not in configs or 'min_negative_keypoint_distance' not in configs): raise ValueError('Invalid keypoint distance config: %s.' % str(configs)) if FLAGS.task == 0 and not FLAGS.profile_only: # Save all key flags. pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir, 'all_flags.train.json') return configs