Example #1
0
 def output_axes(axes):
   if enqueue_many:
     if 'batch' not in axes or list(axes.keys()).index('batch') != 0:
       raise ValueError(
           'When enqueue_many is True, input tensors must have an axis '
           'called "batch" as their first dimension, '
           'but axes were %s' % axes)
     culled_axes = axes.remove('batch')
     return core.Axes([('batch', batch_size)] + list(culled_axes.values()))
   else:
     return core.Axes([('batch', batch_size)] + list(axes.values()))
Example #2
0
    def setUp(self):
        d_7 = tensor_shape.Dimension(7)
        d_8 = tensor_shape.Dimension(8)
        p_rgb = ['red', 'green', 'blue']
        p_range = range(7)

        self.i_8 = core.Axis('8', d_8)

        self.a0 = core.Axes([('d7', d_7)])
        self.a1 = core.Axes([('d7', d_7)])
        self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
        self.a3 = core.Axes([('8', d_8), ('range', p_range)])
Example #3
0
    def test_double(self):
        crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})

        self.assertEqual(
            core.Axes(
                [self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
            crop_lt.axes)
def placeholder(dtype, axes, name=None):
    """Create a placeholder for a labeled tensor.

  For example:

    lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])

  See tf.compat.v1.placeholder for more details.

  Args:
    dtype: The type of elements in the tensor to be fed.
    axes: sequence of strings (denoting axes of unknown size) and/or objects
      convertable to lt.Axis to label the result.
    name: Optional op name.

  Returns:
    Placeholder labeled tensor.
  """
    with ops.name_scope(name, 'lt_placeholder', []) as scope:
        axes = core.Axes([(axis,
                           None) if isinstance(axis, string_types) else axis
                          for axis in axes])
        shape = [axis.size for axis in axes.values()]
        tensor = array_ops.placeholder(dtype, shape, name=scope)
        return core.LabeledTensor(tensor, axes)
Example #5
0
 def test_unknown_dimension(self):
   orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
   reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
   self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
   with self.test_session() as sess:
     result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
     np.testing.assert_array_equal(result, [[1], [2]])
Example #6
0
def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
    with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope:
        temp_axes = core.Axes(
            [axis] + list(labeled_tensor.axes.remove(axis.name).values()))
        transposed = core.transpose(labeled_tensor, temp_axes.keys())
        indexed = core.LabeledTensor(
            array_ops.gather(transposed.tensor, indexer), temp_axes)
        return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
Example #7
0
 def test_typecheck_error_message(self):
     pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '
                'Union(Union(numpy.ndarray, %s, list, tuple), '
                'Optional(Union(tensorflow.Dimension, int))))))' %
                range.__name__)
     regexp = re.escape(pattern).replace(re.escape('...'), '.*')
     with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):
         core.Axes(None)
Example #8
0
  def test_bijection_flat(self):
    rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])

    encode_lt = rc.encode(self.masked_image_lt)
    golden_axes = core.Axes([
        self.batch_axis, self.row_axis, self.column_axis,
        ('depth', len(self.channels) * len(self.masks))
    ])
    self.assertEqual(encode_lt.axes, golden_axes)

    decode_lt = rc.decode(encode_lt)
    self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
Example #9
0
  def test_bijection_with_labels(self):
    depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks)))
    rc = sugar.ReshapeCoder(['channel', 'mask'],
                            [depth_axis, ('other', ['label'])])

    encode_lt = rc.encode(self.masked_image_lt)
    golden_axes = core.Axes([
        self.batch_axis, self.row_axis, self.column_axis, depth_axis,
        ('other', ['label'])
    ])
    self.assertEqual(encode_lt.axes, golden_axes)

    decode_lt = rc.decode(encode_lt)
    self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
Example #10
0
  def test_different_inputs(self):
    # The correct axis ordering is ['x', 'channel', 'probs'].
    align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
        self.x_probs_lt, self.channel_probs_lt)

    x_probs_golden_lt = core.LabeledTensor(
        array_ops.reshape(self.x_probs_lt.tensor,
                          [self.x_size, 1, self.probs_size]),
        [self.a0, 'channel', self.a3])

    self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)

    channel_probs_golden_lt = core.LabeledTensor(
        array_ops.reshape(self.channel_probs_lt.tensor,
                          [1, self.channel_size, self.probs_size]),
        ['x', self.a1, self.a3])

    self.assertLabeledTensorsEqual(align_channel_probs_lt,
                                   channel_probs_golden_lt)

    self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
Example #11
0
    def test_size1(self):
        crop_lt = ops.random_crop(self.original_lt, {'probs': 1})

        self.assertEqual(
            core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
            crop_lt.axes)
Example #12
0
 def test_unknown_dimension(self):
     orig_lt = core.LabeledTensor(
         array_ops.placeholder(dtypes.float32, [None]), ['x'])
     expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
     self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
Example #13
0
 def test_remove(self):
     a = self.a3.remove('range')
     self.assertEqual(a, core.Axes([self.i_8]))
     with self.assertRaises(KeyError):
         self.a3.remove('foobar')
Example #14
0
 def test(self):
     placeholder_lt = io_ops.placeholder(dtypes.float32,
                                         ['batch', ('x', ['a', 'b'])])
     self.assertEqual(placeholder_lt.dtype, dtypes.float32)
     self.assertEqual(placeholder_lt.axes,
                      core.Axes([('batch', None), ('x', ['a', 'b'])]))