def setUp(self): super(Base, self).setUp() self.x_size = 7 self.channel_size = 3 self.z_size = 4 self.probs_size = 11 tensor = math_ops.range( 0, self.x_size * self.channel_size * self.z_size * self.probs_size) tensor = array_ops.reshape( tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size]) a0 = ('x', range(self.x_size)) a1 = ('channel', ['red', 'green', 'blue']) a2 = 'z' a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size)) self.tensor = tensor self.a0 = a0 self.a1 = a1 self.a2 = a2 self.a2_resolved = ('z', self.z_size) self.a3 = a3 self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0}) self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'}) self.channel_probs_lt = core.slice_function(self.original_lt, { 'x': 3, 'z': 0 })
def test_slice(self): select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')}) a1_sliced = ('channel', ['red', 'green']) golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], [self.a0, a1_sliced, self.a2, self.a3]) self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_invalid_input(self): with self.assertRaises(ValueError): rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth']) rc.decode(self.masked_image_lt) with self.assertRaises(ValueError): rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth']) rc.encode(self.masked_image_lt) rc.encode(ops.select(self.masked_image_lt, {'channel': 'red'}))
def test_slices(self): select_lt = ops.select(self.original_lt, { 'x': slice(1, 4), 'channel': slice('green', None) }) a0_sliced = ('x', range(1, 5)) a1_sliced = ('channel', ['green', 'blue']) golden_lt = core.LabeledTensor( self.tensor[1:5, 1:, :, :], [a0_sliced, a1_sliced, self.a2, self.a3]) self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_scalar(self): select_lt = ops.select(self.original_lt, {'channel': 'green'}) golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2, self.a3]) self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_name(self): select_lt = ops.select(self.original_lt, {'channel': 'green'}) self.assertIn('lt_select', select_lt.name)
def test(self): concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel') golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']}) self.assertLabeledTensorsEqual(concat_lt, golden_lt)
def setUp(self): super(ConcatTest, self).setUp() self.red_lt = ops.select(self.original_lt, {'channel': ['red']}) self.green_lt = ops.select(self.original_lt, {'channel': ['green']}) self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
def test_invalid_input(self): with self.assertRaises(ValueError): ops.select(self.original_lt, {'foo': 1}) with self.assertRaises(ValueError): ops.select(self.original_lt, {'z': 1}) with self.assertRaises(KeyError): ops.select(self.original_lt, {'channel': 'purple'}) with self.assertRaises(KeyError): ops.select(self.original_lt, {'channel': ['red', 'purple']}) with self.assertRaises(NotImplementedError): ops.select(self.original_lt, {'channel': ['red'], 'x': [1]}) with self.assertRaises(NotImplementedError): ops.select(self.original_lt, {'channel': ['red'], 'x': 1}) with self.assertRaises(NotImplementedError): ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
def test_tuple(self): original_lt = core.LabeledTensor(constant_op.constant([5, 6]), [('x', [(1, 2), (3, 4)])]) select_lt = ops.select(original_lt, {'x': (1, 2)}) golden_lt = core.LabeledTensor(constant_op.constant(5), []) self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_list_zero_items(self): select_lt = ops.select(self.original_lt, {'channel': []}) golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :], [self.a0, 'channel', self.a2, self.a3]) self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_list_one_item(self): select_lt = ops.select(self.original_lt, {'channel': ['red']}) a1_sliced = ('channel', ['red']) golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :], [self.a0, a1_sliced, self.a2, self.a3]) self.assertLabeledTensorsEqual(select_lt, golden_lt)