def __init__(self, path=dataPath, isTrain=True, nTrain=0, nTest=0):
     if isTrain:
         if len(glob(dataPath + "train/*.jpg")) == 0:
             mnist = loadMnistData.MnistData(path, isOneHot=False)
             mnist.saveImage([60000, nTrain][int(nTrain > 0)], path + "train/", True)  # 60000 images in total
         self.data = glob(path + "train/*.jpg")
     else:
         if len(glob(dataPath + "test/*.jpg")) == 0:
             mnist = loadMnistData.MnistData(path, isOneHot=False)
             mnist.saveImage([10000, nTest][int(nTest > 0)], path + "test/", False)  # 10000 images in total
         self.data = glob(path + "test/*.jpg")
    h14 = h13 + b4

    loss = tf.compat.v1.losses.softmax_cross_entropy(onehot_labels=y_,
                                                     logits=h14)
    acc = tf.compat.v1.metrics.accuracy(labels=tf.argmax(y_, axis=1),
                                        predictions=tf.argmax(h14, axis=1))[1]

    tf.contrib.quantize.experimental_create_training_graph(
        tf.compat.v1.get_default_graph(),
        symmetric=True,
        use_qdq=True,
        quant_delay=800)

    trainStep = tf.compat.v1.train.AdamOptimizer(0.001).minimize(loss)

mnist = loadMnistData.MnistData(dataPath, isOneHot=True)
with tf.Session(graph=g1) as sess:
    sess.run(
        tf.group(tf.compat.v1.global_variables_initializer(),
                 tf.compat.v1.local_variables_initializer()))
    for i in range(1000):
        xSample, ySample = mnist.getBatch(nTrainbatchSize, True)
        trainStep.run(session=sess,
                      feed_dict={
                          x: xSample.reshape(-1, 1, 28, 28),
                          y_: ySample
                      })
        if (i % 100 == 0):
            xSample, ySample = mnist.getBatch(100, False)
            test_accuracy = sess.run(acc, {
                x: xSample.reshape(-1, 1, 28, 28),
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import loadMnistData

nTrain = int(
    sys.argv[1]) if len(sys.argv) > 1 and sys.argv[1].isdigit() else 600
nTest = int(
    sys.argv[2]) if len(sys.argv) > 2 and sys.argv[2].isdigit() else 100

mnist = loadMnistData.MnistData("./", isOneHot=False)
mnist.saveImage(nTrain, "./train/", True)  # 60000 images in total
mnist.saveImage(nTest, "./test/", False)  # 10000 images in total