def test_concat_many(self):
    red = core.Axis('rgb', ['red'])
    green = core.Axis('rgb', ['green'])
    blue = core.Axis('rgb', ['blue'])
    red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])

    self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
 def test_invalid_input(self):
   with self.assertRaises(TypeError):
     core.Axis('foo', [{}])
   with self.assertRaises(ValueError):
     core.Axis('foo', [1, 2, 3, 1])
   red = core.Axis('foo', ['red'])
   with self.assertRaises(tc.Error):
     core.concat_axes([red, 1])
  def setUp(self):
    d_7 = tensor_shape.Dimension(7)
    p_rgb = ['red', 'green', 'blue']

    self.i_7 = core.Axis('7', d_7)
    self.i_7p = core.Axis('7prime', d_7)
    self.i_rgb = core.Axis('rgb', p_rgb)
    self.i_range = core.Axis('range', range(7))
    self.i_unknown = core.Axis('unknown', None)
示例#4
0
def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
  """Pads a tensor.

  See tf.pad.

  Args:
    labeled_tensor: The input tensor.
    paddings: A mapping where the keys are axis names and the values are
      tuples where the first element is the padding to insert at the beginning
      of the axis and the second is the padding to insert at the end of the
      axis.
    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
    name: Optional op name.

  Returns:
    A tensor with the indicated axes padded, optionally with those axes extended
    with the provided labels.

  Raises:
    ValueError: If the padded axes are not axes in the input tensor.
  """
  with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

    if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
      raise ValueError('pad axes %r are not contained in the set of axis '
                       'names %r on the input labeled tensor' %
                       (paddings.keys(), labeled_tensor.axes))

    new_axes = []
    padding_pairs = []
    for name, axis in labeled_tensor.axes.items():
      if name in paddings:
        padding_before, padding_after = paddings[name]
        axis_before = core.Axis(name, padding_before)
        axis_after = core.Axis(name, padding_after)
        new_axes.append(core.concat_axes([axis_before, axis, axis_after]))
        padding_pairs.append((len(axis_before), len(axis_after)))
      else:
        new_axes.append(axis)
        padding_pairs.append((0, 0))

    pad_op = array_ops.pad(labeled_tensor.tensor,
                           padding_pairs,
                           mode,
                           name=scope)

    return core.LabeledTensor(pad_op, new_axes)
  def setUp(self):
    d_7 = tensor_shape.Dimension(7)
    d_8 = tensor_shape.Dimension(8)
    p_rgb = ['red', 'green', 'blue']
    p_range = range(7)

    self.i_8 = core.Axis('8', d_8)

    self.a0 = core.Axes([('d7', d_7)])
    self.a1 = core.Axes([('d7', d_7)])
    self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
    self.a3 = core.Axes([('8', d_8), ('range', p_range)])
    def test_bijection_with_labels(self):
        depth_axis = core.Axis('depth',
                               range(len(self.channels) * len(self.masks)))
        rc = sugar.ReshapeCoder(['channel', 'mask'],
                                [depth_axis, ('other', ['label'])])

        encode_lt = rc.encode(self.masked_image_lt)
        golden_axes = core.Axes([
            self.batch_axis, self.row_axis, self.column_axis, depth_axis,
            ('other', ['label'])
        ])
        self.assertEqual(encode_lt.axes, golden_axes)

        decode_lt = rc.decode(encode_lt)
        self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
示例#7
0
def rename_axis(labeled_tensor, existing_name, new_name, name=None):
  """Rename an axis of LabeledTensor.

  Args:
    labeled_tensor: The input tensor.
    existing_name: Name for an existing axis on the input.
    new_name: Desired replacement name.
    name: Optional op name.

  Returns:
    LabeledTensor with renamed axis.

  Raises:
    ValueError: If `existing_name` is not an axis on the input.
  """
  with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope:
    if existing_name not in labeled_tensor.axes:
      raise ValueError('existing_name %r are not contained in the set of axis '
                       'names %r on the input labeled tensor' %
                       (existing_name, labeled_tensor.axes.keys()))
    new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value)
    return reshape(labeled_tensor, [existing_name], [new_axis], name=scope)
示例#8
0
def select(labeled_tensor, selection, name=None):
  """Slice out a subset of the tensor.

  Args:
    labeled_tensor: The input tensor.
    selection: A dictionary mapping an axis name to a scalar, slice or list of
      values to select. Currently supports two types of selections:
        (a) Any number of scalar and/or slice selections.
        (b) Exactly one list selection, without any scalars or slices.
    name: Optional op name.

  Returns:
    The selection as a `LabeledTensor`.

  Raises:
    ValueError: If the tensor doesn't have an axis in the selection or if
      that axis lacks labels.
    KeyError: If any labels in a selection are not found in the original axis.
    NotImplementedError: If you attempt to combine a list selection with
      scalar selection or another list selection.
  """
  with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope:
    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

    slices = {}
    indexers = {}
    for axis_name, value in selection.items():
      if axis_name not in labeled_tensor.axes:
        raise ValueError(
            'The tensor does not have an axis named %s. Its axes are: %r' %
            (axis_name, labeled_tensor.axes.keys()))
      axis = labeled_tensor.axes[axis_name]
      if axis.labels is None:
        raise ValueError(
            'The axis named %s does not have labels. The axis is: %r' %
            (axis_name, axis))

      if isinstance(value, slice):
        # TODO(shoyer): consider deprecating using slices in favor of lists
        if value.start is None:
          start = None
        else:
          start = axis.index(value.start)

        if value.stop is None:
          stop = None
        else:
          # For now, follow the pandas convention of making labeled slices
          # inclusive of both bounds.
          stop = axis.index(value.stop) + 1

        if value.step is not None:
          raise NotImplementedError('slicing with a step is not yet supported')

        slices[axis_name] = slice(start, stop)

      # Needs to be after checking for slices, since slice objects claim to be
      # instances of collections_abc.Hashable but hash() on them fails.
      elif isinstance(value, collections_abc.Hashable):
        slices[axis_name] = axis.index(value)

      elif isinstance(value, list):
        if indexers:
          raise NotImplementedError(
              'select does not yet support more than one list selection at '
              'the same time')
        indexer = [axis.index(v) for v in value]
        indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64)

      else:
        # If type checking is working properly, this shouldn't be possible.
        raise TypeError('cannot handle arbitrary types')

    if indexers and slices:
      raise NotImplementedError(
          'select does not yet support combined scalar and list selection')

    # For now, handle array selection separately, because tf.gather_nd does
    # not support gradients yet. Later, using gather_nd will let us combine
    # these paths.
    if indexers:
      (axis_name, indexer), = indexers.items()
      axis = core.Axis(axis_name, selection[axis_name])
      return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope)
    else:
      return core.slice_function(labeled_tensor, slices, name=scope)
 def test_concat_unknown(self):
   red = core.Axis('rgb', None)
   green = core.Axis('rgb', None)
   self.assertEqual(core.concat_axes([red, green]), red)
 def test_concat_different_names(self):
   red = core.Axis('red', ['red'])
   green = core.Axis('green', ['red'])
   with self.assertRaises(ValueError):
     core.concat_axes([red, green])
  def test_concat_single(self):
    red = core.Axis('rgb', ['red'])

    self.assertEqual(core.concat_axes([red]), red)
 def test_axis_value_input(self):
   axis = self.i_range
   for value in [range(7), list(range(7)), np.arange(7)]:
     self.assertEqual(axis, core.Axis(axis.name, value))
 def test_axis_input(self):
   axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
   for axis in axes:
     self.assertEqual(axis, core.Axis(axis.name, axis.value))