image = tf.placeholder(tf.float32,shape=[opt.batchSize,opt.H,opt.W]) label = tf.placeholder(tf.int64,shape=[opt.batchSize]) # ------ generate perturbation ------ pInit = data.genPerturbations(opt) pInitMtrx = warp.vec2mtrx(opt,pInit) # ------ build network ------ image = tf.expand_dims(image,axis=-1) imagePert = warp.transformImage(opt,image,pInitMtrx) if opt.netType=="CNN": output = graph.fullCNN(opt,imagePert) elif opt.netType=="STN": imageWarpAll = graph.STN(opt,imagePert) imageWarp = imageWarpAll[-1] output = graph.CNN(opt,imageWarp) elif opt.netType=="IC-STN": imageWarpAll = graph.ICSTN(opt,image,pInit) imageWarp = imageWarpAll[-1] output = graph.CNN(opt,imageWarp) softmax = tf.nn.softmax(output) labelOnehot = tf.one_hot(label,opt.labelN) prediction = tf.equal(tf.argmax(softmax,1),label) # load data print(util.toMagenta("loading MNIST dataset...")) trainData,validData,testData = data.loadMNIST("data/MNIST.npz") # prepare model saver/summary writer saver = tf.train.Saver(max_to_keep=20) print(util.toYellow("======= EVALUATION START =======")) timeStart = time.time()
opt = options.set(training=True) # create directories for model output util.mkdir("models_{0}".format(opt.group)) print(util.toMagenta("building network...")) with torch.cuda.device(0): # ------ build network ------ if opt.netType == "CNN": geometric = graph.Identity() classifier = graph.FullCNN(opt) elif opt.netType == "STN": geometric = graph.STN(opt) classifier = graph.CNN(opt) elif opt.netType == "IC-STN": geometric = graph.ICSTN(opt) classifier = graph.CNN(opt) # ------ define loss ------ loss = torch.nn.CrossEntropyLoss() # ------ optimizer ------ optimList = [{ "params": geometric.parameters(), "lr": opt.lrGP }, { "params": classifier.parameters(), "lr": opt.lrC }] optim = torch.optim.SGD(optimList) # load data print(util.toMagenta("loading MNIST dataset..."))
# generate training data on the fly imageRawBatch = tf.placeholder(tf.float32,shape=[None,28,28],name="image") pInitBatch = data.genPerturbations(opt) pInitMtrxBatch = warp.vec2mtrxBatch(pInitBatch,opt) ImBatch = data.imageWarpIm(imageRawBatch,pInitMtrxBatch,opt,name=None) # build network if opt.type=="CNN": outputBatch = graph.fullCNN(opt,ImBatch,[3,6,9,12,48],0.1) elif opt.type=="STN": ImWarpBatch,pBatch = graph.STN(opt,ImBatch,pInitBatch,1,[4,8,48],0.01) outputBatch = graph.CNN(opt,ImWarpBatch,[3],0.03) elif opt.type=="cSTN": ImWarpBatch,pBatch = graph.cSTN(opt,imageRawBatch,pInitBatch,1,[4,8,48],0.01) outputBatch = graph.CNN(opt,ImWarpBatch,[3],0.03) elif opt.type=="ICSTN": ImWarpBatch,pBatch = graph.ICSTN(opt,imageRawBatch,pInitBatch,opt.recurN,[4,8,48],0.01) outputBatch = graph.CNN(opt,ImWarpBatch,[3],0.03) # define loss/optimizer/summaries imageSummaries = tf.summary.merge_all() labelBatch = tf.placeholder(tf.float32,shape=[None,10],name="label") softmaxLoss = tf.nn.softmax_cross_entropy_with_logits(logits=outputBatch,labels=labelBatch) loss = tf.reduce_mean(softmaxLoss) lossSummary = tf.summary.scalar("training loss",loss) learningRate = tf.placeholder(tf.float32,shape=[2]) trainStep = util.setOptimizer(loss,learningRate,opt) softmax = tf.nn.softmax(outputBatch) prediction = tf.equal(tf.argmax(softmax,1),tf.argmax(labelBatch,1)) print("starting backpropagation...") trainN = len(trainData["image"]) timeStart = time.time()
PH = [imageFull,label] # ------ generate perturbation ------ pInit = data.genPerturbations(opt) pInitMtrx = warp.vec2mtrx(opt,pInit) # ------ build network ------ imagePert = warp.transformCropImage(opt,imageFullNormalize,pInitMtrx) imagePertRescale = imagePert*tf.sqrt(imageVar)+imageMean if opt.netType=="CNN": output = graph.fullCNN(opt,imagePert) elif opt.netType=="STN": imageWarpAll = graph.STN(opt,imagePert) imageWarp = imageWarpAll[-1] output = graph.CNN(opt,imageWarp) imageWarpRescale = imageWarp*tf.sqrt(imageVar)+imageMean elif opt.netType=="IC-STN": imageWarpAll = graph.ICSTN(opt,imageFullNormalize,pInit) imageWarp = imageWarpAll[-1] output = graph.CNN(opt,imageWarp) imageWarpRescale = imageWarp*tf.sqrt(imageVar)+imageMean softmax = tf.nn.softmax(output) labelOnehot = tf.one_hot(label,opt.labelN) prediction = tf.equal(tf.argmax(softmax,1),label) # ------ define loss ------ softmaxLoss = tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=labelOnehot) loss = tf.reduce_mean(softmaxLoss) # ------ optimizer ------ lrGP_PH,lrC_PH = tf.placeholder(tf.float32,shape=[]),tf.placeholder(tf.float32,shape=[]) optim = util.setOptimizer(opt,loss,lrGP_PH,lrC_PH) # ------ generate summaries ------ summaryImageTrain = [] summaryImageTest = []