def test_expand_last_dim_return_tensor_with_more_dims(tmp_path): layer = layer_module.ExpandLastDim() tensor = layer(tf.constant([0.1, 0.2], dtype=tf.float32)) assert 2 == len(tensor.shape.as_list())
def build(self, hp, inputs=None): inputs = super().build(hp, inputs) output_node = nest.flatten(inputs)[0] if len(output_node.shape) == 3: output_node = keras_layers.ExpandLastDim()(output_node) return output_node