def model_fn(input_features, output_sizes):
        """Applies model to input features and produces output of given sizes."""
        if is_training:
            # Flatten all the model-irrelevant dimensions, i.e., dimensions that
            # precede the sequence / feature channel dimensions). Note that we only do
            # this for training, for which the batch size is known.
            num_last_dims_to_keep = 2 if sequential_inputs else 1
            flattened_input_features = data_utils.flatten_first_dims(
                input_features, num_last_dims_to_keep=num_last_dims_to_keep)
            flattened_shape = data_utils.get_shape_by_first_dims(
                input_features, num_last_dims=num_last_dims_to_keep)

            outputs, activations = base_model_fn(flattened_input_features,
                                                 output_sizes)

            # Unflatten back all the model-irrelevant dimensions.
            for key, output in outputs.items():
                outputs[key] = data_utils.unflatten_first_dim(
                    output, shape_to_unflatten=flattened_shape)
            for key, activation in activations.items():
                activations[key] = data_utils.unflatten_first_dim(
                    activation, shape_to_unflatten=flattened_shape)

        else:
            outputs, activations = base_model_fn(input_features, output_sizes)

        return outputs, activations
Exemple #2
0
 def test_flatten_first_dims(self):
     # Shape = [1, 2, 3, 4, 1].
     x = tf.constant([[[[[1], [2], [3], [4]], [[11], [12], [13], [14]],
                        [[21], [22], [23], [24]]],
                       [[[31], [32], [33], [34]], [[41], [42], [43], [44]],
                        [[51], [52], [53], [54]]]]])
     flattened_x = data_utils.flatten_first_dims(x, num_last_dims_to_keep=2)
     self.assertAllEqual(
         flattened_x, [[[1], [2], [3], [4]], [[11], [12], [13], [14]],
                       [[21], [22], [23], [24]], [[31], [32], [33], [34]],
                       [[41], [42], [43], [44]], [[51], [52], [53], [54]]])
def simple_model_late_fuse(input_features,
                           output_sizes,
                           is_training,
                           name='SimpleModelLateFuse',
                           num_late_fusion_preprojection_nodes=0,
                           late_fusion_preprojection_activation_fn=None,
                           num_bottleneck_nodes=0,
                           **kwargs):
    """Implements `simple baseline` model base architecture on sequential inputs.

  The model first runs simpel_model on each individual set of keypoints and then
  performs late fusions.

  Args:
    input_features: A tensor for input features. Shape = [..., sequence_length,
      feature_dim].
    output_sizes: A dictionary for output sizes in the format {output_name:
      output_size}, where `output_size` can be an integer or a list.
    is_training: A boolean for whether it is in training mode.
    name: A string for the name scope.
    num_late_fusion_preprojection_nodes: An integer for the dimension to project
      each frame features to before late fusion. No preprojection will be added
      if non-positive.
    late_fusion_preprojection_activation_fn: A string for the activation
      function of the preprojection layer. If None or 'NONE', no activation
      function is used.
    num_bottleneck_nodes: An integer for size of the bottleneck layer to be
      added before the output layer(s). No bottleneck layer will be added if
      non-positive.
    **kwargs: A dictionary of additional arguments passed to
      `simple_base_late_fuse`.

  Returns:
    A tensor for output activations. Shape = [..., output_dim].
  """
    # First flatten temporal axis into batch.
    flatten_input_features = data_utils.flatten_first_dims(
        input_features, num_last_dims_to_keep=1)
    # Batch process each pose.
    net = simple_base(flatten_input_features,
                      sequential_inputs=False,
                      is_training=is_training,
                      name=name,
                      **kwargs)

    if num_late_fusion_preprojection_nodes > 0:
        params = dict(kwargs)
        params.update({
            'num_hidden_nodes': num_late_fusion_preprojection_nodes,
            'activation_fn': late_fusion_preprojection_activation_fn,
        })
        net = fully_connected(net,
                              is_training=is_training,
                              name=name + '/LateFusePreProject',
                              **params)

    # Recover shape and concatenate temporal axis along feature dims.
    sequence_length = input_features.shape.as_list()[-2]
    feature_length = net.shape.as_list()[-1]
    net = tf.reshape(net, [-1, sequence_length * feature_length])
    # Late fusion.
    net = fully_connected(net,
                          is_training=is_training,
                          name=name + '/LateFuseProject',
                          **kwargs)
    net = fully_connected_block(net,
                                is_training=is_training,
                                name=name + '/LateFuseBlock',
                                **kwargs)
    activations = {'base_activations': net}

    if num_bottleneck_nodes > 0:
        net = linear(net,
                     output_size=num_bottleneck_nodes,
                     weight_max_norm=kwargs.get('weight_max_norm', 0.0),
                     weight_initializer=kwargs.get(
                         'weight_initializer', tf.initializers.he_normal()),
                     bias_initializer=kwargs.get('bias_initializer',
                                                 tf.initializers.he_normal()),
                     name=name + '/BottleneckLogits')
        activations['bottleneck_activations'] = net

    outputs = multi_head_logits(net,
                                output_sizes=output_sizes,
                                name=name,
                                **kwargs)
    return outputs, activations
Exemple #4
0
def run(master, input_dataset_class, common_module, keypoint_profiles_module,
        models_module, input_example_parser_creator, keypoint_preprocessor_3d,
        keypoint_distance_config_override, create_model_input_fn_kwargs,
        embedder_fn_kwargs):
    """Runs training pipeline.

  Args:
    master: BNS name of the TensorFlow master to use.
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    models_module: A Python module that defines base model architectures.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
    keypoint_preprocessor_3d: A function handle for preprocessing raw 3D
      keypoints.
    keypoint_distance_config_override: A dictionary for keypoint distance
      configuration to override the defaults. Ignored if empty.
    create_model_input_fn_kwargs: A dictionary of addition kwargs for create the
      model input creator function.
    embedder_fn_kwargs: A dictionary of additional kwargs for creating the
      embedder function.
  """
    g = tf.Graph()
    with g.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.num_ps_tasks)):
            configs = _validate_and_setup(
                common_module=common_module,
                keypoint_profiles_module=keypoint_profiles_module,
                models_module=models_module,
                keypoint_distance_config_override=
                keypoint_distance_config_override,
                create_model_input_fn_kwargs=create_model_input_fn_kwargs,
                embedder_fn_kwargs=embedder_fn_kwargs)

            def create_inputs():
                """Creates pipeline and model inputs."""
                inputs = pipeline_utils.read_batch_from_dataset_tables(
                    FLAGS.input_table,
                    batch_sizes=[int(x) for x in FLAGS.batch_size],
                    num_instances_per_record=2,
                    shuffle=True,
                    num_epochs=None,
                    keypoint_names_3d=configs['keypoint_profile_3d'].
                    keypoint_names,
                    keypoint_names_2d=configs['keypoint_profile_2d'].
                    keypoint_names,
                    min_keypoint_score_2d=FLAGS.min_input_keypoint_score_2d,
                    shuffle_buffer_size=FLAGS.input_shuffle_buffer_size,
                    common_module=common_module,
                    dataset_class=input_dataset_class,
                    input_example_parser_creator=input_example_parser_creator)

                (inputs[common_module.KEY_KEYPOINTS_3D],
                 keypoint_preprocessor_side_outputs_3d
                 ) = keypoint_preprocessor_3d(
                     inputs[common_module.KEY_KEYPOINTS_3D],
                     keypoint_profile_3d=configs['keypoint_profile_3d'],
                     normalize_keypoints_3d=True)
                inputs.update(keypoint_preprocessor_side_outputs_3d)

                inputs['model_inputs'], side_inputs = configs[
                    'create_model_input_fn'](
                        inputs[common_module.KEY_KEYPOINTS_2D],
                        inputs[common_module.KEY_KEYPOINT_MASKS_2D],
                        inputs[common_module.KEY_PREPROCESSED_KEYPOINTS_3D],
                        model_input_keypoint_type=FLAGS.
                        model_input_keypoint_type,
                        normalize_keypoints_2d=True,
                        keypoint_profile_2d=configs['keypoint_profile_2d'],
                        keypoint_profile_3d=configs['keypoint_profile_3d'],
                        azimuth_range=configs[
                            'random_projection_azimuth_range'],
                        elevation_range=configs[
                            'random_projection_elevation_range'],
                        roll_range=configs['random_projection_roll_range'],
                        normalized_camera_depth_range=(
                            configs['random_projection_camera_depth_range']))
                data_utils.merge_dict(side_inputs, inputs)
                return inputs

            inputs = create_inputs()
            outputs, _ = configs['embedder_fn'](inputs['model_inputs'])
            summaries = {
                'train/batch_size':
                tf.shape(outputs[common_module.KEY_EMBEDDING_MEANS])[0]
            }

            def add_triplet_loss():
                """Adds triplet loss."""
                anchor_keypoints_3d, positive_keypoints_3d = tf.unstack(
                    inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)

                anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None
                if FLAGS.use_inferred_keypoint_masks_for_triplet_label:
                    anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack(
                        inputs[
                            common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D],
                        num=2,
                        axis=1)
                    anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
                        anchor_keypoint_masks_2d,
                        input_keypoint_profile=configs['keypoint_profile_2d'],
                        output_keypoint_profile=configs['keypoint_profile_3d'],
                        enforce_surjectivity=True)
                    positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
                        positive_keypoint_masks_2d,
                        input_keypoint_profile=configs['keypoint_profile_2d'],
                        output_keypoint_profile=configs['keypoint_profile_3d'],
                        enforce_surjectivity=True)

                triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack(
                    pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_embedding_keys']),
                    axis=1)
                if FLAGS.use_normalized_embeddings_for_triplet_loss:
                    triplet_anchor_embeddings = tf.math.l2_normalize(
                        triplet_anchor_embeddings, axis=-1)
                    triplet_positive_embeddings = tf.math.l2_normalize(
                        triplet_positive_embeddings, axis=-1)

                triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = (
                    tf.unstack(pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_mining_embedding_keys']),
                               axis=1))
                if FLAGS.use_normalized_embeddings_for_triplet_mining:
                    triplet_anchor_mining_embeddings = tf.math.l2_normalize(
                        triplet_anchor_mining_embeddings, axis=-1)
                    triplet_positive_mining_embeddings = tf.math.l2_normalize(
                        triplet_positive_mining_embeddings, axis=-1)

                triplet_loss, triplet_loss_summaries = (
                    loss_utils.compute_keypoint_triplet_losses(
                        anchor_embeddings=triplet_anchor_embeddings,
                        positive_embeddings=triplet_positive_embeddings,
                        match_embeddings=triplet_positive_embeddings,
                        anchor_keypoints=anchor_keypoints_3d,
                        match_keypoints=positive_keypoints_3d,
                        margin=FLAGS.triplet_loss_margin,
                        min_negative_keypoint_distance=(
                            configs['min_negative_keypoint_distance']),
                        use_semi_hard=FLAGS.use_semi_hard_triplet_negatives,
                        exclude_inactive_triplet_loss=(
                            FLAGS.exclude_inactive_triplet_loss),
                        anchor_keypoint_masks=anchor_keypoint_masks_3d,
                        match_keypoint_masks=positive_keypoint_masks_3d,
                        embedding_sample_distance_fn=(
                            configs['triplet_embedding_sample_distance_fn']),
                        keypoint_distance_fn=configs['keypoint_distance_fn'],
                        anchor_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        positive_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        match_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        summarize_percentiles=FLAGS.summarize_percentiles))
                tf.losses.add_loss(triplet_loss,
                                   loss_collection=tf.GraphKeys.LOSSES)
                summaries.update(triplet_loss_summaries)
                summaries['train/triplet_loss'] = triplet_loss

            def add_kl_regularization_loss():
                """Adds KL regularization loss."""
                kl_regularization_loss, kl_regularization_loss_summaries = (
                    loss_utils.compute_kl_regularization_loss(
                        outputs[common_module.KEY_EMBEDDING_MEANS],
                        stddevs=outputs[common_module.KEY_EMBEDDING_STDDEVS],
                        prior_stddev=FLAGS.kl_regularization_prior_stddev,
                        loss_weight=FLAGS.kl_regularization_loss_weight))
                tf.losses.add_loss(
                    kl_regularization_loss,
                    loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES)
                summaries.update(kl_regularization_loss_summaries)
                summaries[
                    'train/kl_regularization_loss'] = kl_regularization_loss

            def add_positive_pairwise_loss():
                """Adds positive pairwise loss."""
                (positive_pairwise_anchor_embeddings,
                 positive_pairwise_positive_embeddings) = tf.unstack(
                     pipeline_utils.stack_embeddings(
                         outputs,
                         configs['positive_pairwise_embedding_keys'],
                         common_module=common_module),
                     axis=1)
                if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss:
                    positive_pairwise_anchor_embeddings = tf.math.l2_normalize(
                        positive_pairwise_anchor_embeddings, axis=-1)
                    positive_pairwise_positive_embeddings = tf.math.l2_normalize(
                        positive_pairwise_positive_embeddings, axis=-1)
                positive_pairwise_loss, positive_pairwise_loss_summaries = (
                    loss_utils.compute_positive_pairwise_loss(
                        positive_pairwise_anchor_embeddings,
                        positive_pairwise_positive_embeddings,
                        loss_weight=FLAGS.positive_pairwise_loss_weight,
                        distance_fn=configs[
                            'positive_pairwise_embedding_sample_distance_fn']))
                tf.losses.add_loss(positive_pairwise_loss,
                                   loss_collection=tf.GraphKeys.LOSSES)
                summaries.update(positive_pairwise_loss_summaries)
                summaries[
                    'train/positive_pairwise_loss'] = positive_pairwise_loss

            add_triplet_loss()
            if FLAGS.kl_regularization_loss_weight > 0.0:
                add_kl_regularization_loss()
            if FLAGS.positive_pairwise_loss_weight > 0.0:
                add_positive_pairwise_loss()
            total_loss = tf.losses.get_total_loss()
            summaries['train/total_loss'] = total_loss

            if configs['summarize_matching_sigmoid_vars']:
                # Summarize variables used in matching sigmoid.
                # TODO(liuti): Currently the variable for `raw_a` is named `a` in
                # checkpoints, and true `a` may be referred to as `a_plus` for historic
                # reasons. Consolidate the naming.
                summaries.update({
                    'train/MatchingSigmoid/a':
                    configs['sigmoid_raw_a'],
                    'train/MatchingSigmoid/a_plus':
                    configs['sigmoid_a'],
                    'train/MatchingSigmoid/b':
                    configs['sigmoid_b'],
                })

            if FLAGS.use_moving_average:
                pipeline_utils.add_moving_average(FLAGS.moving_average_decay)

            learning_rate = FLAGS.learning_rate
            optimizer = pipeline_utils.get_optimizer(
                FLAGS.optimizer.upper(), learning_rate=learning_rate)
            init_fn = pipeline_utils.get_init_fn(
                train_dir=FLAGS.train_log_dir,
                model_checkpoint=FLAGS.init_model_checkpoint)
            train_op = tf_slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.gradient_clip_norm,
                summarize_gradients=FLAGS.summarize_gradients)
            saver = tf.train.Saver(keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours,
                                   pad_step_number=True)
            summaries['train/learning_rate'] = learning_rate

            image_summary = {}
            if FLAGS.summarize_inputs:
                image_summary.update({
                    'poses_2d/AnchorPositivePair':
                    visualization_utils.tf_draw_poses_2d(
                        data_utils.flatten_first_dims(inputs[
                            common_module.KEY_PREPROCESSED_KEYPOINTS_2D],
                                                      num_last_dims_to_keep=2),
                        keypoint_profile_2d=configs['keypoint_profile_2d'],
                        num_cols=2),
                })
            pipeline_utils.add_summary(scalars_to_summarize=summaries,
                                       images_to_summarize=image_summary)

            if FLAGS.profile_only:
                pipeline_utils.profile()
                return

            tf_slim.learning.train(
                train_op,
                logdir=FLAGS.train_log_dir,
                log_every_n_steps=FLAGS.log_every_n_steps,
                master=master,
                is_chief=FLAGS.task == 0,
                number_of_steps=FLAGS.num_steps,
                init_fn=init_fn,
                save_summaries_secs=FLAGS.save_summaries_secs,
                startup_delay_steps=FLAGS.startup_delay_steps * FLAGS.task,
                saver=saver,
                save_interval_secs=FLAGS.save_interval_secs,
                session_config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False))