コード例 #1
0
ファイル: core_test.py プロジェクト: tonydeep/tensorflow
 def test_invalid_input(self):
   with self.assertRaises(TypeError):
     core.Axis('foo', [{}])
   with self.assertRaises(ValueError):
     core.Axis('foo', [1, 2, 3, 1])
   red = core.Axis('foo', ['red'])
   with self.assertRaises(tc.Error):
     core.concat_axes([red, 1])
コード例 #2
0
 def test_invalid_input(self):
     with self.assertRaises(TypeError):
         core.Axis('foo', [{}])
     with self.assertRaises(ValueError):
         core.Axis('foo', [1, 2, 3, 1])
     red = core.Axis('foo', ['red'])
     with self.assertRaises(tc.Error):
         core.concat_axes([red, 1])
コード例 #3
0
    def test_concat_many(self):
        red = core.Axis('rgb', ['red'])
        green = core.Axis('rgb', ['green'])
        blue = core.Axis('rgb', ['blue'])
        red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])

        self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
コード例 #4
0
ファイル: ops.py プロジェクト: Immexxx/tensorflow
def concat(labeled_tensors, axis_name, name=None):
  """Concatenate tensors along a dimension.

  See tf.concat.

  Args:
    labeled_tensors: A list of input LabeledTensors.
    axis_name: The name of the axis along which to concatenate.
    name: Optional op name.

  Returns:
    The concatenated tensor.
    The coordinate labels for the concatenation dimension are also concatenated,
    if they are available for every tensor.

  Raises:
    ValueError: If fewer than one tensor inputs is provided, if the tensors
      have incompatible axes, or if `axis_name` isn't the name of an axis.
  """
  with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
    labeled_tensors = [
        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    ]

    if len(labeled_tensors) < 1:
      raise ValueError('concat expects at least 1 tensor, but received %s' %
                       labeled_tensors)

    # All tensors must have these axes.
    axes_0 = labeled_tensors[0].axes
    axis_names = list(axes_0.keys())

    if axis_name not in axis_names:
      raise ValueError('%s not in %s' % (axis_name, axis_names))

    shared_axes = axes_0.remove(axis_name)

    tensors = [labeled_tensors[0].tensor]
    concat_axis_list = [axes_0[axis_name]]
    for labeled_tensor in labeled_tensors[1:]:
      current_shared_axes = labeled_tensor.axes.remove(axis_name)
      if current_shared_axes != shared_axes:
        # TODO(shoyer): add more specific checks about what went wrong,
        # including raising AxisOrderError when appropriate
        raise ValueError('Mismatched shared axes: the first tensor '
                         'had axes %r but this tensor has axes %r.' %
                         (shared_axes, current_shared_axes))

      # Accumulate the axis labels, if they're available.
      concat_axis_list.append(labeled_tensor.axes[axis_name])
      tensors.append(labeled_tensor.tensor)

    concat_axis = core.concat_axes(concat_axis_list)
    concat_dimension = axis_names.index(axis_name)
    concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope)
    values = list(axes_0.values())
    concat_axes = (values[:concat_dimension] + [concat_axis] +
                   values[concat_dimension + 1:])

    return core.LabeledTensor(concat_tensor, concat_axes)
コード例 #5
0
ファイル: core_test.py プロジェクト: tonydeep/tensorflow
  def test_concat_many(self):
    red = core.Axis('rgb', ['red'])
    green = core.Axis('rgb', ['green'])
    blue = core.Axis('rgb', ['blue'])
    red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])

    self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
コード例 #6
0
ファイル: ops.py プロジェクト: jhabikal21/tensorflow
def concat(labeled_tensors, axis_name, name=None):
  """Concatenate tensors along a dimension.

  See tf.concat.

  Args:
    labeled_tensors: A list of input LabeledTensors.
    axis_name: The name of the axis along which to concatenate.
    name: Optional op name.

  Returns:
    The concatenated tensor.
    The coordinate labels for the concatenation dimension are also concatenated,
    if they are available for every tensor.

  Raises:
    ValueError: If fewer than one tensor inputs is provided, if the tensors
      have incompatible axes, or if `axis_name` isn't the name of an axis.
  """
  with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
    labeled_tensors = [
        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    ]

    if len(labeled_tensors) < 1:
      raise ValueError('concat expects at least 1 tensor, but received %s' %
                       labeled_tensors)

    # All tensors must have these axes.
    axes_0 = labeled_tensors[0].axes
    axis_names = list(axes_0.keys())

    if axis_name not in axis_names:
      raise ValueError('%s not in %s' % (axis_name, axis_names))

    shared_axes = axes_0.remove(axis_name)

    tensors = [labeled_tensors[0].tensor]
    concat_axis_list = [axes_0[axis_name]]
    for labeled_tensor in labeled_tensors[1:]:
      current_shared_axes = labeled_tensor.axes.remove(axis_name)
      if current_shared_axes != shared_axes:
        # TODO (shoyer): add more specific checks about what went wrong, id:1091 gh:1092
        # including raising AxisOrderError when appropriate
        raise ValueError('Mismatched shared axes: the first tensor '
                         'had axes %r but this tensor has axes %r.' %
                         (shared_axes, current_shared_axes))

      # Accumulate the axis labels, if they're available.
      concat_axis_list.append(labeled_tensor.axes[axis_name])
      tensors.append(labeled_tensor.tensor)

    concat_axis = core.concat_axes(concat_axis_list)
    concat_dimension = axis_names.index(axis_name)
    concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope)
    values = list(axes_0.values())
    concat_axes = (values[:concat_dimension] + [concat_axis] +
                   values[concat_dimension + 1:])

    return core.LabeledTensor(concat_tensor, concat_axes)
コード例 #7
0
def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
    """Pads a tensor.

  See tf.pad.

  Args:
    labeled_tensor: The input tensor.
    paddings: A mapping where the keys are axis names and the values are
      tuples where the first element is the padding to insert at the beginning
      of the axis and the second is the padding to insert at the end of the
      axis.
    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
    name: Optional op name.

  Returns:
    A tensor with the indicated axes padded, optionally with those axes extended
    with the provided labels.

  Raises:
    ValueError: If the padded axes are not axes in the input tensor.
  """
    with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
        labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

        if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
            raise ValueError(
                'pad axes %r are not contained in the set of axis '
                'names %r on the input labeled tensor' %
                (paddings.keys(), labeled_tensor.axes))

        new_axes = []
        padding_pairs = []
        for name, axis in labeled_tensor.axes.items():
            if name in paddings:
                padding_before, padding_after = paddings[name]
                axis_before = core.Axis(name, padding_before)
                axis_after = core.Axis(name, padding_after)
                new_axes.append(
                    core.concat_axes([axis_before, axis, axis_after]))
                padding_pairs.append((len(axis_before), len(axis_after)))
            else:
                new_axes.append(axis)
                padding_pairs.append((0, 0))

        pad_op = array_ops.pad(labeled_tensor.tensor,
                               padding_pairs,
                               mode,
                               name=scope)

        return core.LabeledTensor(pad_op, new_axes)
コード例 #8
0
ファイル: ops.py プロジェクト: Immexxx/tensorflow
def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
  """Pads a tensor.

  See tf.pad.

  Args:
    labeled_tensor: The input tensor.
    paddings: A mapping where the keys are axis names and the values are
      tuples where the first element is the padding to insert at the beginning
      of the axis and the second is the padding to insert at the end of the
      axis.
    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
    name: Optional op name.

  Returns:
    A tensor with the indicated axes padded, optionally with those axes extended
    with the provided labels.

  Raises:
    ValueError: If the padded axes are not axes in the input tensor.
  """
  with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

    if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
      raise ValueError('pad axes %r are not contained in the set of axis '
                       'names %r on the input labeled tensor' %
                       (paddings.keys(), labeled_tensor.axes))

    new_axes = []
    padding_pairs = []
    for name, axis in labeled_tensor.axes.items():
      if name in paddings:
        padding_before, padding_after = paddings[name]
        axis_before = core.Axis(name, padding_before)
        axis_after = core.Axis(name, padding_after)
        new_axes.append(core.concat_axes([axis_before, axis, axis_after]))
        padding_pairs.append((len(axis_before), len(axis_after)))
      else:
        new_axes.append(axis)
        padding_pairs.append((0, 0))

    pad_op = array_ops.pad(labeled_tensor.tensor,
                           padding_pairs,
                           mode,
                           name=scope)

    return core.LabeledTensor(pad_op, new_axes)
コード例 #9
0
 def test_concat_unknown(self):
     red = core.Axis('rgb', None)
     green = core.Axis('rgb', None)
     self.assertEqual(core.concat_axes([red, green]), red)
コード例 #10
0
 def test_concat_different_names(self):
     red = core.Axis('red', ['red'])
     green = core.Axis('green', ['red'])
     with self.assertRaises(ValueError):
         core.concat_axes([red, green])
コード例 #11
0
    def test_concat_single(self):
        red = core.Axis('rgb', ['red'])

        self.assertEqual(core.concat_axes([red]), red)
コード例 #12
0
ファイル: core_test.py プロジェクト: tonydeep/tensorflow
 def test_concat_unknown(self):
   red = core.Axis('rgb', None)
   green = core.Axis('rgb', None)
   self.assertEqual(core.concat_axes([red, green]), red)
コード例 #13
0
ファイル: core_test.py プロジェクト: tonydeep/tensorflow
 def test_concat_different_names(self):
   red = core.Axis('red', ['red'])
   green = core.Axis('green', ['red'])
   with self.assertRaises(ValueError):
     core.concat_axes([red, green])
コード例 #14
0
ファイル: core_test.py プロジェクト: tonydeep/tensorflow
  def test_concat_single(self):
    red = core.Axis('rgb', ['red'])

    self.assertEqual(core.concat_axes([red]), red)