Ejemplo n.º 1
0
    def model_fn(input_features, output_sizes):
        """Applies model to input features and produces output of given sizes."""
        if is_training:
            # Flatten all the model-irrelevant dimensions, i.e., dimensions that
            # precede the sequence / feature channel dimensions). Note that we only do
            # this for training, for which the batch size is known.
            num_last_dims_to_keep = 2 if sequential_inputs else 1
            flattened_input_features = data_utils.flatten_first_dims(
                input_features, num_last_dims_to_keep=num_last_dims_to_keep)
            flattened_shape = data_utils.get_shape_by_first_dims(
                input_features, num_last_dims=num_last_dims_to_keep)

            outputs, activations = base_model_fn(flattened_input_features,
                                                 output_sizes)

            # Unflatten back all the model-irrelevant dimensions.
            for key, output in outputs.items():
                outputs[key] = data_utils.unflatten_first_dim(
                    output, shape_to_unflatten=flattened_shape)
            for key, activation in activations.items():
                activations[key] = data_utils.unflatten_first_dim(
                    activation, shape_to_unflatten=flattened_shape)

        else:
            outputs, activations = base_model_fn(input_features, output_sizes)

        return outputs, activations
Ejemplo n.º 2
0
 def test_get_shape_by_first_dims(self):
     # Shape = [1, 2, 3, 4, 5].
     x = tf.zeros([1, 2, 3, 4, 5])
     shape = data_utils.get_shape_by_first_dims(x, num_last_dims=2)
     self.assertAllEqual(shape, [1, 2, 3])