コード例 #1
0
 def output_axes(axes):
   if enqueue_many:
     if 'batch' not in axes or list(axes.keys()).index('batch') != 0:
       raise ValueError(
           'When enqueue_many is True, input tensors must have an axis '
           'called "batch" as their first dimension, '
           'but axes were %s' % axes)
     culled_axes = axes.remove('batch')
     return core.Axes([('batch', batch_size)] + list(culled_axes.values()))
   else:
     return core.Axes([('batch', batch_size)] + list(axes.values()))
コード例 #2
0
  def setUp(self):
    d_7 = tensor_shape.Dimension(7)
    d_8 = tensor_shape.Dimension(8)
    p_rgb = ['red', 'green', 'blue']
    p_range = range(7)

    self.i_8 = core.Axis('8', d_8)

    self.a0 = core.Axes([('d7', d_7)])
    self.a1 = core.Axes([('d7', d_7)])
    self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
    self.a3 = core.Axes([('8', d_8), ('range', p_range)])
コード例 #3
0
    def test_double(self):
        crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})

        self.assertEqual(
            core.Axes(
                [self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
            crop_lt.axes)
コード例 #4
0
 def test_unknown_dimension(self):
     orig_lt = core.LabeledTensor(
         array_ops.placeholder(dtypes.float32, [None]), ['x'])
     reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
     self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
     with self.cached_session() as sess:
         result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
         np.testing.assert_array_equal(result, [[1], [2]])
コード例 #5
0
def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
  with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope:
    temp_axes = core.Axes([axis] + list(
        labeled_tensor.axes.remove(axis.name).values()))
    transposed = core.transpose(labeled_tensor, temp_axes.keys())
    indexed = core.LabeledTensor(
        array_ops.gather(transposed.tensor, indexer), temp_axes)
    return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
コード例 #6
0
 def test_typecheck_error_message(self):
   pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '
              'Union(Union(numpy.ndarray, %s, list, tuple), '
              'Optional(Union(tensorflow.Dimension, int))))))' %
              range.__name__)
   regexp = re.escape(pattern).replace(re.escape('...'), '.*')
   with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):
     core.Axes(None)
コード例 #7
0
    def test_bijection_flat(self):
        rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])

        encode_lt = rc.encode(self.masked_image_lt)
        golden_axes = core.Axes([
            self.batch_axis, self.row_axis, self.column_axis,
            ('depth', len(self.channels) * len(self.masks))
        ])
        self.assertEqual(encode_lt.axes, golden_axes)

        decode_lt = rc.decode(encode_lt)
        self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
コード例 #8
0
    def test_bijection_with_labels(self):
        depth_axis = core.Axis('depth',
                               range(len(self.channels) * len(self.masks)))
        rc = sugar.ReshapeCoder(['channel', 'mask'],
                                [depth_axis, ('other', ['label'])])

        encode_lt = rc.encode(self.masked_image_lt)
        golden_axes = core.Axes([
            self.batch_axis, self.row_axis, self.column_axis, depth_axis,
            ('other', ['label'])
        ])
        self.assertEqual(encode_lt.axes, golden_axes)

        decode_lt = rc.decode(encode_lt)
        self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
コード例 #9
0
  def test_different_inputs(self):
    # The correct axis ordering is ['x', 'channel', 'probs'].
    align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
        self.x_probs_lt, self.channel_probs_lt)

    x_probs_golden_lt = core.LabeledTensor(
        array_ops.reshape(self.x_probs_lt.tensor,
                          [self.x_size, 1, self.probs_size]),
        [self.a0, 'channel', self.a3])

    self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)

    channel_probs_golden_lt = core.LabeledTensor(
        array_ops.reshape(self.channel_probs_lt.tensor,
                          [1, self.channel_size, self.probs_size]),
        ['x', self.a1, self.a3])

    self.assertLabeledTensorsEqual(align_channel_probs_lt,
                                   channel_probs_golden_lt)

    self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
コード例 #10
0
def placeholder(dtype, axes, name=None):
  """Create a placeholder for a labeled tensor.

  For example:

    lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])

  See tf.compat.v1.placeholder for more details.

  Args:
    dtype: The type of elements in the tensor to be fed.
    axes: sequence of strings (denoting axes of unknown size) and/or objects
      convertable to lt.Axis to label the result.
    name: Optional op name.

  Returns:
    Placeholder labeled tensor.
  """
  with ops.name_scope(name, 'lt_placeholder', []) as scope:
    axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis
                      for axis in axes])
    shape = [axis.size for axis in axes.values()]
    tensor = array_ops.placeholder(dtype, shape, name=scope)
    return core.LabeledTensor(tensor, axes)
コード例 #11
0
    def test_size1(self):
        crop_lt = ops.random_crop(self.original_lt, {'probs': 1})

        self.assertEqual(
            core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
            crop_lt.axes)
コード例 #12
0
 def test_unknown_dimension(self):
   orig_lt = core.LabeledTensor(
       array_ops.placeholder(dtypes.float32, [None]), ['x'])
   expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
   self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
コード例 #13
0
 def test_remove(self):
   a = self.a3.remove('range')
   self.assertEqual(a, core.Axes([self.i_8]))
   with self.assertRaises(KeyError):
     self.a3.remove('foobar')
コード例 #14
0
 def test(self):
     placeholder_lt = io_ops.placeholder(dtypes.float32,
                                         ['batch', ('x', ['a', 'b'])])
     self.assertEqual(placeholder_lt.dtype, dtypes.float32)
     self.assertEqual(placeholder_lt.axes,
                      core.Axes([('batch', None), ('x', ['a', 'b'])]))