Exemplo n.º 1
0
  def test_concat_raises(self):
    args = {'a': tdb.Scalar(),
            'b': tdb.Vector(4),
            'c': tdb.Tensor([3, 3])}
    six.assertRaisesRegex(
        self, TypeError, 'Shapes for concat don\'t match:',
        tdb.Pipe, args, tdb.Concat())

    args = {'a': tdb.Vector(2),
            'b': tdb.Vector(4)}
    six.assertRaisesRegex(
        self, TypeError, 'Concat argument.*has rank less than 2.',
        tdb.Pipe, args, tdb.Concat(concat_dim=1))

    args = {'a': tdb.Vector(2, dtype='int32'),
            'b': tdb.Vector(4)}
    six.assertRaisesRegex(
        self, TypeError,
        'Cannot concat tensors of different dtypes: int32 vs. float32',
        tdb.Pipe, args, tdb.Concat())

    args = ((tdb.Scalar(),), (tdb.Scalar(),))
    six.assertRaisesRegex(
        self, TypeError, 'contains nested tuples', tdb.Pipe, args, tdb.Concat())

    args = ()
    self.assertRaisesWithLiteralMatch(
        TypeError, 'Concat requires at least one tensor as input',
        tdb.Pipe, args, tdb.Concat())
Exemplo n.º 2
0
  def test_optional(self):
    block = tdb.Optional(tdb.Vector(4))
    self.assertBuildsConst([1.0, 2.0, 3.0, 4.0], block, [1, 2, 3, 4])
    self.assertBuildsConst([0.0, 0.0, 0.0, 0.0], block, None)

    block2 = tdb.Optional(tdb.Scalar(), np.array(42.0, dtype='float32'))
    self.assertBuildsConst(6.0, block2, 6)
    self.assertBuildsConst(42.0, block2, None)
Exemplo n.º 3
0
  def test_output_type_inference(self):
    # Identity and composite compute their output types from input types.
    block = tdb.Scalar() >> (tdb.Identity() >> tdb.Identity())
    self.assertBuildsConst(42., block, 42)

    block = ({'a': tdb.Scalar(), 'b': tdb.Vector(2) >> tdb.Identity()} >>
             tdb.Identity() >> (tdb.Identity() >> tdb.Identity()) >>
             tdb.Identity())
    self.assertBuildsConst((42., [5., 1.]), block, {'a': 42, 'b': [5, 1]})
Exemplo n.º 4
0
 def test_one_of_raises(self):
   six.assertRaisesRegex(
       self, TypeError, 'key_fn is not callable: 42',
       tdb.OneOf, 42, (tdb.Scalar(),))
   self.assertRaisesWithLiteralMatch(
       ValueError, 'case_blocks must be non-empty', tdb.OneOf, lambda x: x, {})
   six.assertRaisesRegex(
       self, TypeError, 'Type mismatch between output type',
       tdb.OneOf, lambda x: x, {0: tdb.Scalar(), 1: tdb.Vector(2)})
Exemplo n.º 5
0
  def test_rnn(self):
    # We have to expand_dims to broadcast x over the batch.
    def f(x, st):
      return (tf.multiply(x, x), tf.add(st, tf.expand_dims(x, 1)))

    intup = (tdb.Map(tdb.Scalar()), tdb.Vector(2))
    block = intup >> tdb.RNN(tdb.Function(f), initial_state_from_input=True)
    self.assertBuilds(([], [0.0, 0.0]), block,
                      ([], [0.0, 0.0]), max_depth=0)
    self.assertBuilds(([1.0, 4.0, 9.0, 16.0], [10.0, 10.0]), block,
                      ([1.0, 2.0, 3.0, 4.0], [0.0, 0.0]), max_depth=4)
    self.assertBuilds(([1.0, 4.0, 9.0, 16.0], [10.0, 10.0]), block,
                      ([1.0, 2.0, 3.0, 4.0], [0.0, 0.0]), max_depth=4)
Exemplo n.º 6
0
 def test_rshift(self):
     block = (tdl.FC(1, None, tf.constant_initializer(2.0)) >> tdl.FC(
         1, None, tf.constant_initializer(3.0)))
     with self.test_session():
         self.assertEqual([6.0], (tdb.Vector(1) >> block).eval([1.0],
                                                               tolist=True))
Exemplo n.º 7
0
 def test_vector(self):
   self.assertBuildsConst([1., 2., 3.], tdb.Vector(3), [1, 2, 3])
Exemplo n.º 8
0
 def test_concat_scalar(self):
   block = {'a': tdb.Scalar(),
            'b': tdb.Vector(4),
            'c': tdb.Scalar()} >> tdb.Concat()
   self.assertBuildsConst([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], block,
                          {'a': 1.0, 'b': [2.0, 3.0, 4.0, 5.0], 'c': 6.0})
Exemplo n.º 9
0
 def test_concat(self):
   block = {'a': tdb.Vector(1),
            'b': tdb.Vector(4),
            'c': tdb.Vector(1)} >> tdb.Concat()
   self.assertBuildsConst([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], block,
                          {'a': [1.0], 'b': [2.0, 3.0, 4.0, 5.0], 'c': [6.0]})
Exemplo n.º 10
0
 def test_optional_default_none(self):
   block = tdb.Optional({'a': tdb.Map({'b': tdb.Scalar(), 'c': tdb.Scalar()}),
                         'd': tdb.Vector(3)})
   self.assertBuildsConst(([(0., 1.)], [2., 3., 4.]), block,
                          {'a': [{'b': 0, 'c': 1}], 'd': [2, 3, 4]})
   self.assertBuildsConst(([], [0., 0., 0.]), block, None)
Exemplo n.º 11
0
 def test_mean_2vector(self):
   mean_2vectors = tdb.Map(tdb.Vector(2)) >> tdb.Mean()
   self.assertBuilds([1.0, 4.0], mean_2vectors,
                     [[1.0, 5.0], [2.0, 3.0], [0.0, 4.0]], max_depth=3)
Exemplo n.º 12
0
 def test_max_2vectors(self):
   max_2vectors = tdb.Map(tdb.Vector(2)) >> tdb.Max()
   self.assertBuilds([2.0, 5.0], max_2vectors,
                     [[1.0, 5.0], [2.0, 3.0], [0.0, 4.0]], max_depth=2)
Exemplo n.º 13
0
 def test_sum_2vectors(self):
   sum_2vectors = tdb.Map(tdb.Vector(2)) >> tdb.Sum()
   self.assertBuilds([3.0, 12.0], sum_2vectors,
                     [[1.0, 5.0], [2.0, 3.0], [0.0, 4.0]], max_depth=2)
Exemplo n.º 14
0
 def test_map_pyobject_type_inference(self):
   b = tdb.Map(tdb.Identity()) >> tdb.Vector(2)
   self.assertBuildsConst([1., 2.], b, [1, 2])
Exemplo n.º 15
0
 def test_function_tuple_in_out(self):
   # f(a, (b, c)) := ((c, a), b)
   b = ((tdb.Vector(1), (tdb.Vector(2), tdb.Vector(3))) >>
        tdb.Function(lambda x, y: ((y[1], x), y[0])))
   self.assertBuilds((([4., 5., 6.], [1.]), [2., 3.]), b,
                     ([1], ([2, 3], [4, 5, 6])))
Exemplo n.º 16
0
  def test_repr(self):
    goldens = {
        tdb.Tensor([]): '<td.Tensor dtype=\'float32\' shape=()>',
        tdb.Tensor([1, 2], 'int32', name='foo'):
        '<td.Tensor \'foo\' dtype=\'int32\' shape=(1, 2)>',

        tdb.Scalar('int64'): '<td.Scalar dtype=\'int64\'>',

        tdb.Vector(42): '<td.Vector dtype=\'float32\' size=42>',

        tdb.FromTensor(tf.zeros(3)): '<td.FromTensor \'zeros:0\'>',

        tdb.Function(tf.negative,
                     name='foo'): '<td.Function \'foo\' tf_fn=\'negative\'>',

        tdb.Identity(): '<td.Identity>',
        tdb.Identity('foo'): '<td.Identity \'foo\'>',

        tdb.InputTransform(ord): '<td.InputTransform py_fn=\'ord\'>',

        tdb.SerializedMessageToTree('foo'):
        '<td.SerializedMessageToTree \'foo\' '
        'py_fn=\'serialized_message_to_tree\'>',

        tdb.GetItem(3, 'mu'): '<td.GetItem \'mu\' key=3>',

        tdb.Length(): '<td.Length dtype=\'float32\'>',

        tdb.Slice(stop=2): '<td.Slice key=slice(None, 2, None)>',
        tdb.Slice(stop=2, name='x'):
        '<td.Slice \'x\' key=slice(None, 2, None)>',

        tdb.ForwardDeclaration(name='foo')():
        '<td.ForwardDeclaration() \'foo\'>',

        tdb.Composition(name='x').input: '<td.Composition.input \'x\'>',
        tdb.Composition(name='x').output: '<td.Composition.output \'x\'>',
        tdb.Composition(name='x'): '<td.Composition \'x\'>',

        tdb.Pipe(): '<td.Pipe>',
        tdb.Pipe(tdb.Scalar(), tdb.Identity()): '<td.Pipe>',

        tdb.Record({}, name='x'): '<td.Record \'x\' ordered=False>',
        tdb.Record((), name='x'): '<td.Record \'x\' ordered=True>',

        tdb.AllOf(): '<td.AllOf>',
        tdb.AllOf(tdb.Identity()): '<td.AllOf>',
        tdb.AllOf(tdb.Identity(), tdb.Identity()): '<td.AllOf>',

        tdb.AllOf(name='x'): '<td.AllOf \'x\'>',
        tdb.AllOf(tdb.Identity(), name='x'): '<td.AllOf \'x\'>',
        tdb.AllOf(tdb.Identity(), tdb.Identity(), name='x'): '<td.AllOf \'x\'>',

        tdb.Map(tdb.Scalar(), name='x'):
        '<td.Map \'x\' element_block=<td.Scalar dtype=\'float32\'>>',

        tdb.Fold(tdb.Function(tf.add), tf.ones([]), name='x'):
        '<td.Fold \'x\' combine_block=<td.Function tf_fn=\'add\'> '
        'start_block=<td.FromTensor \'ones:0\'>>',

        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8))):
        '<td.RNN>',
        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8)), name='x'):
        '<td.RNN \'x\'>',
        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8)),
                initial_state=tf.ones(8)):
        '<td.RNN>',
        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8)),
                initial_state=tf.ones(8), name='x'):
        '<td.RNN \'x\'>',

        tdb.Reduce(tdb.Function(tf.add), name='x'):
        '<td.Reduce \'x\' combine_block=<td.Function tf_fn=\'add\'>>',

        tdb.Sum(name='foo'):
        '<td.Sum \'foo\' combine_block=<td.Function tf_fn=\'add\'>>',

        tdb.Min(name='foo'):
        '<td.Min \'foo\' combine_block=<td.Function tf_fn=\'minimum\'>>',

        tdb.Max(name='foo'):
        '<td.Max \'foo\' combine_block=<td.Function tf_fn=\'maximum\'>>',

        tdb.Mean(name='foo'): '<td.Mean \'foo\'>',

        tdb.OneOf(ord, (tdb.Scalar(), tdb.Scalar()), name='x'):
        '<td.OneOf \'x\'>',

        tdb.Optional(tdb.Scalar(), name='foo'):
        '<td.Optional \'foo\' some_case_block=<td.Scalar dtype=\'float32\'>>',

        tdb.Concat(1, True, 'x'):
        '<td.Concat \'x\' concat_dim=1 flatten=True>',

        tdb.Broadcast(name='x'): '<td.Broadcast \'x\'>',

        tdb.Zip(name='x'): '<td.Zip \'x\'>',

        tdb.NGrams(n=42, name='x'): '<td.NGrams \'x\' n=42>',

        tdb.OneHot(2, 3, name='x'):
        '<td.OneHot \'x\' dtype=\'float32\' start=2 stop=3>',
        tdb.OneHot(3): '<td.OneHot dtype=\'float32\' start=0 stop=3>',

        tdb.OneHotFromList(['a', 'b']): '<td.OneHotFromList>',
        tdb.OneHotFromList(['a', 'b'], name='foo'):
        '<td.OneHotFromList \'foo\'>',

        tdb.Nth(name='x'): '<td.Nth \'x\'>',

        tdb.Zeros([], 'x'): '<td.Zeros \'x\'>',

        tdb.Void(): '<td.Void>',
        tdb.Void('foo'): '<td.Void \'foo\'>',

        tdm.Metric('foo'): '<td.Metric \'foo\'>'}
    for block, expected_repr in sorted(six.iteritems(goldens),
                                       key=lambda kv: kv[1]):
      self.assertEqual(repr(block), expected_repr)