def test_identical_shaped_inputs(self): offset_tensor = self.original_lt.tensor + 1 offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes) align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt, offset_lt) self.assertLabeledTensorsEqual(align_lt, self.original_lt) self.assertLabeledTensorsEqual(align_offset_lt, offset_lt) self.assertEqual(broadcast_axes, self.original_lt.axes)
def test_axis_order_scope(self): xz_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'z']) yz_lt = core.LabeledTensor(array_ops.ones((4, 3)), ['y', 'z']) _, _, broadcast_axes = core.align(xz_lt, yz_lt) self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z']) _, _, broadcast_axes = core.align(yz_lt, xz_lt) self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z']) with core.axis_order_scope(['x', 'y', 'z']): _, _, broadcast_axes = core.align(yz_lt, xz_lt) self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z']) with core.axis_order_scope(['x', 'y']): with self.assertRaises(core.AxisOrderError): core.align(xz_lt, yz_lt) with self.assertRaises(core.AxisOrderError): core.align(yz_lt, xz_lt)
def test_different_inputs(self): # The correct axis ordering is ['x', 'channel', 'probs']. align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align( self.x_probs_lt, self.channel_probs_lt) x_probs_golden_lt = core.LabeledTensor( array_ops.reshape(self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]), [self.a0, 'channel', self.a3]) self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt) channel_probs_golden_lt = core.LabeledTensor( array_ops.reshape(self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size]), ['x', self.a1, self.a3]) self.assertLabeledTensorsEqual(align_channel_probs_lt, channel_probs_golden_lt) self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
def test_invalid_input(self): lt_0 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(5))]) lt_1 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(1, 6))]) with self.assertRaises(ValueError): core.align(lt_0, lt_1)
def test_name(self): align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt) self.assertIn('lt_align', align_lt_0.name) self.assertIn('/0', align_lt_0.name) self.assertIn('lt_align', align_lt_1.name) self.assertIn('/1', align_lt_1.name)