コード例 #1
0
def test_save():
    """A DataSet object can be saved to a CSV, Pickle, or Matlab file."""
    a = DataSet(np.random.rand(8, 5))
    a.save('test.mat')
    if SCIPY_AVAILABLE:
        assert exists('test.mat')
        os.remove('test.mat')
コード例 #2
0
def test_fromMatlab():
    """A DataSet object can be constructed from a Matlab file."""
    a = np.random.rand(8, 5)
    DataSet(a).save('test.mat')
    b = DataSet.fromMatlab('test.mat')
    if SCIPY_AVAILABLE:
        assert a.shape == b.shape
        os.remove('test.mat')
    else:
        assert len(b) == 0
コード例 #3
0
def test_RBM_learn():
    """A RBM can learn to generate samples from a dataset."""
    net = RBM(16, 15)
    trainset = DataSet.fromWhatever('top_left')
    trainer = CDTrainer(net)
    for err in trainer.run(trainset):
        pass
コード例 #4
0
def test_DBN_error():
    """The reconstruction error of a DBN is less than 1."""
    net = DBN([RBM(16, 15), RBM(15, 10)], 'Test')
    trainset = DataSet.fromWhatever('top_left')
    for train_info in net.learn(trainset):
        pass
    mean_err = np.sqrt(((trainset[0] - net.evaluate(trainset[0]))**2).mean())
    assert mean_err <= 1
コード例 #5
0
def test_RBM_error():
    """The reconstruction error of a RBM is less than 1."""
    net = RBM(16, 15)
    trainset = DataSet.fromWhatever('top_left')
    trainer = CDTrainer(net)
    for err in trainer.run(trainset):
        pass
    mean_err = np.sqrt(((trainset[0] - net.evaluate(trainset[0]))**2).mean())
    assert mean_err <= 1
コード例 #6
0
def getInput(request):
    """Return a specific input image of a specific dataset."""
    if not authorized(request.GET['pass']):
        return HttpResponse(status = 401)

    dataset_name = request.GET['dataset']
    if dataset_name not in datasets_cache:
        datasets_cache[dataset_name] = DataSet.fromWhatever(dataset_name)
        print('cached', dataset_name, 'dataset')
    dataset = datasets_cache[dataset_name]
    index = int(request.GET['index'])
    if index < 0:
        index = random.randint(0, len(dataset) - 1)

    image = dataset[index].tolist()
    response = heatmap(image)
    json_response = json.dumps(response)
    return HttpResponse(json_response, content_type = 'application/json')
コード例 #7
0
def test_DBN_learn():
    """A DBN can learn to generate samples from a dataset."""
    net = DBN([RBM(16, 15), RBM(15, 10)], 'Test')
    trainset = DataSet.fromWhatever('top_left')
    for train_info in net.learn(trainset):
        pass
コード例 #8
0
def train(request):
    """Set up a network to be trained according to
    the parameters in the HTTP request."""
    print(request.POST['pass'])
    if not authorized(request.POST['pass']):
        return HttpResponse(status = 401)

    trainset_name = request.POST['dataset']
    trainset = DataSet.fromWhatever(trainset_name)

    try:
        num_layers = 1 + int(request.POST['num_hid_layers'])
    except ValueError:
        num_layers = 1
        # return HttpResponse({'error': 'you haven\'t specified [...]'})

    std_dev = float(request.POST['std_dev'])
    net = DBN(name = trainset_name)
    vis_size = int(request.POST['vis_sz'])
    for layer in range(1, num_layers):
        hid_size = int(request.POST['hid_sz_' + str(layer)])
        print('creating a', vis_size, 'x', hid_size, 'RBM...')
        net.append(RBM(vis_size, hid_size, std_dev = std_dev))
        vis_size = hid_size # for constructing the next RBM

    epochs = request.POST['epochs']
    config = {
        'max_epochs'   : int(epochs if (epochs != 'inf') else maxsize),
        'batch_size'   : int(request.POST['batch_size']),
        'learn_rate'   : float(request.POST['learn_rate']),
        'momentum'     : float(request.POST['momentum']),
        'std_dev'      : std_dev,
        'spars_target' : float(request.POST['spars_target'])
    }

    # sanity check for batch size:
    if len(trainset) % config['batch_size'] != 0:
        print('encountered batch size', config['batch_size'], 'for dataset with', len(trainset), 'examples: adjusting batch size to', end = ' ')
        while len(trainset) % config['batch_size'] != 0:
            config['batch_size'] -= 1
        print(config['batch_size'])

    random_id = ''.join(random.choice(string.ascii_uppercase + string.digits) for i in range(10))
    training_jobs[random_id] = {
        'birthday': time(),
        'network': net,
        'generator': net.learn(trainset, Configuration(**config))
    }

    # delete the old client job that is being replaced (if any):
    last_job = request.POST['last_job_id']
    if last_job in training_jobs:
        del training_jobs[last_job]

    # delete a random pending job older than five hours:
    random_old_job = random.choice(list(training_jobs.keys()))
    if time() - training_jobs[random_old_job]['birthday'] > 18000:
        print('deleting old job n.', random_old_job)
        del training_jobs[random_old_job] # risky...

    return HttpResponse(random_id)
コード例 #9
0

# password checking:
PASSWORD = '******'
def authorized(password):
    return password == PASSWORD

# pending training jobs on the server:
training_jobs = {}

# actual datasets, for caching input examples:
datasets_cache = {}

# datasets available on the server:
datasets_info = {}  # datasets shapes
datasets_name = sorted(DataSet.allSets(), key = lambda s: s.lower())
for d in datasets_name:
    try:
        datasets_info[d] = DataSet.fromWhatever(d).shape[1]
    except (pickle.UnpicklingError, IndexError):
        pass



def index(request):
    """Return the main page."""
    context = {
        'config': Configuration(), # default configuration
        'datasets': datasets_info  # available datasets
    }
    return render(request, 'DBNtrain/index.html', context)
コード例 #10
0
def test_fromWhatever():
    """A DataSet object can be constructed from a file."""
    dataset = DataSet.fromWhatever('top_left')
    assert dataset.shape == (10, 16)
コード例 #11
0
def test_fromPickle():
    """A DataSet object can be constructed from a Pickle file."""
    dataset = DataSet.fromPickle(full('top_left.pkl'))
    assert dataset.shape == (10, 16)
コード例 #12
0
def test_fromCSV():
    """A DataSet object can be constructed from a CSV file."""
    dataset = DataSet.fromCSV(full('top_left.csv'))
    assert dataset.shape == (10, 16)
コード例 #13
0
def test_DataSetConstructor():
    """A DataSet object can be constructed from an array."""
    data = [1, 2, 3, 4, 5, 6, 7]
    dataset = DataSet(data)
    assert all(dataset.data[i] == data[i] for i in range(len(data)))