Пример #1
0
 def test_invalid_input(self):
   with self.assertRaises(TypeError):
     core.Axis('foo', [{}])
   with self.assertRaises(ValueError):
     core.Axis('foo', [1, 2, 3, 1])
   red = core.Axis('foo', ['red'])
   with self.assertRaises(tc.Error):
     core.concat_axes([red, 1])
Пример #2
0
 def test_invalid_input(self):
     with self.assertRaises(TypeError):
         core.Axis('foo', [{}])
     with self.assertRaises(ValueError):
         core.Axis('foo', [1, 2, 3, 1])
     red = core.Axis('foo', ['red'])
     with self.assertRaises(tc.Error):
         core.concat_axes([red, 1])
Пример #3
0
    def test_concat_many(self):
        red = core.Axis('rgb', ['red'])
        green = core.Axis('rgb', ['green'])
        blue = core.Axis('rgb', ['blue'])
        red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])

        self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
Пример #4
0
def concat(labeled_tensors, axis_name, name=None):
  """Concatenate tensors along a dimension.

  See tf.concat.

  Args:
    labeled_tensors: A list of input LabeledTensors.
    axis_name: The name of the axis along which to concatenate.
    name: Optional op name.

  Returns:
    The concatenated tensor.
    The coordinate labels for the concatenation dimension are also concatenated,
    if they are available for every tensor.

  Raises:
    ValueError: If fewer than one tensor inputs is provided, if the tensors
      have incompatible axes, or if `axis_name` isn't the name of an axis.
  """
  with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
    labeled_tensors = [
        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    ]

    if len(labeled_tensors) < 1:
      raise ValueError('concat expects at least 1 tensor, but received %s' %
                       labeled_tensors)

    # All tensors must have these axes.
    axes_0 = labeled_tensors[0].axes
    axis_names = list(axes_0.keys())

    if axis_name not in axis_names:
      raise ValueError('%s not in %s' % (axis_name, axis_names))

    shared_axes = axes_0.remove(axis_name)

    tensors = [labeled_tensors[0].tensor]
    concat_axis_list = [axes_0[axis_name]]
    for labeled_tensor in labeled_tensors[1:]:
      current_shared_axes = labeled_tensor.axes.remove(axis_name)
      if current_shared_axes != shared_axes:
        # TODO(shoyer): add more specific checks about what went wrong,
        # including raising AxisOrderError when appropriate
        raise ValueError('Mismatched shared axes: the first tensor '
                         'had axes %r but this tensor has axes %r.' %
                         (shared_axes, current_shared_axes))

      # Accumulate the axis labels, if they're available.
      concat_axis_list.append(labeled_tensor.axes[axis_name])
      tensors.append(labeled_tensor.tensor)

    concat_axis = core.concat_axes(concat_axis_list)
    concat_dimension = axis_names.index(axis_name)
    concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope)
    values = list(axes_0.values())
    concat_axes = (values[:concat_dimension] + [concat_axis] +
                   values[concat_dimension + 1:])

    return core.LabeledTensor(concat_tensor, concat_axes)
Пример #5
0
  def test_concat_many(self):
    red = core.Axis('rgb', ['red'])
    green = core.Axis('rgb', ['green'])
    blue = core.Axis('rgb', ['blue'])
    red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])

    self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
Пример #6
0
def concat(labeled_tensors, axis_name, name=None):
  """Concatenate tensors along a dimension.

  See tf.concat.

  Args:
    labeled_tensors: A list of input LabeledTensors.
    axis_name: The name of the axis along which to concatenate.
    name: Optional op name.

  Returns:
    The concatenated tensor.
    The coordinate labels for the concatenation dimension are also concatenated,
    if they are available for every tensor.

  Raises:
    ValueError: If fewer than one tensor inputs is provided, if the tensors
      have incompatible axes, or if `axis_name` isn't the name of an axis.
  """
  with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
    labeled_tensors = [
        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    ]

    if len(labeled_tensors) < 1:
      raise ValueError('concat expects at least 1 tensor, but received %s' %
                       labeled_tensors)

    # All tensors must have these axes.
    axes_0 = labeled_tensors[0].axes
    axis_names = list(axes_0.keys())

    if axis_name not in axis_names:
      raise ValueError('%s not in %s' % (axis_name, axis_names))

    shared_axes = axes_0.remove(axis_name)

    tensors = [labeled_tensors[0].tensor]
    concat_axis_list = [axes_0[axis_name]]
    for labeled_tensor in labeled_tensors[1:]:
      current_shared_axes = labeled_tensor.axes.remove(axis_name)
      if current_shared_axes != shared_axes:
        # TODO (shoyer): add more specific checks about what went wrong, id:1091 gh:1092
        # including raising AxisOrderError when appropriate
        raise ValueError('Mismatched shared axes: the first tensor '
                         'had axes %r but this tensor has axes %r.' %
                         (shared_axes, current_shared_axes))

      # Accumulate the axis labels, if they're available.
      concat_axis_list.append(labeled_tensor.axes[axis_name])
      tensors.append(labeled_tensor.tensor)

    concat_axis = core.concat_axes(concat_axis_list)
    concat_dimension = axis_names.index(axis_name)
    concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope)
    values = list(axes_0.values())
    concat_axes = (values[:concat_dimension] + [concat_axis] +
                   values[concat_dimension + 1:])

    return core.LabeledTensor(concat_tensor, concat_axes)
Пример #7
0
def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
    """Pads a tensor.

  See tf.pad.

  Args:
    labeled_tensor: The input tensor.
    paddings: A mapping where the keys are axis names and the values are
      tuples where the first element is the padding to insert at the beginning
      of the axis and the second is the padding to insert at the end of the
      axis.
    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
    name: Optional op name.

  Returns:
    A tensor with the indicated axes padded, optionally with those axes extended
    with the provided labels.

  Raises:
    ValueError: If the padded axes are not axes in the input tensor.
  """
    with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
        labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

        if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
            raise ValueError(
                'pad axes %r are not contained in the set of axis '
                'names %r on the input labeled tensor' %
                (paddings.keys(), labeled_tensor.axes))

        new_axes = []
        padding_pairs = []
        for name, axis in labeled_tensor.axes.items():
            if name in paddings:
                padding_before, padding_after = paddings[name]
                axis_before = core.Axis(name, padding_before)
                axis_after = core.Axis(name, padding_after)
                new_axes.append(
                    core.concat_axes([axis_before, axis, axis_after]))
                padding_pairs.append((len(axis_before), len(axis_after)))
            else:
                new_axes.append(axis)
                padding_pairs.append((0, 0))

        pad_op = array_ops.pad(labeled_tensor.tensor,
                               padding_pairs,
                               mode,
                               name=scope)

        return core.LabeledTensor(pad_op, new_axes)
Пример #8
0
def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
  """Pads a tensor.

  See tf.pad.

  Args:
    labeled_tensor: The input tensor.
    paddings: A mapping where the keys are axis names and the values are
      tuples where the first element is the padding to insert at the beginning
      of the axis and the second is the padding to insert at the end of the
      axis.
    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
    name: Optional op name.

  Returns:
    A tensor with the indicated axes padded, optionally with those axes extended
    with the provided labels.

  Raises:
    ValueError: If the padded axes are not axes in the input tensor.
  """
  with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

    if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
      raise ValueError('pad axes %r are not contained in the set of axis '
                       'names %r on the input labeled tensor' %
                       (paddings.keys(), labeled_tensor.axes))

    new_axes = []
    padding_pairs = []
    for name, axis in labeled_tensor.axes.items():
      if name in paddings:
        padding_before, padding_after = paddings[name]
        axis_before = core.Axis(name, padding_before)
        axis_after = core.Axis(name, padding_after)
        new_axes.append(core.concat_axes([axis_before, axis, axis_after]))
        padding_pairs.append((len(axis_before), len(axis_after)))
      else:
        new_axes.append(axis)
        padding_pairs.append((0, 0))

    pad_op = array_ops.pad(labeled_tensor.tensor,
                           padding_pairs,
                           mode,
                           name=scope)

    return core.LabeledTensor(pad_op, new_axes)
Пример #9
0
 def test_concat_unknown(self):
     red = core.Axis('rgb', None)
     green = core.Axis('rgb', None)
     self.assertEqual(core.concat_axes([red, green]), red)
Пример #10
0
 def test_concat_different_names(self):
     red = core.Axis('red', ['red'])
     green = core.Axis('green', ['red'])
     with self.assertRaises(ValueError):
         core.concat_axes([red, green])
Пример #11
0
    def test_concat_single(self):
        red = core.Axis('rgb', ['red'])

        self.assertEqual(core.concat_axes([red]), red)
Пример #12
0
 def test_concat_unknown(self):
   red = core.Axis('rgb', None)
   green = core.Axis('rgb', None)
   self.assertEqual(core.concat_axes([red, green]), red)
Пример #13
0
 def test_concat_different_names(self):
   red = core.Axis('red', ['red'])
   green = core.Axis('green', ['red'])
   with self.assertRaises(ValueError):
     core.concat_axes([red, green])
Пример #14
0
  def test_concat_single(self):
    red = core.Axis('rgb', ['red'])

    self.assertEqual(core.concat_axes([red]), red)