def test_invalid_input(self):
   with self.assertRaises(core.AxisOrderError):
     core.expand_dims(self.original_lt,
                      ['foo', 'not_x', 'bar', 'channel', 'z', 'probs', 'grok'])
   with self.assertRaises(core.AxisOrderError):
     core.expand_dims(self.original_lt,
                      ['foo', 'z', 'bar', 'channel', 'x', 'probs', 'grok'])
Beispiel #2
0
 def test_invalid_input(self):
   with self.assertRaises(core.AxisOrderError):
     core.expand_dims(self.original_lt, ['foo', 'not_x', 'bar', 'channel', 'z',
                                         'probs', 'grok'])
   with self.assertRaises(core.AxisOrderError):
     core.expand_dims(self.original_lt, ['foo', 'z', 'bar', 'channel', 'x',
                                         'probs', 'grok'])
  def test(self):
    expand_lt = core.expand_dims(
        self.original_lt, ['foo', 'x', 'bar', 'channel', 'z', 'probs', 'grok'])
    golden_lt = core.LabeledTensor(
        array_ops.reshape(self.tensor, [
            1, self.x_size, 1, self.channel_size, self.z_size, self.probs_size,
            1
        ]), ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])

    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
Beispiel #4
0
  def test(self):
    expand_lt = core.expand_dims(self.original_lt, ['foo', 'x', 'bar',
                                                    'channel', 'z', 'probs',
                                                    'grok'])
    golden_lt = core.LabeledTensor(
        tf.reshape(self.tensor, [1, self.x_size, 1, self.channel_size,
                                 self.z_size, self.probs_size, 1]),
        ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])

    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
Beispiel #5
0
  def test_label(self):
    expand_lt = core.expand_dims(self.original_lt, ['x',
                                                    'channel',
                                                    ('foo', 'bar'),
                                                    'z',
                                                    'probs',])
    golden_lt = core.LabeledTensor(
        tf.reshape(self.tensor, [self.x_size, self.channel_size, 1, self.z_size,
                                 self.probs_size]),
        [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])

    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
Beispiel #6
0
    def test_label(self):
        expand_lt = core.expand_dims(self.original_lt, [
            'x',
            'channel',
            ('foo', 'bar'),
            'z',
            'probs',
        ])
        golden_lt = core.LabeledTensor(
            array_ops.reshape(self.tensor, [
                self.x_size, self.channel_size, 1, self.z_size, self.probs_size
            ]), [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])

        self.assertLabeledTensorsEqual(expand_lt, golden_lt)
Beispiel #7
0
 def test_unknown_dimension(self):
     orig_lt = core.LabeledTensor(
         array_ops.placeholder(dtypes.float32, [None]), ['x'])
     expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
     self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
Beispiel #8
0
    def test_identity(self):
        expand_lt = core.expand_dims(self.original_lt,
                                     self.original_lt.axes.keys())
        golden_lt = self.original_lt

        self.assertLabeledTensorsEqual(expand_lt, golden_lt)
Beispiel #9
0
 def test_name(self):
     expand_lt = core.expand_dims(self.original_lt,
                                  self.original_lt.axes.keys())
     self.assertIn('lt_expand', expand_lt.name)
Beispiel #10
0
 def test_unknown_dimension(self):
   orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
   expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
   self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
Beispiel #11
0
  def test_identity(self):
    expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
    golden_lt = self.original_lt

    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
Beispiel #12
0
 def test_name(self):
   expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
   self.assertIn('lt_expand', expand_lt.name)