import matplotlib
matplotlib.use('Agg')
from dataObj.image import pvpObj
from tf.slp_sparse_code import SLP
import numpy as np
import pdb

#Paths to list of filenames
trainFileList = "/home/slundquist/mountData/tfSparseCode/fista_cifar_nf256_eval/fista_train_cifar_256_eval.pvp"
trainGtList =  "/home/slundquist/mountData/datasets/cifar/images/train.txt"

testFileList = "/home/slundquist/mountData/tfSparseCode/fista_cifar_nf256_eval/fista_test_cifar_256_eval.pvp"
testGtList =  "/home/slundquist/mountData/datasets/cifar/images/test.txt"

#Get object from which tensorflow will pull data from
trainDataObj = pvpObj(trainFileList, trainGtList, (16, 16, 256), resizeMethod="crop", shuffle=True, skip=1, seed=None, rangeIdx=range(128))
testDataObj = pvpObj(testFileList, testGtList, (16, 16, 256), resizeMethod="crop", shuffle=True, skip=1, seed=None)

params = {
    #Base output directory
    'outDir':          "/home/slundquist/mountData/tfSparseCode/",
    #Inner run directory
    'runDir':          "/cifar_fista_slp_limited/",
    'tfDir':           "/tfout",
    #Save parameters
    'ckptDir':         "/checkpoints/",
    'saveFile':        "/save-model",
    'savePeriod':      100, #In terms of displayPeriod
    #output plots directory
    'plotDir':         "plots/",
    'plotPeriod':      100, #With respect to displayPeriod
import matplotlib
matplotlib.use('Agg')
from dataObj.image import pvpObj
from tf.mlp_sparse_code import MLPSP
import numpy as np
import pdb

#Paths to list of filenames
trainFileList = "/home/slundquist/mountData/tfSparseCode/fista_cifar_nf256_eval/fista_train_cifar_256_eval.pvp"
trainGtList =  "/home/slundquist/mountData/datasets/cifar/images/train.txt"

testFileList = "/home/slundquist/mountData/tfSparseCode/fista_cifar_nf256_eval/fista_test_cifar_256_eval.pvp"
testGtList =  "/home/slundquist/mountData/datasets/cifar/images/test.txt"

#Get object from which tensorflow will pull data from
trainDataObj = pvpObj(trainFileList, trainGtList, (16, 16, 256), resizeMethod="crop", shuffle=True, skip=1, seed=None)
testDataObj = pvpObj(testFileList, testGtList, (16, 16, 256), resizeMethod="crop", shuffle=True, skip=1, seed=None)

params = {
    #Base output directory
    'outDir':          "/home/slundquist/mountData/tfSparseCode/",
    #Inner run directory
    'runDir':          "/cifar_fista_mlp_nf256/",
    'tfDir':           "/tfout",
    #Save parameters
    'ckptDir':         "/checkpoints/",
    'saveFile':        "/save-model",
    'savePeriod':      100, #In terms of displayPeriod
    #output plots directory
    'plotDir':         "plots/",
    'plotPeriod':      100, #With respect to displayPeriod
trainGtList =  "/home/slundquist/mountData/cifar_pv/cifar_trainset/cifar_train.txt"

#testFileList = "/home/slundquist/mountData/cifar_pv/cifar_testset/S4.pvp"
#testFileList = "/home/slundquist/mountData/cifar_pv/cifar_testset/S3.pvp"
#testFileList = "/home/slundquist/mountData/cifar_pv/cifar_testset/S2.pvp"
testFileList = "/home/slundquist/mountData/cifar_pv/cifar_testset/S1.pvp"
testGtList =  "/home/slundquist/mountData/cifar_pv/cifar_testset/cifar_test.txt"

#Get object from which tensorflow will pull data from
#trainDataObj = pvpObj(trainFileList, trainGtList, (1, 1, 1536), resizeMethod="crop", shuffle=True, skip=1, seed=None)
#testDataObj = pvpObj(testFileList, testGtList, (1, 1, 1536), resizeMethod="crop", shuffle=True, skip=1, seed=None)
#trainDataObj = pvpObj(trainFileList, trainGtList, (4, 4, 384), resizeMethod="crop", shuffle=True, skip=1, seed=None)
#testDataObj = pvpObj(testFileList, testGtList, (4, 4, 384), resizeMethod="crop", shuffle=True, skip=1, seed=None)
#trainDataObj = pvpObj(trainFileList, trainGtList, (8, 8, 96), resizeMethod="crop", shuffle=True, skip=1, seed=None)
#testDataObj = pvpObj(testFileList, testGtList, (8, 8, 96), resizeMethod="crop", shuffle=True, skip=1, seed=None)
trainDataObj = pvpObj(trainFileList, trainGtList, (16, 16, 24), resizeMethod="crop", shuffle=True, skip=1, seed=None)
testDataObj = pvpObj(testFileList, testGtList, (16, 16, 24), resizeMethod="crop", shuffle=True, skip=1, seed=None)

params = {
    #Base output directory
    'outDir':          "/home/slundquist/mountData/tfLCA/",
    #Inner run directory
    'runDir':          "/cifar_lca_topdown_s1only/",
    'tfDir':           "/tfout",
    #Save parameters
    'ckptDir':         "/checkpoints/",
    'saveFile':        "/save-model",
    'savePeriod':      100, #In terms of displayPeriod
    #output plots directory
    'plotDir':         "plots/",
    'plotPeriod':      100, #With respect to displayPeriod