Example #1
0
def train_model(batchsize, floder_log, floder_model):
    # 数据读入
    X_train, y_train = readcifar10.read(batchsize=batchsize,
                                        type="train",
                                        aug_data=True)
    X_test, y_test = readcifar10.read(batchsize=batchsize,
                                      type="test",
                                      aug_data=False)
Example #2
0
def train():
	#定义batchsize
    batchsize = 64
	#定义日志存放目录
    floder_log = 'logdirs-resnet'
	#定义model存放路径
    floder_model = 'model-resnet'
	
    if not os.path.exists(floder_log):
        os.mkdir(floder_log)

    if not os.path.exists(floder_model):
        os.mkdir(floder_model)
	
	#定义训练样本日志
    tr_summary = set()
	#定义测试样本日志
    te_summary = set()

    ##data
    tr_im, tr_label = readcifar10.read(batchsize, 0, 1)
    te_im, te_label = readcifar10.read(batchsize, 1, 0)

    ##net
	#定义输入数据
    input_data = tf.placeholder(tf.float32, shape=[None, 32, 32, 3],
                                name='input_data')

    input_label = tf.placeholder(tf.int64, shape=[None],
                                name='input_label')
    keep_prob = tf.placeholder(tf.float32, shape=None,
                                name='keep_prob')

    is_training = tf.placeholder(tf.bool, shape=None,
                               name='is_training')
    logits = resnet.model_resnet(input_data, keep_prob=keep_prob, is_training=is_training)

    ##loss

    total_loss, l2_loss = loss(logits, input_label)
	
	#记录loss的日志信息
    tr_summary.add(tf.summary.scalar('train total loss', total_loss))
    tr_summary.add(tf.summary.scalar('test l2_loss', l2_loss))

    te_summary.add(tf.summary.scalar('train total loss', total_loss))
    te_summary.add(tf.summary.scalar('test l2_loss', l2_loss))

    ##accurancy
	#获取当前概率分布中最大的值所对应的索引
    pred_max  = tf.argmax(logits, 1)
	#判断这个值是否和label相等
    correct = tf.equal(pred_max, input_label)
    #定义精度
	accurancy = tf.reduce_mean(tf.cast(correct, tf.float32))
    #记录精度的日志信息
	tr_summary.add(tf.summary.scalar('train accurancy', accurancy))
Example #3
0
def train():
    batchsize = 64
    floder_log = 'logdirs-resnet'
    floder_model = 'model-resnet'

    if not os.path.exists(floder_log):
        os.mkdir(floder_log)

    if not os.path.exists(floder_model):
        os.mkdir(floder_model)

    tr_summary = set()
    te_summary = set()

    ##data
    tr_im, tr_label = readcifar10.read(batchsize, 0, 1)
    te_im, te_label = readcifar10.read(batchsize, 1, 0)

    ##net
    input_data = tf.placeholder(tf.float32,
                                shape=[None, 32, 32, 3],
                                name='input_data')

    input_label = tf.placeholder(tf.int64, shape=[None], name='input_label')
    keep_prob = tf.placeholder(tf.float32, shape=None, name='keep_prob')

    is_training = tf.placeholder(tf.bool, shape=None, name='is_training')
    logits = resnet.model_resnet(input_data,
                                 keep_prob=keep_prob,
                                 is_training=is_training)

    ##loss

    total_loss, l2_loss = loss(logits, input_label)

    tr_summary.add(tf.summary.scalar('train total loss', total_loss))
    tr_summary.add(tf.summary.scalar('test l2_loss', l2_loss))

    te_summary.add(tf.summary.scalar('train total loss', total_loss))
    te_summary.add(tf.summary.scalar('test l2_loss', l2_loss))

    ##accurancy
    pred_max = tf.argmax(logits, 1)
    correct = tf.equal(pred_max, input_label)
    accurancy = tf.reduce_mean(tf.cast(correct, tf.float32))
    tr_summary.add(tf.summary.scalar('train accurancy', accurancy))
    te_summary.add(tf.summary.scalar('test accurancy', accurancy))
    ##op
    global_step, op, lr = func_optimal(batchsize, total_loss)
    tr_summary.add(tf.summary.scalar('train lr', lr))
    te_summary.add(tf.summary.scalar('test lr', lr))

    tr_summary.add(tf.summary.image('train image', input_data * 128 + 128))
    te_summary.add(tf.summary.image('test image', input_data * 128 + 128))

    with tf.Session() as sess:
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))

        tf.train.start_queue_runners(sess=sess, coord=tf.train.Coordinator())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)

        ckpt = tf.train.latest_checkpoint(floder_model)

        if ckpt:
            saver.restore(sess, ckpt)

        epoch_val = 100

        tr_summary_op = tf.summary.merge(list(tr_summary))
        te_summary_op = tf.summary.merge(list(te_summary))

        summary_writer = tf.summary.FileWriter(floder_log, sess.graph)

        for i in range(50000 * epoch_val):
            train_im_batch, train_label_batch = \
                sess.run([tr_im, tr_label])
            feed_dict = {
                input_data: train_im_batch,
                input_label: train_label_batch,
                keep_prob: 0.8,
                is_training: True
            }

            _, global_step_val, \
            lr_val, \
            total_loss_val, \
            accurancy_val, tr_summary_str = sess.run([op,
                                      global_step,
                                      lr,
                                      total_loss,
                                      accurancy, tr_summary_op],
                     feed_dict=feed_dict)

            summary_writer.add_summary(tr_summary_str, global_step_val)

            if i % 100 == 0:
                print("{},{},{},{}".format(global_step_val, lr_val,
                                           total_loss_val, accurancy_val))

            if i % (50000 // batchsize) == 0:
                test_loss = 0
                test_acc = 0
                for ii in range(10000 // batchsize):
                    test_im_batch, test_label_batch = \
                        sess.run([te_im, te_label])
                    feed_dict = {
                        input_data: test_im_batch,
                        input_label: test_label_batch,
                        keep_prob: 1.0,
                        is_training: False
                    }

                    total_loss_val, global_step_val, \
                    accurancy_val, te_summary_str = sess.run([total_loss,global_step,
                                              accurancy, te_summary_op],
                                             feed_dict=feed_dict)

                    summary_writer.add_summary(te_summary_str, global_step_val)

                    test_loss += total_loss_val
                    test_acc += accurancy_val

                print('test:', test_loss * batchsize / 10000,
                      test_acc * batchsize / 10000)

            if i % 1000 == 0:
                saver.save(
                    sess, "{}/model.ckpt{}".format(floder_model,
                                                   str(global_step_val)))
    return