def model_fn(input_features, output_sizes): """Applies model to input features and produces output of given sizes.""" if is_training: # Flatten all the model-irrelevant dimensions, i.e., dimensions that # precede the sequence / feature channel dimensions). Note that we only do # this for training, for which the batch size is known. num_last_dims_to_keep = 2 if sequential_inputs else 1 flattened_input_features = data_utils.flatten_first_dims( input_features, num_last_dims_to_keep=num_last_dims_to_keep) flattened_shape = data_utils.get_shape_by_first_dims( input_features, num_last_dims=num_last_dims_to_keep) outputs, activations = base_model_fn(flattened_input_features, output_sizes) # Unflatten back all the model-irrelevant dimensions. for key, output in outputs.items(): outputs[key] = data_utils.unflatten_first_dim( output, shape_to_unflatten=flattened_shape) for key, activation in activations.items(): activations[key] = data_utils.unflatten_first_dim( activation, shape_to_unflatten=flattened_shape) else: outputs, activations = base_model_fn(input_features, output_sizes) return outputs, activations
def test_flatten_first_dims(self): # Shape = [1, 2, 3, 4, 1]. x = tf.constant([[[[[1], [2], [3], [4]], [[11], [12], [13], [14]], [[21], [22], [23], [24]]], [[[31], [32], [33], [34]], [[41], [42], [43], [44]], [[51], [52], [53], [54]]]]]) flattened_x = data_utils.flatten_first_dims(x, num_last_dims_to_keep=2) self.assertAllEqual( flattened_x, [[[1], [2], [3], [4]], [[11], [12], [13], [14]], [[21], [22], [23], [24]], [[31], [32], [33], [34]], [[41], [42], [43], [44]], [[51], [52], [53], [54]]])
def simple_model_late_fuse(input_features, output_sizes, is_training, name='SimpleModelLateFuse', num_late_fusion_preprojection_nodes=0, late_fusion_preprojection_activation_fn=None, num_bottleneck_nodes=0, **kwargs): """Implements `simple baseline` model base architecture on sequential inputs. The model first runs simpel_model on each individual set of keypoints and then performs late fusions. Args: input_features: A tensor for input features. Shape = [..., sequence_length, feature_dim]. output_sizes: A dictionary for output sizes in the format {output_name: output_size}, where `output_size` can be an integer or a list. is_training: A boolean for whether it is in training mode. name: A string for the name scope. num_late_fusion_preprojection_nodes: An integer for the dimension to project each frame features to before late fusion. No preprojection will be added if non-positive. late_fusion_preprojection_activation_fn: A string for the activation function of the preprojection layer. If None or 'NONE', no activation function is used. num_bottleneck_nodes: An integer for size of the bottleneck layer to be added before the output layer(s). No bottleneck layer will be added if non-positive. **kwargs: A dictionary of additional arguments passed to `simple_base_late_fuse`. Returns: A tensor for output activations. Shape = [..., output_dim]. """ # First flatten temporal axis into batch. flatten_input_features = data_utils.flatten_first_dims( input_features, num_last_dims_to_keep=1) # Batch process each pose. net = simple_base(flatten_input_features, sequential_inputs=False, is_training=is_training, name=name, **kwargs) if num_late_fusion_preprojection_nodes > 0: params = dict(kwargs) params.update({ 'num_hidden_nodes': num_late_fusion_preprojection_nodes, 'activation_fn': late_fusion_preprojection_activation_fn, }) net = fully_connected(net, is_training=is_training, name=name + '/LateFusePreProject', **params) # Recover shape and concatenate temporal axis along feature dims. sequence_length = input_features.shape.as_list()[-2] feature_length = net.shape.as_list()[-1] net = tf.reshape(net, [-1, sequence_length * feature_length]) # Late fusion. net = fully_connected(net, is_training=is_training, name=name + '/LateFuseProject', **kwargs) net = fully_connected_block(net, is_training=is_training, name=name + '/LateFuseBlock', **kwargs) activations = {'base_activations': net} if num_bottleneck_nodes > 0: net = linear(net, output_size=num_bottleneck_nodes, weight_max_norm=kwargs.get('weight_max_norm', 0.0), weight_initializer=kwargs.get( 'weight_initializer', tf.initializers.he_normal()), bias_initializer=kwargs.get('bias_initializer', tf.initializers.he_normal()), name=name + '/BottleneckLogits') activations['bottleneck_activations'] = net outputs = multi_head_logits(net, output_sizes=output_sizes, name=name, **kwargs) return outputs, activations
def run(master, input_dataset_class, common_module, keypoint_profiles_module, models_module, input_example_parser_creator, keypoint_preprocessor_3d, keypoint_distance_config_override, create_model_input_fn_kwargs, embedder_fn_kwargs): """Runs training pipeline. Args: master: BNS name of the TensorFlow master to use. input_dataset_class: An input dataset class that matches input table type. common_module: A Python module that defines common flags and constants. keypoint_profiles_module: A Python module that defines keypoint profiles. models_module: A Python module that defines base model architectures. input_example_parser_creator: A function handle for creating data parser function. If None, uses the default parser creator. keypoint_preprocessor_3d: A function handle for preprocessing raw 3D keypoints. keypoint_distance_config_override: A dictionary for keypoint distance configuration to override the defaults. Ignored if empty. create_model_input_fn_kwargs: A dictionary of addition kwargs for create the model input creator function. embedder_fn_kwargs: A dictionary of additional kwargs for creating the embedder function. """ g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.num_ps_tasks)): configs = _validate_and_setup( common_module=common_module, keypoint_profiles_module=keypoint_profiles_module, models_module=models_module, keypoint_distance_config_override= keypoint_distance_config_override, create_model_input_fn_kwargs=create_model_input_fn_kwargs, embedder_fn_kwargs=embedder_fn_kwargs) def create_inputs(): """Creates pipeline and model inputs.""" inputs = pipeline_utils.read_batch_from_dataset_tables( FLAGS.input_table, batch_sizes=[int(x) for x in FLAGS.batch_size], num_instances_per_record=2, shuffle=True, num_epochs=None, keypoint_names_3d=configs['keypoint_profile_3d']. keypoint_names, keypoint_names_2d=configs['keypoint_profile_2d']. keypoint_names, min_keypoint_score_2d=FLAGS.min_input_keypoint_score_2d, shuffle_buffer_size=FLAGS.input_shuffle_buffer_size, common_module=common_module, dataset_class=input_dataset_class, input_example_parser_creator=input_example_parser_creator) (inputs[common_module.KEY_KEYPOINTS_3D], keypoint_preprocessor_side_outputs_3d ) = keypoint_preprocessor_3d( inputs[common_module.KEY_KEYPOINTS_3D], keypoint_profile_3d=configs['keypoint_profile_3d'], normalize_keypoints_3d=True) inputs.update(keypoint_preprocessor_side_outputs_3d) inputs['model_inputs'], side_inputs = configs[ 'create_model_input_fn']( inputs[common_module.KEY_KEYPOINTS_2D], inputs[common_module.KEY_KEYPOINT_MASKS_2D], inputs[common_module.KEY_PREPROCESSED_KEYPOINTS_3D], model_input_keypoint_type=FLAGS. model_input_keypoint_type, normalize_keypoints_2d=True, keypoint_profile_2d=configs['keypoint_profile_2d'], keypoint_profile_3d=configs['keypoint_profile_3d'], azimuth_range=configs[ 'random_projection_azimuth_range'], elevation_range=configs[ 'random_projection_elevation_range'], roll_range=configs['random_projection_roll_range'], normalized_camera_depth_range=( configs['random_projection_camera_depth_range'])) data_utils.merge_dict(side_inputs, inputs) return inputs inputs = create_inputs() outputs, _ = configs['embedder_fn'](inputs['model_inputs']) summaries = { 'train/batch_size': tf.shape(outputs[common_module.KEY_EMBEDDING_MEANS])[0] } def add_triplet_loss(): """Adds triplet loss.""" anchor_keypoints_3d, positive_keypoints_3d = tf.unstack( inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1) anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None if FLAGS.use_inferred_keypoint_masks_for_triplet_label: anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack( inputs[ common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D], num=2, axis=1) anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks( anchor_keypoint_masks_2d, input_keypoint_profile=configs['keypoint_profile_2d'], output_keypoint_profile=configs['keypoint_profile_3d'], enforce_surjectivity=True) positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks( positive_keypoint_masks_2d, input_keypoint_profile=configs['keypoint_profile_2d'], output_keypoint_profile=configs['keypoint_profile_3d'], enforce_surjectivity=True) triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack( pipeline_utils.stack_embeddings( outputs, configs['triplet_embedding_keys']), axis=1) if FLAGS.use_normalized_embeddings_for_triplet_loss: triplet_anchor_embeddings = tf.math.l2_normalize( triplet_anchor_embeddings, axis=-1) triplet_positive_embeddings = tf.math.l2_normalize( triplet_positive_embeddings, axis=-1) triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = ( tf.unstack(pipeline_utils.stack_embeddings( outputs, configs['triplet_mining_embedding_keys']), axis=1)) if FLAGS.use_normalized_embeddings_for_triplet_mining: triplet_anchor_mining_embeddings = tf.math.l2_normalize( triplet_anchor_mining_embeddings, axis=-1) triplet_positive_mining_embeddings = tf.math.l2_normalize( triplet_positive_mining_embeddings, axis=-1) triplet_loss, triplet_loss_summaries = ( loss_utils.compute_keypoint_triplet_losses( anchor_embeddings=triplet_anchor_embeddings, positive_embeddings=triplet_positive_embeddings, match_embeddings=triplet_positive_embeddings, anchor_keypoints=anchor_keypoints_3d, match_keypoints=positive_keypoints_3d, margin=FLAGS.triplet_loss_margin, min_negative_keypoint_distance=( configs['min_negative_keypoint_distance']), use_semi_hard=FLAGS.use_semi_hard_triplet_negatives, exclude_inactive_triplet_loss=( FLAGS.exclude_inactive_triplet_loss), anchor_keypoint_masks=anchor_keypoint_masks_3d, match_keypoint_masks=positive_keypoint_masks_3d, embedding_sample_distance_fn=( configs['triplet_embedding_sample_distance_fn']), keypoint_distance_fn=configs['keypoint_distance_fn'], anchor_mining_embeddings= triplet_anchor_mining_embeddings, positive_mining_embeddings= triplet_positive_mining_embeddings, match_mining_embeddings= triplet_positive_mining_embeddings, summarize_percentiles=FLAGS.summarize_percentiles)) tf.losses.add_loss(triplet_loss, loss_collection=tf.GraphKeys.LOSSES) summaries.update(triplet_loss_summaries) summaries['train/triplet_loss'] = triplet_loss def add_kl_regularization_loss(): """Adds KL regularization loss.""" kl_regularization_loss, kl_regularization_loss_summaries = ( loss_utils.compute_kl_regularization_loss( outputs[common_module.KEY_EMBEDDING_MEANS], stddevs=outputs[common_module.KEY_EMBEDDING_STDDEVS], prior_stddev=FLAGS.kl_regularization_prior_stddev, loss_weight=FLAGS.kl_regularization_loss_weight)) tf.losses.add_loss( kl_regularization_loss, loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES) summaries.update(kl_regularization_loss_summaries) summaries[ 'train/kl_regularization_loss'] = kl_regularization_loss def add_positive_pairwise_loss(): """Adds positive pairwise loss.""" (positive_pairwise_anchor_embeddings, positive_pairwise_positive_embeddings) = tf.unstack( pipeline_utils.stack_embeddings( outputs, configs['positive_pairwise_embedding_keys'], common_module=common_module), axis=1) if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss: positive_pairwise_anchor_embeddings = tf.math.l2_normalize( positive_pairwise_anchor_embeddings, axis=-1) positive_pairwise_positive_embeddings = tf.math.l2_normalize( positive_pairwise_positive_embeddings, axis=-1) positive_pairwise_loss, positive_pairwise_loss_summaries = ( loss_utils.compute_positive_pairwise_loss( positive_pairwise_anchor_embeddings, positive_pairwise_positive_embeddings, loss_weight=FLAGS.positive_pairwise_loss_weight, distance_fn=configs[ 'positive_pairwise_embedding_sample_distance_fn'])) tf.losses.add_loss(positive_pairwise_loss, loss_collection=tf.GraphKeys.LOSSES) summaries.update(positive_pairwise_loss_summaries) summaries[ 'train/positive_pairwise_loss'] = positive_pairwise_loss add_triplet_loss() if FLAGS.kl_regularization_loss_weight > 0.0: add_kl_regularization_loss() if FLAGS.positive_pairwise_loss_weight > 0.0: add_positive_pairwise_loss() total_loss = tf.losses.get_total_loss() summaries['train/total_loss'] = total_loss if configs['summarize_matching_sigmoid_vars']: # Summarize variables used in matching sigmoid. # TODO(liuti): Currently the variable for `raw_a` is named `a` in # checkpoints, and true `a` may be referred to as `a_plus` for historic # reasons. Consolidate the naming. summaries.update({ 'train/MatchingSigmoid/a': configs['sigmoid_raw_a'], 'train/MatchingSigmoid/a_plus': configs['sigmoid_a'], 'train/MatchingSigmoid/b': configs['sigmoid_b'], }) if FLAGS.use_moving_average: pipeline_utils.add_moving_average(FLAGS.moving_average_decay) learning_rate = FLAGS.learning_rate optimizer = pipeline_utils.get_optimizer( FLAGS.optimizer.upper(), learning_rate=learning_rate) init_fn = pipeline_utils.get_init_fn( train_dir=FLAGS.train_log_dir, model_checkpoint=FLAGS.init_model_checkpoint) train_op = tf_slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.gradient_clip_norm, summarize_gradients=FLAGS.summarize_gradients) saver = tf.train.Saver(keep_checkpoint_every_n_hours=FLAGS. keep_checkpoint_every_n_hours, pad_step_number=True) summaries['train/learning_rate'] = learning_rate image_summary = {} if FLAGS.summarize_inputs: image_summary.update({ 'poses_2d/AnchorPositivePair': visualization_utils.tf_draw_poses_2d( data_utils.flatten_first_dims(inputs[ common_module.KEY_PREPROCESSED_KEYPOINTS_2D], num_last_dims_to_keep=2), keypoint_profile_2d=configs['keypoint_profile_2d'], num_cols=2), }) pipeline_utils.add_summary(scalars_to_summarize=summaries, images_to_summarize=image_summary) if FLAGS.profile_only: pipeline_utils.profile() return tf_slim.learning.train( train_op, logdir=FLAGS.train_log_dir, log_every_n_steps=FLAGS.log_every_n_steps, master=master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.num_steps, init_fn=init_fn, save_summaries_secs=FLAGS.save_summaries_secs, startup_delay_steps=FLAGS.startup_delay_steps * FLAGS.task, saver=saver, save_interval_secs=FLAGS.save_interval_secs, session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))