def setUp(self):
        super(Base, self).setUp()

        self.x_size = 7
        self.channel_size = 3
        self.z_size = 4
        self.probs_size = 11

        tensor = math_ops.range(
            0, self.x_size * self.channel_size * self.z_size * self.probs_size)
        tensor = array_ops.reshape(
            tensor,
            [self.x_size, self.channel_size, self.z_size, self.probs_size])
        a0 = ('x', range(self.x_size))
        a1 = ('channel', ['red', 'green', 'blue'])
        a2 = 'z'
        a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))

        self.tensor = tensor
        self.a0 = a0
        self.a1 = a1
        self.a2 = a2
        self.a2_resolved = ('z', self.z_size)
        self.a3 = a3
        self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])

        self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})
        self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})
        self.channel_probs_lt = core.slice_function(self.original_lt, {
            'x': 3,
            'z': 0
        })
 def test_slice(self):
     select_lt = ops.select(self.original_lt,
                            {'channel': slice('red', 'green')})
     a1_sliced = ('channel', ['red', 'green'])
     golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
                                    [self.a0, a1_sliced, self.a2, self.a3])
     self.assertLabeledTensorsEqual(select_lt, golden_lt)
 def test_invalid_input(self):
     with self.assertRaises(ValueError):
         rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
         rc.decode(self.masked_image_lt)
     with self.assertRaises(ValueError):
         rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
         rc.encode(self.masked_image_lt)
         rc.encode(ops.select(self.masked_image_lt, {'channel': 'red'}))
    def test_slices(self):
        select_lt = ops.select(self.original_lt, {
            'x': slice(1, 4),
            'channel': slice('green', None)
        })

        a0_sliced = ('x', range(1, 5))
        a1_sliced = ('channel', ['green', 'blue'])
        golden_lt = core.LabeledTensor(
            self.tensor[1:5,
                        1:, :, :], [a0_sliced, a1_sliced, self.a2, self.a3])
        self.assertLabeledTensorsEqual(select_lt, golden_lt)
 def test_scalar(self):
     select_lt = ops.select(self.original_lt, {'channel': 'green'})
     golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :],
                                    [self.a0, self.a2, self.a3])
     self.assertLabeledTensorsEqual(select_lt, golden_lt)
 def test_name(self):
     select_lt = ops.select(self.original_lt, {'channel': 'green'})
     self.assertIn('lt_select', select_lt.name)
    def test(self):
        concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')
        golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})

        self.assertLabeledTensorsEqual(concat_lt, golden_lt)
    def setUp(self):
        super(ConcatTest, self).setUp()

        self.red_lt = ops.select(self.original_lt, {'channel': ['red']})
        self.green_lt = ops.select(self.original_lt, {'channel': ['green']})
        self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
 def test_invalid_input(self):
     with self.assertRaises(ValueError):
         ops.select(self.original_lt, {'foo': 1})
     with self.assertRaises(ValueError):
         ops.select(self.original_lt, {'z': 1})
     with self.assertRaises(KeyError):
         ops.select(self.original_lt, {'channel': 'purple'})
     with self.assertRaises(KeyError):
         ops.select(self.original_lt, {'channel': ['red', 'purple']})
     with self.assertRaises(NotImplementedError):
         ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})
     with self.assertRaises(NotImplementedError):
         ops.select(self.original_lt, {'channel': ['red'], 'x': 1})
     with self.assertRaises(NotImplementedError):
         ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
 def test_tuple(self):
     original_lt = core.LabeledTensor(constant_op.constant([5, 6]),
                                      [('x', [(1, 2), (3, 4)])])
     select_lt = ops.select(original_lt, {'x': (1, 2)})
     golden_lt = core.LabeledTensor(constant_op.constant(5), [])
     self.assertLabeledTensorsEqual(select_lt, golden_lt)
 def test_list_zero_items(self):
     select_lt = ops.select(self.original_lt, {'channel': []})
     golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],
                                    [self.a0, 'channel', self.a2, self.a3])
     self.assertLabeledTensorsEqual(select_lt, golden_lt)
 def test_list_one_item(self):
     select_lt = ops.select(self.original_lt, {'channel': ['red']})
     a1_sliced = ('channel', ['red'])
     golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],
                                    [self.a0, a1_sliced, self.a2, self.a3])
     self.assertLabeledTensorsEqual(select_lt, golden_lt)