# -*- coding: utf-8 -*- # @Time : 2018/4/11 10:08 # @Author : jiaopan # @Email : [email protected] from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import tensorflow as tf import utils FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_integer( 'batch_size', int(utils.configUtil("global.conf", "dataset", "batch_size")), """每个batch样本总数""") IMAGE_SIZE = int(utils.configUtil("global.conf", "dataset", "image_mat_size")) NUM_CLASSES = int(utils.configUtil("global.conf", "dataset", "num_class")) NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = int( utils.configUtil("global.conf", "train", "train_data_count")) NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = int( utils.configUtil("global.conf", "eval", "eval_data_count")) # 超参数设置 MOVING_AVERAGE_DECAY = 0.9999 NUM_EPOCHS_PER_DECAY = 350.0 LEARNING_RATE_DECAY_FACTOR = 0.1 INITIAL_LEARNING_RATE = 0.1 TOWER_NAME = 'tower' def activation_summary(x):
import utils print(utils.configUtil("global.conf", "data", "train_data_dir"))
from __future__ import absolute_import from __future__ import division from __future__ import print_function from datetime import datetime import math import time from tensorflow.python.platform import gfile import numpy as np import tensorflow as tf import network import tfrecord import utils FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string( 'eval_dir', utils.configUtil("global.conf", "eval", "eval_log_dir"), """验证日志目录.""") tf.app.flags.DEFINE_string( 'eval_data', utils.configUtil("global.conf", "eval", "eval_tfrecord_dir"), """验证数据集目录""") tf.app.flags.DEFINE_string( 'checkpoint_dir', utils.configUtil("global.conf", "model", "model_dir"), """保存的模型.""") tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 3, """设置每隔多长时间做一侧评估""") tf.app.flags.DEFINE_integer( 'num_examples', int(utils.configUtil("global.conf", "eval", "eval_data_count")), """验证数据集样本总数""") tf.app.flags.DEFINE_boolean('run_once', False, """仅验证一次.""")
# 训练模块 from __future__ import absolute_import from __future__ import division from __future__ import print_function from datetime import datetime import os.path import time import numpy as np from six.moves import xrange import tensorflow as tf import network import tfrecord import utils FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('train_dir', utils.configUtil("global.conf","model","model_dir"), """模型保存目录""" """检查点存储目录.(tensorboard查看)""") tf.app.flags.DEFINE_integer('max_steps', int(utils.configUtil("global.conf","train","max_steps")), """最大训练/迭代次数.""") tf.app.flags.DEFINE_boolean('log_device_placement', False,"""""") tf.app.flags.DEFINE_string('train_data',utils.configUtil("global.conf","train","train_tfrecord_dir"), '训练集目录(tfrecord)') tf.app.flags.DEFINE_integer('train_num',int(utils.configUtil("global.conf","train","train_data_count")), '训练集样本总数') def train(): with tf.Graph().as_default():
# -*- coding: utf-8 -*- # @Time : 2018/4/15 2:14 # @Author : jiaopan # @Email : [email protected] from PIL import Image import tensorflow as tf import network import utils FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string( 'checkpoint_dir', utils.configUtil("global.conf", "model", "model_dir"), """保存的模型.""") def inputs(input, count=1, batch_size=1): network.FLAGS.batch_size = batch_size img = Image.open(input) img = img.resize((32, 32)) img = img.tobytes() img = tf.decode_raw(img, tf.uint8) img = tf.reshape(img, [3, 32, 32]) img = tf.transpose(img, [1, 2, 0]) img = tf.cast(img, tf.float32) float_image = tf.image.per_image_standardization(img) capacity = int(count * 0.4 + 3 * batch_size) min_after_dequeue = int(batch_size * 0.4) images, label_batch = tf.train.shuffle_batch( [float_image, '?'], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue,
#coding:utf-8 # -*- coding: utf-8 -*- # @Time : 2018/4/11 10:08 # @Author : jiaopan # @Email : [email protected] import tensorflow as tf import os,sys,time from PIL import Image import utils IMAGE_SIZE = int(utils.configUtil("global.conf","dataset","resize_image_size")) # 裁剪大小 IMAGE_MAT_SIZE = int(utils.configUtil("global.conf","dataset","image_mat_size")) # reshape参数 CHANNELS = int(utils.configUtil("global.conf","dataset","chnnels")) # 通道 TRAIN_DATASET = utils.configUtil("global.conf","dataset","train_data_dir") # 训练原始数据 EVAL_DATASET = utils.configUtil("global.conf","dataset","eval_data_dir") # 验证集原始数据 BATCH_SIZE = int(utils.configUtil("global.conf","dataset","batch_size")) def create(dataset_dir,tfrecord_path,tfrecord_name="train_tfrecord",width=IMAGE_SIZE,height=IMAGE_SIZE): """ #构建图片TFrecord文件 param dataset_dir:原始图片的根目录,目录下包含多个子目录,每个子目录下为同一类别的图片 param tfrecord_name:存储的TFreord文件名 param tfrecord_path:存储的TFreord文件的目录路径 param width:图片裁剪宽度 param height:图片裁剪高度 param channels:通道 """ if not os.path.exists(dataset_dir): print('创建TFRECORD文件时出错,文件目录或文件不存在,请检查路径名..\n') exit()