def model_fn(input_features, output_sizes): """Applies model to input features and produces output of given sizes.""" if is_training: # Flatten all the model-irrelevant dimensions, i.e., dimensions that # precede the sequence / feature channel dimensions). Note that we only do # this for training, for which the batch size is known. num_last_dims_to_keep = 2 if sequential_inputs else 1 flattened_input_features = data_utils.flatten_first_dims( input_features, num_last_dims_to_keep=num_last_dims_to_keep) flattened_shape = data_utils.get_shape_by_first_dims( input_features, num_last_dims=num_last_dims_to_keep) outputs, activations = base_model_fn(flattened_input_features, output_sizes) # Unflatten back all the model-irrelevant dimensions. for key, output in outputs.items(): outputs[key] = data_utils.unflatten_first_dim( output, shape_to_unflatten=flattened_shape) for key, activation in activations.items(): activations[key] = data_utils.unflatten_first_dim( activation, shape_to_unflatten=flattened_shape) else: outputs, activations = base_model_fn(input_features, output_sizes) return outputs, activations
def test_unflatten_first_dim(self): # Shape = [6, 2]. x = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) unflattened_x = data_utils.unflatten_first_dim( x, shape_to_unflatten=tf.constant([2, 3])) self.assertAllEqual(unflattened_x, [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])