예제 #1
0
 def test_unknown_size(self):
   reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
                            ['new_dim'])
   golden_lt = core.LabeledTensor(
       array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
       [self.original_lt.axes['x'], 'new_dim'])
   self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
예제 #2
0
 def test_unknown_size(self):
     reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
                              ['new_dim'])
     golden_lt = core.LabeledTensor(
         array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
         [self.original_lt.axes['x'], 'new_dim'])
     self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
예제 #3
0
  def decode(self, labeled_tensor):
    """Reshape the input to the original shape.

    This is the inverse of encode.
    Encode must have been called at least once prior to this method being
    called.

    Args:
      labeled_tensor: The input tensor.

    Returns:
      The input reshaped to the original shape.

    Raises:
      ValueError: If this method was called before encode was called.
    """
    if self._existing_axes is None:
      raise ValueError('decode called before encode')

    with tf_ops.name_scope(self._name, 'lt_reshape_decode',
                           [labeled_tensor]) as scope:
      labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

      new_axis_names = [axis if isinstance(axis, string_types) else
                        core.as_axis(axis).name for axis in self._new_axes]

      return ops.reshape(labeled_tensor,
                         new_axis_names,
                         self._existing_axes,
                         name=scope)
예제 #4
0
  def encode(self, labeled_tensor):
    """Reshape the input to the target shape.

    If called several times, the axes named in existing_axis_names must be
    identical.

    Args:
      labeled_tensor: The input tensor.

    Returns:
      The input reshaped to the target shape.

    Raises:
      ValueError: If the axes in existing_axis_names don't match the axes of
        a tensor in a previous invocation of this method.
    """
    with tf_ops.name_scope(self._name, 'lt_reshape_encode',
                           [labeled_tensor]) as scope:
      labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

      reshape_lt = ops.reshape(labeled_tensor,
                               self._existing_axis_names,
                               self._new_axes,
                               name=scope)

      axes = [labeled_tensor.axes[n] for n in self._existing_axis_names]
      if self._existing_axes is not None and self._existing_axes != axes:
        raise ValueError(
            'input axes %r do not match axes from previous method call %r' %
            (axes, self._existing_axes))
      else:
        self._existing_axes = axes

      return reshape_lt
예제 #5
0
    def decode(self, labeled_tensor):
        """Reshape the input to the original shape.

    This is the inverse of encode.
    Encode must have been called at least once prior to this method being
    called.

    Args:
      labeled_tensor: The input tensor.

    Returns:
      The input reshaped to the original shape.

    Raises:
      ValueError: If this method was called before encode was called.
    """
        if self._existing_axes is None:
            raise ValueError('decode called before encode')

        with tf_ops.name_scope(self._name, 'lt_reshape_decode',
                               [labeled_tensor]) as scope:
            labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

            new_axis_names = [
                axis
                if isinstance(axis, string_types) else core.as_axis(axis).name
                for axis in self._new_axes
            ]

            return ops.reshape(labeled_tensor,
                               new_axis_names,
                               self._existing_axes,
                               name=scope)
예제 #6
0
    def encode(self, labeled_tensor):
        """Reshape the input to the target shape.

    If called several times, the axes named in existing_axis_names must be
    identical.

    Args:
      labeled_tensor: The input tensor.

    Returns:
      The input reshaped to the target shape.

    Raises:
      ValueError: If the axes in existing_axis_names don't match the axes of
        a tensor in a previous invocation of this method.
    """
        with tf_ops.name_scope(self._name, 'lt_reshape_encode',
                               [labeled_tensor]) as scope:
            labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)

            reshape_lt = ops.reshape(labeled_tensor,
                                     self._existing_axis_names,
                                     self._new_axes,
                                     name=scope)

            axes = [labeled_tensor.axes[n] for n in self._existing_axis_names]
            if self._existing_axes is not None and self._existing_axes != axes:
                raise ValueError(
                    'input axes %r do not match axes from previous method call %r'
                    % (axes, self._existing_axes))
            else:
                self._existing_axes = axes

            return reshape_lt
예제 #7
0
 def test_unknown_dimension(self):
   orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
   reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
   self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
   with self.test_session() as sess:
     result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
     np.testing.assert_array_equal(result, [[1], [2]])
예제 #8
0
 def test_with_labels(self):
     new_dim_size = self.channel_size * self.z_size * self.probs_size
     reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
                              [('new_dim', range(new_dim_size))])
     golden_lt = core.LabeledTensor(
         array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
         [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
     self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
예제 #9
0
 def test_with_labels(self):
   new_dim_size = self.channel_size * self.z_size * self.probs_size
   reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
                            [('new_dim', range(new_dim_size))])
   golden_lt = core.LabeledTensor(
       array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
       [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
   self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
예제 #10
0
 def test_known_size(self):
   new_dim_size = self.channel_size * self.z_size * self.probs_size
   reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
                            [('new_dim', new_dim_size)])
   golden_lt = core.LabeledTensor(
       tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
       [self.original_lt.axes['x'], 'new_dim'])
   self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
예제 #11
0
 def test_invalid_input(self):
   with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
     ops.reshape(self.original_lt, ['foo'], ['bar'])
   with self.assertRaisesRegexp(core.AxisOrderError,
                                'not a slice of axis names'):
     ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
   with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
     ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
예제 #12
0
 def test_invalid_input(self):
   with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
     ops.reshape(self.original_lt, ['foo'], ['bar'])
   with self.assertRaisesRegexp(core.AxisOrderError,
                                'not a slice of axis names'):
     ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
   with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
     ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
예제 #13
0
 def test_identity(self):
     reshape_lt = ops.reshape(self.original_lt,
                              self.original_lt.axes.keys(),
                              self.original_lt.axes.values())
     self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
예제 #14
0
 def test_name(self):
     reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
     self.assertIn('lt_reshape', reshape_lt.name)
예제 #15
0
 def test_identity(self):
   reshape_lt = ops.reshape(self.original_lt,
                            self.original_lt.axes.keys(),
                            self.original_lt.axes.values())
   self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
예제 #16
0
 def test_name(self):
   reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
   self.assertIn('lt_reshape', reshape_lt.name)