def _structured_tensor_prensor_map( st: structured_tensor.StructuredTensor, default_field_name: path.Step) -> Mapping[path.Step, prensor.Prensor]: """Creates a map of fields, to put in a child or root prensor.""" return { k: _structured_tensor_field_to_prensor( st.field_value(k), default_field_name) for k in st.field_names() }
def _expand_dims_scalar(st: structured_tensor.StructuredTensor): """_expand_dims for a scalar structured tensor.""" new_shape = tf.constant([1], dtype=tf.int64) new_fields = { k: _expand_dims(st.field_value(k), 0) for k in st.field_names() } return structured_tensor.StructuredTensor.from_fields(new_fields, shape=new_shape)