Пример #1
0
 def test_forward_declaration_orphaned_nested(self):
   fwd1 = tdb.ForwardDeclaration(tdt.VoidType(), tdt.TensorType([]))
   fwd2 = tdb.ForwardDeclaration(tdt.SequenceType(tdt.TensorType([])),
                                 tdt.TensorType([]))
   b = tdb.Map(tdb.Scalar()) >> fwd2() >> tdb.Function(tf.negative)
   fwd2.resolve_to(tdb.Fold(tdb.Function(tf.add), fwd1()))
   fwd1.resolve_to(tdb.FromTensor(tf.ones([])))
   self.assertBuilds(-8., b, [3, 4], max_depth=3)
Пример #2
0
 def test_composition_raises_cycle(self):
   fn1 = times_scalar_block(2.0)
   fn2 = times_scalar_block(3.0)
   c = tdb.Composition([fn2, fn1]).set_input_type(tdt.VoidType())
   c.connect(fn1, fn2)
   c.connect(fn2, c.output)
   c.connect(fn2, fn1)   # cycle
   self.assertRaisesWithLiteralMatch(
       ValueError, 'Composition cannot have cycles.', c._validate, None)
Пример #3
0
 def __init__(self, metric_name):
     self._metric_name = metric_name
     super(Metric, self).__init__(name=str(metric_name),
                                  output_type=tdt.VoidType())
Пример #4
0
 def test_conversion(self):
     t = tdt.VoidType()
     self.assertEqual(repr(t), 'VoidType()')
Пример #5
0
 def test_terminal_types(self):
     t = tdt.SequenceType(
         tdt.TupleType(tdt.TensorType([]), tdt.VoidType(),
                       tdt.TupleType(tdt.PyObjectType())))
     t_elem = t.element_type
     self.assertEqual(list(t.terminal_types()), [t_elem[0], t_elem[2][0]])
Пример #6
0
 def test_eval_void(self):
   block = tdb.Identity().set_input_type(tdt.VoidType())
   self.assertBuildsConst(None, block, None)
Пример #7
0
 def test_zeros_void(self):
   block = tdb.Zeros(tdt.TupleType(tdt.VoidType(), tdt.TensorType(())))
   self.assertBuildsConst((None, 0.0), block, None)
Пример #8
0
 def test_forward_declaration_orphaned(self):
   fwd = tdb.ForwardDeclaration(tdt.VoidType(), tdt.TensorType([]))
   b = tdb.AllOf(fwd(), fwd()) >> tdb.Sum()
   fwd.resolve_to(tdb.FromTensor(tf.ones([])))
   self.assertBuilds(2., b, None)
Пример #9
0
 def test_composition_backward_type_inference(self):
   b = tdb.Map(tdb.Identity()) >> tdb.Identity() >> tdb.Identity()
   six.assertRaisesRegex(
       self, TypeError, 'bad output type VoidType',
       b.output.set_output_type, tdt.VoidType())
Пример #10
0
 def test_composition_slice(self):
   c1 = tdb.Composition().set_input_type(tdt.VoidType())
   with c1.scope():
     t = tdb.AllOf(*[np.array(t) for t in range(5)]).reads(c1.input)
     c1.output.reads(tdb.Function(tf.add).reads(t[1:-1:2]))
   self.assertBuilds(4, c1, None, max_depth=1)
Пример #11
0
 def __init__(self, metric_name):
   if not isinstance(metric_name, six.string_types):
     raise TypeError('metric_name must be a string: %s' % (metric_name,))
   self._metric_name = metric_name
   super(Metric, self).__init__(name=str(metric_name),
                                output_type=tdt.VoidType())