def test_forward_declaration_orphaned_nested(self): fwd1 = tdb.ForwardDeclaration(tdt.VoidType(), tdt.TensorType([])) fwd2 = tdb.ForwardDeclaration(tdt.SequenceType(tdt.TensorType([])), tdt.TensorType([])) b = tdb.Map(tdb.Scalar()) >> fwd2() >> tdb.Function(tf.negative) fwd2.resolve_to(tdb.Fold(tdb.Function(tf.add), fwd1())) fwd1.resolve_to(tdb.FromTensor(tf.ones([]))) self.assertBuilds(-8., b, [3, 4], max_depth=3)
def test_composition_raises_cycle(self): fn1 = times_scalar_block(2.0) fn2 = times_scalar_block(3.0) c = tdb.Composition([fn2, fn1]).set_input_type(tdt.VoidType()) c.connect(fn1, fn2) c.connect(fn2, c.output) c.connect(fn2, fn1) # cycle self.assertRaisesWithLiteralMatch( ValueError, 'Composition cannot have cycles.', c._validate, None)
def __init__(self, metric_name): self._metric_name = metric_name super(Metric, self).__init__(name=str(metric_name), output_type=tdt.VoidType())
def test_conversion(self): t = tdt.VoidType() self.assertEqual(repr(t), 'VoidType()')
def test_terminal_types(self): t = tdt.SequenceType( tdt.TupleType(tdt.TensorType([]), tdt.VoidType(), tdt.TupleType(tdt.PyObjectType()))) t_elem = t.element_type self.assertEqual(list(t.terminal_types()), [t_elem[0], t_elem[2][0]])
def test_eval_void(self): block = tdb.Identity().set_input_type(tdt.VoidType()) self.assertBuildsConst(None, block, None)
def test_zeros_void(self): block = tdb.Zeros(tdt.TupleType(tdt.VoidType(), tdt.TensorType(()))) self.assertBuildsConst((None, 0.0), block, None)
def test_forward_declaration_orphaned(self): fwd = tdb.ForwardDeclaration(tdt.VoidType(), tdt.TensorType([])) b = tdb.AllOf(fwd(), fwd()) >> tdb.Sum() fwd.resolve_to(tdb.FromTensor(tf.ones([]))) self.assertBuilds(2., b, None)
def test_composition_backward_type_inference(self): b = tdb.Map(tdb.Identity()) >> tdb.Identity() >> tdb.Identity() six.assertRaisesRegex( self, TypeError, 'bad output type VoidType', b.output.set_output_type, tdt.VoidType())
def test_composition_slice(self): c1 = tdb.Composition().set_input_type(tdt.VoidType()) with c1.scope(): t = tdb.AllOf(*[np.array(t) for t in range(5)]).reads(c1.input) c1.output.reads(tdb.Function(tf.add).reads(t[1:-1:2])) self.assertBuilds(4, c1, None, max_depth=1)
def __init__(self, metric_name): if not isinstance(metric_name, six.string_types): raise TypeError('metric_name must be a string: %s' % (metric_name,)) self._metric_name = metric_name super(Metric, self).__init__(name=str(metric_name), output_type=tdt.VoidType())