Пример #1
0
 def test_dense_param_sharing(self):
     model1 = combinators.Serial(core.Dense(32), core.Dense(32))
     layer = core.Dense(32)
     model2 = combinators.Serial(layer, layer)
     rng = random.PRNGKey(0)
     params1, _ = model1.initialize((1, 32), onp.float32, rng)
     params2, _ = model2.initialize((1, 32), onp.float32, rng)
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
Пример #2
0
 def test_serial_dup_dup(self):
     layer = cb.Serial(cb.Dup(), cb.Dup())
     input_shape = (3, 2)
     expected_shape = ((3, 2), (3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Пример #3
0
 def test_serial_div_div(self):
     layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0))
     input_shape = (3, 2)
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Пример #4
0
 def test_serial_no_op_list(self):
     layer = cb.Serial([])
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)