def test_identical_shaped_inputs(self):
    offset_tensor = self.original_lt.tensor + 1
    offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes)

    align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt,
                                                           offset_lt)

    self.assertLabeledTensorsEqual(align_lt, self.original_lt)
    self.assertLabeledTensorsEqual(align_offset_lt, offset_lt)
    self.assertEqual(broadcast_axes, self.original_lt.axes)
  def test_axis_order_scope(self):
    xz_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'z'])
    yz_lt = core.LabeledTensor(array_ops.ones((4, 3)), ['y', 'z'])

    _, _, broadcast_axes = core.align(xz_lt, yz_lt)
    self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])

    _, _, broadcast_axes = core.align(yz_lt, xz_lt)
    self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z'])

    with core.axis_order_scope(['x', 'y', 'z']):
      _, _, broadcast_axes = core.align(yz_lt, xz_lt)
      self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])

    with core.axis_order_scope(['x', 'y']):
      with self.assertRaises(core.AxisOrderError):
        core.align(xz_lt, yz_lt)
      with self.assertRaises(core.AxisOrderError):
        core.align(yz_lt, xz_lt)
  def test_different_inputs(self):
    # The correct axis ordering is ['x', 'channel', 'probs'].
    align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
        self.x_probs_lt, self.channel_probs_lt)

    x_probs_golden_lt = core.LabeledTensor(
        array_ops.reshape(self.x_probs_lt.tensor,
                          [self.x_size, 1, self.probs_size]),
        [self.a0, 'channel', self.a3])

    self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)

    channel_probs_golden_lt = core.LabeledTensor(
        array_ops.reshape(self.channel_probs_lt.tensor,
                          [1, self.channel_size, self.probs_size]),
        ['x', self.a1, self.a3])

    self.assertLabeledTensorsEqual(align_channel_probs_lt,
                                   channel_probs_golden_lt)

    self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
 def test_invalid_input(self):
   lt_0 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(5))])
   lt_1 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(1, 6))])
   with self.assertRaises(ValueError):
     core.align(lt_0, lt_1)
 def test_name(self):
   align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt)
   self.assertIn('lt_align', align_lt_0.name)
   self.assertIn('/0', align_lt_0.name)
   self.assertIn('lt_align', align_lt_1.name)
   self.assertIn('/1', align_lt_1.name)