def build_network_from_description(description): seed = description['$seed'] if '$seed' in description else None net = build_network_from_architecture(description['architecture'], seed) init = create_from_description(description['initialization']) net.initialize(init) if 'regularization' in description and description['regularization']: reg = create_from_description(description['regularization']) net.set_regularizers(reg) if 'constraints' in description and description['constraints']: con = create_from_description(description['constraints']) net.set_constraints(con) if 'error_function' in description and description['error_function']: net.error_func = create_from_description(description['error_function']) return net
def test_create_gaussian_initializer_from_dict(self): description = { '@type': 'Gaussian', 'std': 23.0, 'mean': 7.0 } init = create_from_description(description) self.assertIsInstance(init, Gaussian) self.assertEqual(init.mean, 7.0) self.assertEqual(init.std, 23.0)
def test_sparse_inputs_initalizer_from_description(self): description = { '@type': 'SparseInputs', 'connections': 69, 'init': {'@type': 'Gaussian', 'std': 23.0, 'mean': 7.0} } init = create_from_description(description) self.assertIsInstance(init, SparseInputs) self.assertEqual(init.connections, 69) self.assertIsInstance(init.init, Gaussian) self.assertEqual(init.init.std, 23.0) self.assertEqual(init.init.mean, 7.0)
def test_custom_initializer(self): class Custom(Initializer): def __init__(self, foo): Initializer.__init__(self) self.foo = foo c = Custom('bar') descr = c.__describe__() self.assertDictEqual(descr, {'@type': 'Custom', 'foo': 'bar'}) c2 = create_from_description(descr) self.assertNotEqual(c, c2) self.assertIsInstance(c2, Custom) self.assertEqual(c2.foo, 'bar')
def test_create_initializer_from_list(self): self.assertListEqual(create_from_description([1]), [1]) self.assertListEqual(create_from_description([1, 2, 3]), [1, 2, 3]) self.assertListEqual(create_from_description([1.0, -1.0]), [1.0, -1.0])
def test_create_initializer_from_plain_number(self): self.assertEqual(create_from_description(0.0), 0.0) self.assertEqual(create_from_description(1.0), 1.0) self.assertEqual(create_from_description(5), 5)