コード例 #1
0
ファイル: test_interface.py プロジェクト: baylabs/hdl
lena = lena() / 256.0
height, width = lena.shape

# Extract all clean patches from the left half of the image
print 'Extracting clean patches...'
t0 = time()
patch_size = (6, 6)
data = extract_patches_2d(lena, patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print 'done in %.2fs.' % (time() - t0)

###############################################################################
# Learn the dictionary from clean patches

print 'Learning the dictionary SparseCoding... '
t0 = time()
np.random.seed(42)
dico = SparseCoding(n_atoms=36, alpha=.1, max_iter=200)
V = dico.fit(data).components_
print 'done in %.2fs.' % (time() - t0)

print 'Transform the data...'
t0 = time()
coef = dico.transform(data[:10,:])
print 'done in %.2fs.' % (time() - t0)

np.testing.assert_almost_equal(coef[:10,:].sum(),3.23775956895)
np.testing.assert_almost_equal(V[:,:10].sum(),39.9447109218)
コード例 #2
0
ファイル: demo_interface.py プロジェクト: baylabs/hdl
# Extract all clean patches from the left half of the image
print 'Extracting clean patches...'
t0 = time()
patch_size = (7, 7)
data = extract_patches_2d(distorted[:, :height / 2], patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print 'done in %.2fs.' % (time() - t0)

###############################################################################
# Learn the dictionary from clean patches (new way)

print 'Learning the dictionary SparseCoding... '
t0 = time()
dico = SparseCoding(n_atoms=100, alpha=1., max_iter=10000)
V = dico.fit(data).components_
dt = time() - t0
print 'done in %.2fs.' % dt

pl.figure(figsize=(4.2, 4))
for i, comp in enumerate(V[:100]):
    pl.subplot(10, 10, i + 1)
    pl.imshow(comp.reshape(patch_size), cmap=pl.cm.gray_r,
        interpolation='nearest')
    pl.xticks(())
    pl.yticks(())
pl.suptitle('Dictionary learned from Lena patches\n' +
            'Train time %.1fs on %d patches' % (dt, len(data)),
    fontsize=16)
pl.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)