def test_save(): """A DataSet object can be saved to a CSV, Pickle, or Matlab file.""" a = DataSet(np.random.rand(8, 5)) a.save('test.mat') if SCIPY_AVAILABLE: assert exists('test.mat') os.remove('test.mat')
def test_fromMatlab(): """A DataSet object can be constructed from a Matlab file.""" a = np.random.rand(8, 5) DataSet(a).save('test.mat') b = DataSet.fromMatlab('test.mat') if SCIPY_AVAILABLE: assert a.shape == b.shape os.remove('test.mat') else: assert len(b) == 0
def test_RBM_learn(): """A RBM can learn to generate samples from a dataset.""" net = RBM(16, 15) trainset = DataSet.fromWhatever('top_left') trainer = CDTrainer(net) for err in trainer.run(trainset): pass
def test_DBN_error(): """The reconstruction error of a DBN is less than 1.""" net = DBN([RBM(16, 15), RBM(15, 10)], 'Test') trainset = DataSet.fromWhatever('top_left') for train_info in net.learn(trainset): pass mean_err = np.sqrt(((trainset[0] - net.evaluate(trainset[0]))**2).mean()) assert mean_err <= 1
def test_RBM_error(): """The reconstruction error of a RBM is less than 1.""" net = RBM(16, 15) trainset = DataSet.fromWhatever('top_left') trainer = CDTrainer(net) for err in trainer.run(trainset): pass mean_err = np.sqrt(((trainset[0] - net.evaluate(trainset[0]))**2).mean()) assert mean_err <= 1
def getInput(request): """Return a specific input image of a specific dataset.""" if not authorized(request.GET['pass']): return HttpResponse(status = 401) dataset_name = request.GET['dataset'] if dataset_name not in datasets_cache: datasets_cache[dataset_name] = DataSet.fromWhatever(dataset_name) print('cached', dataset_name, 'dataset') dataset = datasets_cache[dataset_name] index = int(request.GET['index']) if index < 0: index = random.randint(0, len(dataset) - 1) image = dataset[index].tolist() response = heatmap(image) json_response = json.dumps(response) return HttpResponse(json_response, content_type = 'application/json')
def test_DBN_learn(): """A DBN can learn to generate samples from a dataset.""" net = DBN([RBM(16, 15), RBM(15, 10)], 'Test') trainset = DataSet.fromWhatever('top_left') for train_info in net.learn(trainset): pass
def train(request): """Set up a network to be trained according to the parameters in the HTTP request.""" print(request.POST['pass']) if not authorized(request.POST['pass']): return HttpResponse(status = 401) trainset_name = request.POST['dataset'] trainset = DataSet.fromWhatever(trainset_name) try: num_layers = 1 + int(request.POST['num_hid_layers']) except ValueError: num_layers = 1 # return HttpResponse({'error': 'you haven\'t specified [...]'}) std_dev = float(request.POST['std_dev']) net = DBN(name = trainset_name) vis_size = int(request.POST['vis_sz']) for layer in range(1, num_layers): hid_size = int(request.POST['hid_sz_' + str(layer)]) print('creating a', vis_size, 'x', hid_size, 'RBM...') net.append(RBM(vis_size, hid_size, std_dev = std_dev)) vis_size = hid_size # for constructing the next RBM epochs = request.POST['epochs'] config = { 'max_epochs' : int(epochs if (epochs != 'inf') else maxsize), 'batch_size' : int(request.POST['batch_size']), 'learn_rate' : float(request.POST['learn_rate']), 'momentum' : float(request.POST['momentum']), 'std_dev' : std_dev, 'spars_target' : float(request.POST['spars_target']) } # sanity check for batch size: if len(trainset) % config['batch_size'] != 0: print('encountered batch size', config['batch_size'], 'for dataset with', len(trainset), 'examples: adjusting batch size to', end = ' ') while len(trainset) % config['batch_size'] != 0: config['batch_size'] -= 1 print(config['batch_size']) random_id = ''.join(random.choice(string.ascii_uppercase + string.digits) for i in range(10)) training_jobs[random_id] = { 'birthday': time(), 'network': net, 'generator': net.learn(trainset, Configuration(**config)) } # delete the old client job that is being replaced (if any): last_job = request.POST['last_job_id'] if last_job in training_jobs: del training_jobs[last_job] # delete a random pending job older than five hours: random_old_job = random.choice(list(training_jobs.keys())) if time() - training_jobs[random_old_job]['birthday'] > 18000: print('deleting old job n.', random_old_job) del training_jobs[random_old_job] # risky... return HttpResponse(random_id)
# password checking: PASSWORD = '******' def authorized(password): return password == PASSWORD # pending training jobs on the server: training_jobs = {} # actual datasets, for caching input examples: datasets_cache = {} # datasets available on the server: datasets_info = {} # datasets shapes datasets_name = sorted(DataSet.allSets(), key = lambda s: s.lower()) for d in datasets_name: try: datasets_info[d] = DataSet.fromWhatever(d).shape[1] except (pickle.UnpicklingError, IndexError): pass def index(request): """Return the main page.""" context = { 'config': Configuration(), # default configuration 'datasets': datasets_info # available datasets } return render(request, 'DBNtrain/index.html', context)
def test_fromWhatever(): """A DataSet object can be constructed from a file.""" dataset = DataSet.fromWhatever('top_left') assert dataset.shape == (10, 16)
def test_fromPickle(): """A DataSet object can be constructed from a Pickle file.""" dataset = DataSet.fromPickle(full('top_left.pkl')) assert dataset.shape == (10, 16)
def test_fromCSV(): """A DataSet object can be constructed from a CSV file.""" dataset = DataSet.fromCSV(full('top_left.csv')) assert dataset.shape == (10, 16)
def test_DataSetConstructor(): """A DataSet object can be constructed from an array.""" data = [1, 2, 3, 4, 5, 6, 7] dataset = DataSet(data) assert all(dataset.data[i] == data[i] for i in range(len(data)))