示例#1
0
文件: gan11.py 项目: abiaozsh/MyCode
            def save(idx, gSaver, dSaver):
                print("start save")
                saveToFile = ConvNet.openBinaryFileW("gan11g"+str(idx)+".bin")
                for item in gSaver:
                    item(saveToFile)
                saveToFile.flush();saveToFile.close()
 
                saveToFile = ConvNet.openBinaryFileW("gan11d"+str(idx)+".bin")
                for item in dSaver:
                    item(saveToFile)
                saveToFile.flush();saveToFile.close()
                print("end save")
示例#2
0
def train():
    #with tf.Session(config=tf.ConfigProto(device_count = {'GPU': 0})) as sess:
    with tf.Session() as sess:
        #初始化参数
        sess.run(tf.global_variables_initializer())


        for j in xrange(0, 10):
            #打印当前网络的输出值
            accurate = 0
            for _i in xrange(0,100):
                testData = MNISTDataLargeSet.extract_data(1)
                testLabel = MNISTDataLargeSet.extract_label(1,onehot = True)
                lbl = sess.run(_test, feed_dict={testOne:testData})

                if np.argmax(lbl, 1) == np.argmax(testLabel, 1):
                    accurate = accurate + 0.01
            
            #执行训练
            totalLoss = 0.0
            for _i in xrange(0,100):
                trainData = MNISTDataLargeSet.extract_data(BATCH_SIZE)
                trainLabel = MNISTDataLargeSet.extract_label(BATCH_SIZE,onehot = True)
                _,_loss = sess.run([optimizer,loss], feed_dict={labels_node: trainLabel, inputlayer: trainData})
                totalLoss = totalLoss + _loss
            
            print(j,"accu:",accurate,"loss:",totalLoss)
            
            #保存已训练的网络
            Saver = []
            for item in plist:
                Saver.append(item.getSaver(sess, True))
                
            saveToFile = ConvNet.openBinaryFileW("MNISTLargeSet.bin")
            for item in Saver:
                item(saveToFile)
            saveToFile.flush();saveToFile.close()