-
Notifications
You must be signed in to change notification settings - Fork 9
/
fista_cifar_train.py
67 lines (61 loc) · 1.97 KB
/
fista_cifar_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import matplotlib
matplotlib.use('Agg')
from dataObj.image import cifarObj
from tf.fista import FISTA
#from plot.roc import makeRocCurve
import numpy as np
import pdb
trainImageLists = "/home/slundquist/mountData/datasets/cifar/images/train.txt"
#testImageLists = "/home/slundquist/mountData/datasets/cifar/images/test.txt"
randImageSeed = None
#Get object from which tensorflow will pull data from
trainDataObj = cifarObj(trainImageLists, resizeMethod="pad", shuffle=True, seed=randImageSeed)
#testDataObj = cifarObj(testImageLists, resizeMethod="pad")
#FISTA params
params = {
#Base output directory
'outDir': "/home/slundquist/mountData/tfSparseCode/",
#Inner run directory
'runDir': "/fista_cifar_nf256/",
'tfDir': "/tfout",
#Save parameters
'ckptDir': "/checkpoints/",
'saveFile': "/save-model",
'savePeriod': 100, #In terms of displayPeriod
#output plots directory
'plotDir': "plots/",
'plotPeriod': 100, #With respect to displayPeriod
#Progress step
'progress': 100,
#Controls how often to write out to tensorboard
'writeStep': 200,
#Flag for loading weights from checkpoint
'load': True,
'loadFile': "/home/slundquist/mountData/tfSparseCode/saved/fista_cifar_nf256.ckpt",
#Device to run on
'device': '/gpu:0',
#####FISTA PARAMS######
'numIterations': 100000,
'displayPeriod': 200,
#Batch size
'batchSize': 8,
#Learning rate for optimizer
'learningRateA': .01,
'learningRateW': 1,
#Lambda in energy function
'thresh': .005,
#Number of features in V1
'numV': 256,
#Stride of V1
'VStrideY': 2,
'VStrideX': 2,
#Patch size
'patchSizeY': 12,
'patchSizeX': 12,
}
#Allocate tensorflow object
tfObj = FISTA(params, trainDataObj)
print("Done init")
tfObj.runModel()
print("Done run")
tfObj.closeSess()