Exemplo n.º 1
0
    def setUp(self):
        d_7 = tensor_shape.Dimension(7)
        d_8 = tensor_shape.Dimension(8)
        p_rgb = ['red', 'green', 'blue']
        p_range = range(7)

        self.i_8 = core.Axis('8', d_8)

        self.a0 = core.Axes([('d7', d_7)])
        self.a1 = core.Axes([('d7', d_7)])
        self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
        self.a3 = core.Axes([('8', d_8), ('range', p_range)])
Exemplo n.º 2
0
  def test_bijection_with_labels(self):
    depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks)))
    rc = sugar.ReshapeCoder(['channel', 'mask'],
                            [depth_axis, ('other', ['label'])])

    encode_lt = rc.encode(self.masked_image_lt)
    golden_axes = core.Axes([
        self.batch_axis, self.row_axis, self.column_axis, depth_axis,
        ('other', ['label'])
    ])
    self.assertEqual(encode_lt.axes, golden_axes)

    decode_lt = rc.decode(encode_lt)
    self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
Exemplo n.º 3
0
def rename_axis(labeled_tensor, existing_name, new_name, name=None):
  """Rename an axis of LabeledTensor.

  Args:
    labeled_tensor: The input tensor.
    existing_name: Name for an existing axis on the input.
    new_name: Desired replacement name.
    name: Optional op name.

  Returns:
    LabeledTensor with renamed axis.

  Raises:
    ValueError: If `existing_name` is not an axis on the input.
  """
  with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope:
    if existing_name not in labeled_tensor.axes:
      raise ValueError('existing_name %r are not contained in the set of axis '
                       'names %r on the input labeled tensor' %
                       (existing_name, labeled_tensor.axes.keys()))
    new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value)
    return reshape(labeled_tensor, [existing_name], [new_axis], name=scope)
Exemplo n.º 4
0
def select(labeled_tensor, selection, name=None):
    """Slice out a subset of the tensor.

  Args:
    labeled_tensor: The input tensor.
    selection: A dictionary mapping an axis name to a scalar, slice or list of
      values to select. Currently supports two types of selections:
        (a) Any number of scalar and/or slice selections.
        (b) Exactly one list selection, without any scalars or slices.
    name: Optional op name.

  Returns:
    The selection as a `LabeledTensor`.

  Raises:
    ValueError: If the tensor doesn't have an axis in the selection or if
      that axis lacks labels.
    KeyError: If any labels in a selection are not found in the original axis.
    NotImplementedError: If you attempt to combine a list selection with
      scalar selection or another list selection.
  """
    with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope:
        labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

        slices = {}
        indexers = {}
        for axis_name, value in selection.items():
            if axis_name not in labeled_tensor.axes:
                raise ValueError(
                    'The tensor does not have an axis named %s. Its axes are: %r'
                    % (axis_name, labeled_tensor.axes.keys()))
            axis = labeled_tensor.axes[axis_name]
            if axis.labels is None:
                raise ValueError(
                    'The axis named %s does not have labels. The axis is: %r' %
                    (axis_name, axis))

            if isinstance(value, slice):
                # TODO(shoyer): consider deprecating using slices in favor of lists
                if value.start is None:
                    start = None
                else:
                    start = axis.index(value.start)

                if value.stop is None:
                    stop = None
                else:
                    # For now, follow the pandas convention of making labeled slices
                    # inclusive of both bounds.
                    stop = axis.index(value.stop) + 1

                if value.step is not None:
                    raise NotImplementedError(
                        'slicing with a step is not yet supported')

                slices[axis_name] = slice(start, stop)

            # Needs to be after checking for slices, since slice objects claim to be
            # instances of collections.Hashable but hash() on them fails.
            elif isinstance(value, collections.Hashable):
                slices[axis_name] = axis.index(value)

            elif isinstance(value, list):
                if indexers:
                    raise NotImplementedError(
                        'select does not yet support more than one list selection at '
                        'the same time')
                indexer = [axis.index(v) for v in value]
                indexers[axis_name] = ops.convert_to_tensor(indexer,
                                                            dtype=dtypes.int64)

            else:
                # If type checking is working properly, this shouldn't be possible.
                raise TypeError('cannot handle arbitrary types')

        if indexers and slices:
            raise NotImplementedError(
                'select does not yet support combined scalar and list selection'
            )

        # For now, handle array selection separately, because tf.gather_nd does
        # not support gradients yet. Later, using gather_nd will let us combine
        # these paths.
        if indexers:
            (axis_name, indexer), = indexers.items()
            axis = core.Axis(axis_name, selection[axis_name])
            return _gather_1d_on_axis(labeled_tensor,
                                      indexer,
                                      axis,
                                      name=scope)
        else:
            return core.slice_function(labeled_tensor, slices, name=scope)
Exemplo n.º 5
0
 def test_concat_unknown(self):
     red = core.Axis('rgb', None)
     green = core.Axis('rgb', None)
     self.assertEqual(core.concat_axes([red, green]), red)
Exemplo n.º 6
0
 def test_concat_different_names(self):
     red = core.Axis('red', ['red'])
     green = core.Axis('green', ['red'])
     with self.assertRaises(ValueError):
         core.concat_axes([red, green])
Exemplo n.º 7
0
    def test_concat_single(self):
        red = core.Axis('rgb', ['red'])

        self.assertEqual(core.concat_axes([red]), red)
Exemplo n.º 8
0
 def test_axis_value_input(self):
     axis = self.i_range
     for value in [range(7), list(range(7)), np.arange(7)]:
         self.assertEqual(axis, core.Axis(axis.name, value))
Exemplo n.º 9
0
 def test_axis_input(self):
     axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
     for axis in axes:
         self.assertEqual(axis, core.Axis(axis.name, axis.value))