def __init__(self, path=dataPath, isTrain=True, nTrain=0, nTest=0): if isTrain: if len(glob(dataPath + "train/*.jpg")) == 0: mnist = loadMnistData.MnistData(path, isOneHot=False) mnist.saveImage([60000, nTrain][int(nTrain > 0)], path + "train/", True) # 60000 images in total self.data = glob(path + "train/*.jpg") else: if len(glob(dataPath + "test/*.jpg")) == 0: mnist = loadMnistData.MnistData(path, isOneHot=False) mnist.saveImage([10000, nTest][int(nTest > 0)], path + "test/", False) # 10000 images in total self.data = glob(path + "test/*.jpg")
h14 = h13 + b4 loss = tf.compat.v1.losses.softmax_cross_entropy(onehot_labels=y_, logits=h14) acc = tf.compat.v1.metrics.accuracy(labels=tf.argmax(y_, axis=1), predictions=tf.argmax(h14, axis=1))[1] tf.contrib.quantize.experimental_create_training_graph( tf.compat.v1.get_default_graph(), symmetric=True, use_qdq=True, quant_delay=800) trainStep = tf.compat.v1.train.AdamOptimizer(0.001).minimize(loss) mnist = loadMnistData.MnistData(dataPath, isOneHot=True) with tf.Session(graph=g1) as sess: sess.run( tf.group(tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer())) for i in range(1000): xSample, ySample = mnist.getBatch(nTrainbatchSize, True) trainStep.run(session=sess, feed_dict={ x: xSample.reshape(-1, 1, 28, 28), y_: ySample }) if (i % 100 == 0): xSample, ySample = mnist.getBatch(100, False) test_accuracy = sess.run(acc, { x: xSample.reshape(-1, 1, 28, 28),
# # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys import loadMnistData nTrain = int( sys.argv[1]) if len(sys.argv) > 1 and sys.argv[1].isdigit() else 600 nTest = int( sys.argv[2]) if len(sys.argv) > 2 and sys.argv[2].isdigit() else 100 mnist = loadMnistData.MnistData("./", isOneHot=False) mnist.saveImage(nTrain, "./train/", True) # 60000 images in total mnist.saveImage(nTest, "./test/", False) # 10000 images in total