def test_save_network_5_4_3_2_1(db_session): expected_attributes = { 'alpha': 0.25, 'momentum': 0.5, 'out_sigmoided': False, 'weights': [ numpy.array([ [11, 12, 13, 14, 15, 16], [17, 18, 19, 20, 21, 22], [23, 24, 25, 26, 27, 28], [29, 30, 31, 32, 33, 34], ]), numpy.array([ [-35, 36, -37, 38, -39], [-40, 41, -42, 43, -44], [-45, 46, -47, 48, -49], ]), numpy.array([ [50.0, -50.1, 50.2, -50.3], [50.4, -50.5, 50.6, -50.7], ]), numpy.array([ [50.8, 50.9, 51.0], ]), ], } network_name = '5_4_3_2_1' perceptron = MLMCPerceptron(**expected_attributes) db_save_network(db_session, perceptron, network_name) network = db_session.query(Network).filter_by(name=network_name).one() assert_network_equal(network, expected_attributes)
def test_save_network_2_2_1(db_session): expected_attributes = { 'alpha': 0.1, 'momentum': 0.8, 'out_sigmoided': True, 'weights': [ numpy.array([ [1, 2, 3], [4, 5, 6], ]), numpy.array([ [7, 8, 9], ]), ], } network_name = '2_2_1' perceptron = MLMCPerceptron(**expected_attributes) db_save_network(db_session, perceptron, name=network_name) network = db_session.query(Network).filter_by(name=network_name).one() assert_network_equal(network, expected_attributes)
def test_update_network_2_2(db_session): network_name = '2_2' old_network_attributes = { 'alpha': 0.1, 'momentum': 0.8, 'out_sigmoided': True, 'weights': [ numpy.array([ [50.1, -50.2, 50.3], [50.4, -50.5, 50.6], ]), ] } perceptron = MLMCPerceptron(**old_network_attributes) db_save_network(db_session, perceptron, network_name) network = db_session.query(Network).filter_by(name=network_name).one() assert_network_equal(network, old_network_attributes) new_network_attributes = { 'alpha': 0.3, 'momentum': 0.1, 'out_sigmoided': False, 'weights': [ numpy.array([ [500.1, -500.2, 500.3], [500.4, -500.5, 500.6], ]), ], } perceptron = MLMCPerceptron(**new_network_attributes) db_update_network(db_session, perceptron, network_name) after_update_attributes = db_load_network(db_session, network_name) perceptron = MLMCPerceptron(**after_update_attributes) assert_perceptron_equal(perceptron, new_network_attributes)
def load_from_db(self, db_session): try: network_attributes = db_load_network(db_session, self.db_name) self.perceptron = MLMCPerceptron(**network_attributes) except NoResultFound: message_fmt = 'Network {name!r} not found in db. Creating new...' print message_fmt.format(name=self.db_name), assert self.SIZES_NAME_RE.match(self.db_name) sizes = [int(size) for size in self.db_name.split('_')] self.perceptron = MLMCPerceptron( sizes, alpha=0.01, exploration_probability=0.5, ) name = self.db_name network = db_save_network(db_session, self.perceptron, name) db_session.commit() print 'created'