def decode(self, labeled_tensor): """Reshape the input to the original shape. This is the inverse of encode. Encode must have been called at least once prior to this method being called. Args: labeled_tensor: The input tensor. Returns: The input reshaped to the original shape. Raises: ValueError: If this method was called before encode was called. """ if self._existing_axes is None: raise ValueError('decode called before encode') with tf_ops.name_scope(self._name, 'lt_reshape_decode', [labeled_tensor]) as scope: labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) new_axis_names = [ axis if isinstance(axis, string_types) else core.as_axis(axis).name for axis in self._new_axes ] return ops.reshape(labeled_tensor, new_axis_names, self._existing_axes, name=scope)
def constant(value, dtype=None, axes=None, name=None): """Creates a constant tensor. If `axes` includes any strings, shape is inferred from `value`. Otherwise, the sizes of the given `axes` are used to set `shape` for `tf.constant`. See tf.constant for more details. Args: value: The input tensor. dtype: The type of the returned tensor. axes: Optional Axes, list of strings or list of objects coercible to Axis objects. By default, axes are assumed to be an empty list (i.e., `value` is treated as a scalar). name: Optional op name. Returns: The tensor with elements set to zero. """ with ops.name_scope(name, 'lt_constant', [value]) as scope: if axes is None: axes = [] if isinstance(axes, core.Axes): axes = axes.values() if any(isinstance(ax, string_types) for ax in axes): # need to infer shape shape = None else: # axes already indicate shape axes = [core.as_axis(a) for a in axes] shape = [a.size for a in axes] op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope) return core.LabeledTensor(op, axes)
def reshape(labeled_tensor, existing_axes, new_axes, name=None): """Reshape specific axes of a LabeledTensor. Non-indicated axes remain in their original locations. Args: labeled_tensor: The input tensor. existing_axes: List of axis names found on the input tensor. These must appear sequentially in the list of axis names on the input. In other words, they must be a valid slice of `list(labeled_tensor.axes.keys())`. new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects providing new axes with which to replace `existing_axes` in the reshaped result. At most one element of `new_axes` may be a string, indicating an axis with unknown size. name: Optional op name. Returns: The reshaped LabeledTensor. Raises: ValueError: If `existing_axes` are not all axes on the input, or if more than one of `new_axes` has unknown size. AxisOrderError: If `existing_axes` are not a slice of axis names on the input. """ with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope: labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) original_axis_names = list(labeled_tensor.axes.keys()) existing_axes = list(existing_axes) if not set(existing_axes) <= set(original_axis_names): raise ValueError('existing_axes %r are not contained in the set of axis ' 'names %r on the input labeled tensor' % (existing_axes, original_axis_names)) start = original_axis_names.index(existing_axes[0]) stop = original_axis_names.index(existing_axes[-1]) + 1 if existing_axes != original_axis_names[start:stop]: # We could support existing_axes that aren't a slice by using transpose, # but that could lead to unpredictable performance consequences because # transposes are not free in TensorFlow. If we did transpose # automatically, the user might never realize that their data is being # produced with the wrong order. (The later will occur with some frequency # because of how broadcasting automatically choose axis order.) # So for now we've taken the strict approach. raise core.AxisOrderError( 'existing_axes %r are not a slice of axis names %r on the input ' 'labeled tensor. Use `transpose` or `impose_axis_order` to reorder ' 'axes on the input explicitly.' % (existing_axes, original_axis_names)) if sum(isinstance(axis, string_types) for axis in new_axes) > 1: raise ValueError( 'at most one axis in new_axes can have unknown size. All other ' 'axes must have an indicated integer size or labels: %r' % new_axes) original_values = list(labeled_tensor.axes.values()) axis_size = lambda axis: -1 if axis.size is None else axis.size shape = [axis_size(axis) for axis in original_values[:start]] for axis_ref in new_axes: if isinstance(axis_ref, string_types): shape.append(-1) else: axis = core.as_axis(axis_ref) shape.append(axis_size(axis)) shape.extend(axis_size(axis) for axis in original_values[stop:]) reshaped_tensor = array_ops.reshape( labeled_tensor.tensor, shape, name=scope) axes = original_values[:start] + list(new_axes) + original_values[stop:] return core.LabeledTensor(reshaped_tensor, axes)
def test_as_axis(self): self.assertEqual(self.i_7, core.as_axis(('7', 7))) self.assertEqual(self.i_7, core.as_axis(self.i_7))
def __init__(self, axes, dtype, default_value=None): self._axes = [core.as_axis(a) for a in axes] self._dtype = dtype self._default_value = default_value