Esempio n. 1
0
    def decode(self, labeled_tensor):
        """Reshape the input to the original shape.

    This is the inverse of encode.
    Encode must have been called at least once prior to this method being
    called.

    Args:
      labeled_tensor: The input tensor.

    Returns:
      The input reshaped to the original shape.

    Raises:
      ValueError: If this method was called before encode was called.
    """
        if self._existing_axes is None:
            raise ValueError('decode called before encode')

        with tf_ops.name_scope(self._name, 'lt_reshape_decode',
                               [labeled_tensor]) as scope:
            labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

            new_axis_names = [
                axis
                if isinstance(axis, string_types) else core.as_axis(axis).name
                for axis in self._new_axes
            ]

            return ops.reshape(labeled_tensor,
                               new_axis_names,
                               self._existing_axes,
                               name=scope)
Esempio n. 2
0
def constant(value, dtype=None, axes=None, name=None):
  """Creates a constant tensor.

  If `axes` includes any strings, shape is inferred from `value`. Otherwise,
  the sizes of the given `axes` are used to set `shape` for `tf.constant`.

  See tf.constant for more details.

  Args:
    value: The input tensor.
    dtype: The type of the returned tensor.
    axes: Optional Axes, list of strings or list of objects coercible to Axis
      objects. By default, axes are assumed to be an empty list (i.e., `value`
      is treated as a scalar).
    name: Optional op name.

  Returns:
    The tensor with elements set to zero.
  """
  with ops.name_scope(name, 'lt_constant', [value]) as scope:

    if axes is None:
      axes = []

    if isinstance(axes, core.Axes):
      axes = axes.values()

    if any(isinstance(ax, string_types) for ax in axes):
      # need to infer shape
      shape = None
    else:
      # axes already indicate shape
      axes = [core.as_axis(a) for a in axes]
      shape = [a.size for a in axes]

    op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope)
    return core.LabeledTensor(op, axes)
Esempio n. 3
0
def reshape(labeled_tensor, existing_axes, new_axes, name=None):
  """Reshape specific axes of a LabeledTensor.

  Non-indicated axes remain in their original locations.

  Args:
    labeled_tensor: The input tensor.
    existing_axes: List of axis names found on the input tensor. These must
      appear sequentially in the list of axis names on the input. In other
      words, they must be a valid slice of `list(labeled_tensor.axes.keys())`.
    new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects
      providing new axes with which to replace `existing_axes` in the reshaped
      result. At most one element of `new_axes` may be a string, indicating an
      axis with unknown size.
    name: Optional op name.

  Returns:
    The reshaped LabeledTensor.

  Raises:
    ValueError: If `existing_axes` are not all axes on the input, or if more
     than one of `new_axes` has unknown size.
    AxisOrderError: If `existing_axes` are not a slice of axis names on the
      input.
  """
  with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope:
    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

    original_axis_names = list(labeled_tensor.axes.keys())
    existing_axes = list(existing_axes)
    if not set(existing_axes) <= set(original_axis_names):
      raise ValueError('existing_axes %r are not contained in the set of axis '
                       'names %r on the input labeled tensor' %
                       (existing_axes, original_axis_names))

    start = original_axis_names.index(existing_axes[0])
    stop = original_axis_names.index(existing_axes[-1]) + 1

    if existing_axes != original_axis_names[start:stop]:
      # We could support existing_axes that aren't a slice by using transpose,
      # but that could lead to unpredictable performance consequences because
      # transposes are not free in TensorFlow. If we did transpose
      # automatically, the user might never realize that their data is being
      # produced with the wrong order. (The later will occur with some frequency
      # because of how broadcasting automatically choose axis order.)
      # So for now we've taken the strict approach.
      raise core.AxisOrderError(
          'existing_axes %r are not a slice of axis names %r on the input '
          'labeled tensor. Use `transpose` or `impose_axis_order` to reorder '
          'axes on the input explicitly.' %
          (existing_axes, original_axis_names))

    if sum(isinstance(axis, string_types) for axis in new_axes) > 1:
      raise ValueError(
          'at most one axis in new_axes can have unknown size. All other '
          'axes must have an indicated integer size or labels: %r' % new_axes)

    original_values = list(labeled_tensor.axes.values())
    axis_size = lambda axis: -1 if axis.size is None else axis.size
    shape = [axis_size(axis) for axis in original_values[:start]]
    for axis_ref in new_axes:
      if isinstance(axis_ref, string_types):
        shape.append(-1)
      else:
        axis = core.as_axis(axis_ref)
        shape.append(axis_size(axis))
    shape.extend(axis_size(axis) for axis in original_values[stop:])

    reshaped_tensor = array_ops.reshape(
        labeled_tensor.tensor, shape, name=scope)
    axes = original_values[:start] + list(new_axes) + original_values[stop:]
    return core.LabeledTensor(reshaped_tensor, axes)
 def test_as_axis(self):
   self.assertEqual(self.i_7, core.as_axis(('7', 7)))
   self.assertEqual(self.i_7, core.as_axis(self.i_7))
 def __init__(self, axes, dtype, default_value=None):
   self._axes = [core.as_axis(a) for a in axes]
   self._dtype = dtype
   self._default_value = default_value