Example #1
0
    def model_fn(input_features, output_sizes):
        """Applies model to input features and produces output of given sizes."""
        if is_training:
            # Flatten all the model-irrelevant dimensions, i.e., dimensions that
            # precede the sequence / feature channel dimensions). Note that we only do
            # this for training, for which the batch size is known.
            num_last_dims_to_keep = 2 if sequential_inputs else 1
            flattened_input_features = data_utils.flatten_first_dims(
                input_features, num_last_dims_to_keep=num_last_dims_to_keep)
            flattened_shape = data_utils.get_shape_by_first_dims(
                input_features, num_last_dims=num_last_dims_to_keep)

            outputs, activations = base_model_fn(flattened_input_features,
                                                 output_sizes)

            # Unflatten back all the model-irrelevant dimensions.
            for key, output in outputs.items():
                outputs[key] = data_utils.unflatten_first_dim(
                    output, shape_to_unflatten=flattened_shape)
            for key, activation in activations.items():
                activations[key] = data_utils.unflatten_first_dim(
                    activation, shape_to_unflatten=flattened_shape)

        else:
            outputs, activations = base_model_fn(input_features, output_sizes)

        return outputs, activations
Example #2
0
 def test_unflatten_first_dim(self):
   # Shape = [6, 2].
   x = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
   unflattened_x = data_utils.unflatten_first_dim(
       x, shape_to_unflatten=tf.constant([2, 3]))
   self.assertAllEqual(unflattened_x,
                       [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])