def test_weightsHistogram(): """The histogram of the weights of a RBM is an array of pairs.""" net = RBM(16, 23) hist = net.weightsHistogram() for h in hist: assert len(h) == 2
def test_load(): """A DBN can be loaded from a Pickle file.""" net = DBN([RBM(4, 5), RBM(5, 6)], 'Test') net.save() net = DBN.load('Test') assert len(net) == 2 os.remove(nets.full('Test.pkl'))
def test_generationLength(): """A DBN can generate a sample whose length is equal to the number of visible units of its first RBM.""" net = DBN([RBM(4, 7), RBM(7, 3)]) sample = net.evaluate([0, 1, 1, 0.4]) assert len(sample) == 4
def test_saveAndLoad(): """A DBN serialised to a file is equivalent to a DBN loaded from the same file.""" net_1 = DBN([RBM(4, 5), RBM(5, 6)], 'Test') net_1.save() net_2 = DBN.load('Test') assert [np.equal(net_1[i].W, net_2[i].W) for i in range(2)] os.remove(nets.full('Test.pkl'))
def test_receptiveField(): """The receptive fields of a DBN have a shape whose area is equal to the number of visible units.""" net = DBN([RBM(49, 15), RBM(15, 26)]) rc = net.receptiveField(1, 10) assert rc.shape == (49, ) rc = net.receptiveField(2, 10) assert rc.shape == (49, )
def test_DBN_error(): """The reconstruction error of a DBN is less than 1.""" net = DBN([RBM(16, 15), RBM(15, 10)], 'Test') trainset = DataSet.fromWhatever('top_left') for train_info in net.learn(trainset): pass mean_err = np.sqrt(((trainset[0] - net.evaluate(trainset[0]))**2).mean()) assert mean_err <= 1
def test_RBM_error(): """The reconstruction error of a RBM is less than 1.""" net = RBM(16, 15) trainset = DataSet.fromWhatever('top_left') trainer = CDTrainer(net) for err in trainer.run(trainset): pass mean_err = np.sqrt(((trainset[0] - net.evaluate(trainset[0]))**2).mean()) assert mean_err <= 1
def test_RBM_learn(): """A RBM can learn to generate samples from a dataset.""" net = RBM(16, 15) trainset = DataSet.fromWhatever('top_left') trainer = CDTrainer(net) for err in trainer.run(trainset): pass
def test_DBN_learn(): """A DBN can learn to generate samples from a dataset.""" net = DBN([RBM(16, 15), RBM(15, 10)], 'Test') trainset = DataSet.fromWhatever('top_left') for train_info in net.learn(trainset): pass
def test_init(): """A DBN can be constructed from a list of RBMs.""" net = DBN([RBM(4, 7), RBM(7, 3)], 'Test') assert len(net) == 2
def test_append(): """A RBM can be appended to a DBN.""" net = DBN() net.append(RBM(3, 4)) net.append(RBM(4, 7)) assert len(net) == 2
def train(request): """Set up a network to be trained according to the parameters in the HTTP request.""" print(request.POST['pass']) if not authorized(request.POST['pass']): return HttpResponse(status = 401) trainset_name = request.POST['dataset'] trainset = DataSet.fromWhatever(trainset_name) try: num_layers = 1 + int(request.POST['num_hid_layers']) except ValueError: num_layers = 1 # return HttpResponse({'error': 'you haven\'t specified [...]'}) std_dev = float(request.POST['std_dev']) net = DBN(name = trainset_name) vis_size = int(request.POST['vis_sz']) for layer in range(1, num_layers): hid_size = int(request.POST['hid_sz_' + str(layer)]) print('creating a', vis_size, 'x', hid_size, 'RBM...') net.append(RBM(vis_size, hid_size, std_dev = std_dev)) vis_size = hid_size # for constructing the next RBM epochs = request.POST['epochs'] config = { 'max_epochs' : int(epochs if (epochs != 'inf') else maxsize), 'batch_size' : int(request.POST['batch_size']), 'learn_rate' : float(request.POST['learn_rate']), 'momentum' : float(request.POST['momentum']), 'std_dev' : std_dev, 'spars_target' : float(request.POST['spars_target']) } # sanity check for batch size: if len(trainset) % config['batch_size'] != 0: print('encountered batch size', config['batch_size'], 'for dataset with', len(trainset), 'examples: adjusting batch size to', end = ' ') while len(trainset) % config['batch_size'] != 0: config['batch_size'] -= 1 print(config['batch_size']) random_id = ''.join(random.choice(string.ascii_uppercase + string.digits) for i in range(10)) training_jobs[random_id] = { 'birthday': time(), 'network': net, 'generator': net.learn(trainset, Configuration(**config)) } # delete the old client job that is being replaced (if any): last_job = request.POST['last_job_id'] if last_job in training_jobs: del training_jobs[last_job] # delete a random pending job older than five hours: random_old_job = random.choice(list(training_jobs.keys())) if time() - training_jobs[random_old_job]['birthday'] > 18000: print('deleting old job n.', random_old_job) del training_jobs[random_old_job] # risky... return HttpResponse(random_id)