コード例 #1
0
ファイル: train.py プロジェクト: gbyy422990/unet_tensorflow
def main(flags):
    current_time = time.strftime("%m/%d/%H/%M/%S")
    train_logdir = os.path.join(flags.logdir, "image", current_time)
    test_logdir = os.path.join(flags.logdir, "test", current_time)

    train = pd.read_csv(flags.data_dir)


    num_train = train.shape[0]

    test = pd.read_csv(flags.test_dir)
    num_test = test.shape[0]

    tf.reset_default_graph()
    X = tf.placeholder(tf.float32, shape = [None,h,w,c_image],name = 'X')
    y = tf.placeholder(tf.float32,shape = [None,h,w,c_label], name = 'y')
    mode = tf.placeholder(tf.bool, name='mode')

    pred = model.unet(X,mode)


    if flags.is_cross_entropy:
        loss = loss_CE(pred,y)
        CE_op = loss_CE(pred, y)
        tf.summary.scalar("CE", CE_op)

    else:
        loss = -loss_IOU(pred,y)
        IOU_op = loss_IOU(pred, y)
        tf.summary.scalar('IOU:', IOU_op)

    global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(flags.learning_rate, global_step,
                                               tf.cast(num_train / flags.batch_size * flags.decay_step, tf.int32),
                                               flags.decay_rate, staircase=True)

    with tf.control_dependencies(update_ops):
        training_op = train_op(loss,learning_rate)


    train_csv = tf.train.string_input_producer(['data_image.csv'])
    test_csv = tf.train.string_input_producer(['data_test.csv'])

    train_image, train_label = read_csv(train_csv,augmentation=True)
    test_image, test_label = read_csv(test_csv,augmentation=False)

    #batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量
    X_train_batch_op, y_train_batch_op = tf.train.shuffle_batch([train_image, train_label],batch_size = flags.batch_size,
                                              capacity = flags.batch_size*5,min_after_dequeue = flags.batch_size*2,
                                              allow_smaller_final_batch = True)

    X_test_batch_op, y_test_batch_op = tf.train.batch([test_image, test_label],batch_size = flags.batch_size,
                                                        capacity = flags.batch_size*2,allow_smaller_final_batch = True)



    print('Shuffle batch done')
    #tf.summary.scalar('loss/Cross_entropy', CE_op)

    tf.add_to_collection('inputs', X)
    tf.add_to_collection('inputs', mode)
    tf.add_to_collection('pred', pred)

    tf.summary.image('Input Image:', X)
    tf.summary.image('Label:', y)
    tf.summary.image('Predicted Image:', pred)

    tf.summary.scalar("learning_rate", learning_rate)

    # 添加任意shape的Tensor,统计这个Tensor的取值分布
    tf.summary.histogram('Predicted Image:', pred)


    #添加一个操作,代表执行所有summary操作,这样可以避免人工执行每一个summary op
    summary_op = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
        test_writer = tf.summary.FileWriter(test_logdir)

        init = tf.global_variables_initializer()
        sess.run(init)


        saver = tf.train.Saver()
        if os.path.exists(flags.model_dir) and tf.train.checkpoint_exists(flags.model_dir):
            latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
            saver.restore(sess, latest_check_point)

        else:
            print('No model')
            try:
                os.rmdir(flags.model_dir)
            except Exception as e:
                print(e)
            os.mkdir(flags.model_dir)

        try:
            #global_step = tf.train.get_global_step(sess.graph)

            #使用tf.train.string_input_producer(epoch_size, shuffle=False),会默认将QueueRunner添加到全局图中,
            #我们必须使用tf.train.start_queue_runners(sess=sess),去启动该线程。要在session当中将该线程开启,不然就会挂起。然后使用coord= tf.train.Coordinator()去做一些线程的同步工作,
            #否则会出现运行到sess.run一直卡住不动的情况。
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            for epoch in range(flags.epochs):
                for step in range(0,num_train,flags.batch_size):
                    X_train, y_train = sess.run([X_train_batch_op,y_train_batch_op])

                    if flags.is_cross_entropy:

                        _,step_ce,step_summary,global_step_value = sess.run([training_op,CE_op,summary_op,global_step],feed_dict={X:X_train,y:y_train,mode:True})

                        train_writer.add_summary(step_summary,global_step_value)
                        print('epoch:{} step:{} loss_CE:{}'.format(epoch+1, global_step_value, step_ce))

                    else:
                        _,step_iou,step_summary,global_step_value = sess.run([training_op, IOU_op, summary_op, global_step],feed_dict={X: X_train, y: y_train, mode: True})

                        train_writer.add_summary(step_summary, global_step_value)
                        print('epoch:{} step:{} loss_IOU:{}'.format(epoch + 1, global_step_value, step_iou))

                for step in range(0,num_test,flags.batch_size):
                    if flags.is_cross_entropy:
                        X_test, y_test = sess.run([X_test_batch_op,y_test_batch_op])
                        step_ce,step_summary = sess.run([CE_op,summary_op],feed_dict={X: X_test,y:y_test,mode:False})

                        test_writer.add_summary(step_summary,epoch * (num_train // flags.batch_size) + step // flags.batch_size * num_train // num_test)
                        print('Test loss_CE:{}'.format(step_ce))

                    else:
                        X_test, y_test = sess.run([X_test_batch_op, y_test_batch_op])
                        step_iou,step_summary = sess.run([IOU_op,summary_op],feed_dict={X: X_test, y: y_test, mode: False})

                        test_writer.add_summary(step_summary,epoch * (num_train // flags.batch_size) + step // flags.batch_size * num_train // num_test)
                        print('Test loss_IOU:{}'.format(step_iou))

            saver.save(sess, '{}/model.ckpt'.format(flags.model_dir))

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess, "{}/model.ckpt".format(flags.model_dir))
コード例 #2
0
def train():
    if not os.path.isfile(train_data_pickle):
        # trainig data
        train_features, train_labels = features(['fold0', 'fold1', 'fold2'])
        traindata = TrainData(train_features, train_labels)
        with open(train_data_pickle, mode='wb') as f:
            pickle.dump(traindata, f)
    else:
        print("loading: %s" % (train_data_pickle))
        with open(train_data_pickle, mode='rb') as f:
            traindata = pickle.load(f)
            train_features = traindata.train_inputs
            train_labels = traindata.train_targets

    if not os.path.isfile(test_data_pickle):
        test_features, test_labels = features(['fold3'])
        testdata = TestData(test_features, test_labels)
        with open(test_data_pickle, mode='wb') as f:
            pickle.dump(testdata, f)
    else:
        print("loading: %s" % (test_data_pickle))
        with open(test_data_pickle, mode='rb') as f:
            testdata = pickle.load(f)
            test_features = testdata.test_inputs
            test_labels = testdata.test_targets

    # TODO change to use train and test
    train_labels = one_hot_encode(train_labels)
    test_labels = one_hot_encode(test_labels)

    # random train and test sets.
    train_test_split = np.random.rand(len(train_features)) < 0.70
    train_x = train_features[train_test_split]
    train_y = train_labels[train_test_split]
    test_x = train_features[~train_test_split]
    test_y = train_labels[~train_test_split]

    n_dim = train_features.shape[1]
    print("input dim: %s" % (n_dim))

    # create placeholder
    X = tf.placeholder(tf.float32, [None, n_dim])
    Y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    # build graph
    logits = model.inference(X, n_dim)

    weights = tf.all_variables()
    saver = tf.train.Saver(weights)

    # create loss
    loss = model.loss(logits, Y)
    tf.scalar_summary('loss', loss)

    accracy = model.accuracy(logits, Y)
    tf.scalar_summary('test accuracy', accracy)

    # train operation
    train_op = model.train_op(loss)

    # variable initializer
    init = tf.initialize_all_variables()

    # get Session
    sess = tf.Session()

    # sumary merge and writer
    merged = tf.merge_all_summaries()
    train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir)

    # initialize
    sess.run(init)

    for step in xrange(MAX_STEPS):

        t_pred = sess.run(tf.argmax(logits, 1), feed_dict={X: train_features})
        t_true = sess.run(tf.argmax(train_labels, 1))
        print("train samples pred: %s" % t_pred[:30])
        print("train samples target: %s" % t_true[:30])
        print('Train accuracy: ',
              sess.run(accracy, feed_dict={
                  X: train_x,
                  Y: train_y
              }))
        for epoch in xrange(training_epochs):
            summary, logits_val, _, loss_val = sess.run(
                [merged, logits, train_op, loss],
                feed_dict={
                    X: train_x,
                    Y: train_y
                })
        train_writer.add_summary(summary, step)

        print("step:%d, loss: %s" % (step, loss_val))
        y_pred = sess.run(tf.argmax(logits, 1), feed_dict={X: test_x})
        y_true = sess.run(tf.argmax(test_y, 1))
        print("test samples pred: %s" % y_pred[:10])
        print("test samples target: %s" % y_true[:10])
        accracy_val = sess.run([accracy], feed_dict={X: test_x, Y: test_y})
        # print('Test accuracy: ', accracy_val)
        # train_writer.add_summary(accracy_val, step)
        p, r, f, s = precision_recall_fscore_support(y_true,
                                                     y_pred,
                                                     average='micro')
        print("F-score: %s" % f)

        if step % 1000 == 0:
            saver.save(sess, FLAGS.ckpt_dir, global_step=step)
コード例 #3
0
 FLAGS.batch_size = 8
 FLAGS.is_training = True
 FLAGS.minimal_summaries = True
 FLAGS.initial_learning_rate = 1e-4
 FLAGS.stddev = 5e-2
 FLAGS.weight_decay = 5e-5
 # global step
 global_step = tf.Variable(0, trainable=False)
 # get training batch
 images, labels = model.get_train_input()
 # inference
 outputs = model.inference_resnet(images)
 # calculate total loss
 loss = model.loss(outputs, labels)
 # train operation
 train_op = model.train_op(loss, global_step)
 # initialize
 init = tf.initialize_all_variables()
 # Start running operations on the Graph.
 config = tf.ConfigProto(log_device_placement=False)
 config.gpu_options.allow_growth = True
 sess = tf.InteractiveSession(config=config)
 sess.run(init)
 # resnet saver
 # saver_resnet = tf.train.Saver(tf.trainable_variables())
 saver_resnet = tf.train.Saver(
     [v for v in tf.trainable_variables() if not "fc" in v.name])
 saver_resnet.restore(sess, FLAGS.resnet_param)
 # start queue runner
 tf.train.start_queue_runners(sess=sess)
 # pass once
コード例 #4
0
def main(flags):
    current_time = time.strftime("%m/%d/%H/%M/%S")
    train_logdir = os.path.join(flags.logdir, "pig", current_time)
    test_logdir = os.path.join(flags.logdir, "test", current_time)

    train = pd.read_csv(flags.data_dir)

    num_train = train.shape[0]

    test = pd.read_csv(flags.test_dir)
    num_test = test.shape[0]

    tf.reset_default_graph()
    X = tf.placeholder(tf.float32,
                       shape=[flags.batch_size, h, w, c_image],
                       name='X')
    y = tf.placeholder(tf.float32,
                       shape=[flags.batch_size, h, w, c_label],
                       name='y')
    mode = tf.placeholder(tf.bool, name='mode')

    score_dsn6_up, score_dsn5_up, score_dsn4_up, score_dsn3_up, score_dsn2_up, score_dsn1_up, upscore_fuse = model.unet(
        X, mode)

    #print(score_dsn6_up.get_shape().as_list())

    loss6 = loss_CE(score_dsn6_up, y)
    loss5 = loss_CE(score_dsn5_up, y)
    loss4 = loss_CE(score_dsn4_up, y)
    loss3 = loss_CE(score_dsn3_up, y)
    loss2 = loss_CE(score_dsn2_up, y)
    loss1 = loss_CE(score_dsn1_up, y)
    loss_fuse = loss_CE(upscore_fuse, y)
    tf.summary.scalar("CE6", loss6)
    tf.summary.scalar("CE5", loss5)
    tf.summary.scalar("CE4", loss4)
    tf.summary.scalar("CE3", loss3)
    tf.summary.scalar("CE2", loss2)
    tf.summary.scalar("CE1", loss1)
    tf.summary.scalar("CE_fuse", loss_fuse)

    Loss = loss6 + loss5 + loss4 + loss3 + loss2 + 2 * loss1 + loss_fuse
    tf.summary.scalar("CE_total", Loss)

    global_step = tf.Variable(0,
                              dtype=tf.int64,
                              trainable=False,
                              name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(flags.learning_rate,
                                               global_step,
                                               decay_steps=flags.decay_step,
                                               decay_rate=flags.decay_rate,
                                               staircase=True)

    with tf.control_dependencies(update_ops):
        training_op = train_op(Loss, learning_rate)

    train_csv = tf.train.string_input_producer(['pig1.csv'])
    test_csv = tf.train.string_input_producer(['pigtest1.csv'])

    train_image, train_label = read_csv(train_csv, augmentation=True)
    test_image, test_label = read_csv(test_csv, augmentation=False)

    #batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量
    X_train_batch_op, y_train_batch_op = tf.train.shuffle_batch(
        [train_image, train_label],
        batch_size=flags.batch_size,
        capacity=flags.batch_size * 5,
        min_after_dequeue=flags.batch_size * 2,
        allow_smaller_final_batch=True)

    X_test_batch_op, y_test_batch_op = tf.train.batch(
        [test_image, test_label],
        batch_size=flags.batch_size,
        capacity=flags.batch_size * 2,
        allow_smaller_final_batch=True)

    print('Shuffle batch done')
    #tf.summary.scalar('loss/Cross_entropy', CE_op)
    score_dsn6_up = tf.nn.sigmoid(score_dsn6_up)
    score_dsn5_up = tf.nn.sigmoid(score_dsn5_up)
    score_dsn4_up = tf.nn.sigmoid(score_dsn4_up)
    score_dsn3_up = tf.nn.sigmoid(score_dsn3_up)
    score_dsn2_up = tf.nn.sigmoid(score_dsn2_up)
    score_dsn1_up = tf.nn.sigmoid(score_dsn1_up)
    upscore_fuse = tf.nn.sigmoid(upscore_fuse)
    print(upscore_fuse.get_shape().as_list())

    tf.add_to_collection('inputs', X)
    tf.add_to_collection('inputs', mode)
    tf.add_to_collection('score_dsn6_up', score_dsn6_up)
    tf.add_to_collection('score_dsn5_up', score_dsn5_up)
    tf.add_to_collection('score_dsn4_up', score_dsn4_up)
    tf.add_to_collection('score_dsn3_up', score_dsn3_up)
    tf.add_to_collection('score_dsn2_up', score_dsn2_up)
    tf.add_to_collection('score_dsn1_up', score_dsn1_up)
    tf.add_to_collection('upscore_fuse', upscore_fuse)

    tf.summary.image('Input Image:', X)
    tf.summary.image('Label:', y)
    tf.summary.image('score_dsn6_up:', score_dsn6_up)
    tf.summary.image('score_dsn5_up:', score_dsn5_up)
    tf.summary.image('score_dsn4_up:', score_dsn4_up)
    tf.summary.image('score_dsn3_up:', score_dsn3_up)
    tf.summary.image('score_dsn2_up:', score_dsn2_up)
    tf.summary.image('score_dsn1_up:', score_dsn1_up)
    tf.summary.image('upscore_fuse:', upscore_fuse)

    tf.summary.scalar("learning_rate", learning_rate)

    # 添加任意shape的Tensor,统计这个Tensor的取值分布
    tf.summary.histogram('score_dsn1_up:', score_dsn1_up)
    tf.summary.histogram('score_dsn2_up:', score_dsn2_up)
    tf.summary.histogram('score_dsn3_up:', score_dsn3_up)
    tf.summary.histogram('score_dsn4_up:', score_dsn4_up)
    tf.summary.histogram('score_dsn5_up:', score_dsn5_up)
    tf.summary.histogram('score_dsn6_up:', score_dsn6_up)
    tf.summary.histogram('upscore_fuse:', upscore_fuse)

    #添加一个操作,代表执行所有summary操作,这样可以避免人工执行每一个summary op
    summary_op = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
        test_writer = tf.summary.FileWriter(test_logdir)

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()
        # if not os.listdir(flags.model_dir):
        #     print('No model')
        #     try:
        #         os.rmdir(flags.model_dir)
        #     except Exception as e:
        #         print(e)
        #     os.mkdir(flags.model_dir)
        # else:
        #     latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
        #     saver.restore(sess, latest_check_point)
        if os.path.exists(flags.model_dir) and tf.train.checkpoint_exists(
                flags.model_dir):
            latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
            saver.restore(sess, latest_check_point)

        else:
            print('No model')
            try:
                os.rmdir(flags.model_dir)
            except Exception as e:
                print(e)
            os.mkdir(flags.model_dir)

        try:
            #global_step = tf.train.get_global_step(sess.graph)

            #使用tf.train.string_input_producer(epoch_size, shuffle=False),会默认将QueueRunner添加到全局图中,
            #我们必须使用tf.train.start_queue_runners(sess=sess),去启动该线程。要在session当中将该线程开启,不然就会挂起。然后使用coord= tf.train.Coordinator()去做一些线程的同步工作,
            #否则会出现运行到sess.run一直卡住不动的情况。
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            for epoch in range(flags.epochs):
                for step in range(0, num_train, flags.batch_size):
                    X_train, y_train = sess.run(
                        [X_train_batch_op, y_train_batch_op])
                    _, step_ce, step_summary, global_step_value = sess.run(
                        [training_op, Loss, summary_op, global_step],
                        feed_dict={
                            X: X_train,
                            y: y_train,
                            mode: True
                        })

                    train_writer.add_summary(step_summary, global_step_value)
                    print('epoch:{} step:{} loss_CE:{}'.format(
                        epoch + 1, global_step_value, step_ce))
                for step in range(0, num_test, flags.batch_size):
                    X_test, y_test = sess.run(
                        [X_test_batch_op, y_test_batch_op])
                    step_ce, step_summary = sess.run([Loss, summary_op],
                                                     feed_dict={
                                                         X: X_test,
                                                         y: y_test,
                                                         mode: False
                                                     })

                    test_writer.add_summary(
                        step_summary,
                        epoch * (num_train // flags.batch_size) +
                        step // flags.batch_size * num_train // num_test)
                    print('Test loss_CE:{}'.format(step_ce))
                saver.save(sess, '{}/model.ckpt'.format(flags.model_dir))

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess, "{}/model.ckpt".format(flags.model_dir))
コード例 #5
0
def train_nvidia():
    with tf.Graph().as_default():
        FLAGS.batch_size = 512
        FLAGS.minimal_summaries = False
        FLAGS.initial_learning_rate = 1e-3
        FLAGS.stddev = 0.1
        FLAGS.weight_decay = 1e-5
        # global step
        global_step = tf.Variable(0, trainable=False)
        with tf.device("/gpu:" + FLAGS.gpu_id):
            # train net
            train_images, train_labels = model.get_train_input()
            outputs_train = model.inference_nvidianet2(train_images)
            loss_train = model.loss(outputs_train, train_labels)
            # validation net
            val_images, val_labels = model.get_val_input()
            tf.get_variable_scope().reuse_variables()
            outputs_val = model.inference_nvidianet(val_images)
            loss_val = model.loss(outputs_val, val_labels)
            # train operation
            train_op = model.train_op(loss_train, global_step)
        # saver
        saver = tf.train.Saver(tf.all_variables())
        # summarize
        if not FLAGS.minimal_summaries:
            tf.image_summary('images', train_images)
            for var in tf.trainable_variables():
                tf.histogram_summary(var.op.name, var)
        summary_op = tf.merge_all_summaries()
        # initialize
        init = tf.initialize_all_variables()
        # Start running operations on the Graph.
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(init)
        print('network initialized')
        # start queue runner
        tf.train.start_queue_runners(sess=sess)
        # write summary
        summary_writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
        max_iter = int(FLAGS.max_epoch * FLAGS.num_examples_train /
                       FLAGS.batch_size)
        print('total iteration:', str(max_iter))
        for step in xrange(max_iter):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss_train])
            # loss_value = sess.run(loss) # test inference time only
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 200 == 0:
                val_iter = 16
                val_losses = np.zeros((val_iter))
                for ival in range(val_iter):
                    val_loss_value = sess.run(loss_val)
                    val_losses[ival] = val_loss_value
                print("mean validation loss:", np.mean(val_losses))

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: step %d, loss = %.2f'
                              ' (%.1f examples/sec; %.3f sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            if step % 200 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == max_iter:
                checkpoint_path = os.path.join(FLAGS.logdir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
コード例 #6
0
ファイル: run.py プロジェクト: mochrielab/ecoli_segmentation
                        path+'data/dataset3_train.tfrecords',
                        # path+'data/dataset4_train.tfrecords',
                        ], num_epochs=100000000)
        val_filename_queue = tf.train.string_input_producer(\
                        [path+'data/dataset1_test.tfrecords',
                        # path+'data/dataset2_test.tfrecords',
                        path+'data/dataset3_test.tfrecords',
                        # path+'data/dataset4_test.tfrecords',
                        ], num_epochs=1000000000)
        train_data_sets = dl.dataloader(train_filename_queue, batchsize=32)
        # val_data_sets = dl.dataloader(val_filename_queue, batchsize=8)
        image, label = train_data_sets.get_batch()
    logits = unet.inference(image, True, num_classes=3)
    loss = unet.loss(logits, label)
    # iou = unet.iou(logits, label)
    train_op = unet.train_op(loss)
    loss_sumary = tf.summary.scalar('loss', loss)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    saver = tf.train.Saver()
    #    print('trainable variables:---------------------------------------')
    #    for var in tf.trainable_variables():
    #        print(var.name, var.shape)

    with tf.Session() as sess:
        summary_writer = tf.summary.FileWriter(path + 'logs', sess.graph)
        sess.run(init_op)
        checkpoint_file = os.path.join(path, 'logs', 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=0)

        coord = tf.train.Coordinator()
コード例 #7
0
def train_resnet():
    with tf.Graph().as_default():
        # set flag to training
        FLAGS.batch_size = 8
        FLAGS.is_training = True
        FLAGS.minimal_summaries = False
        FLAGS.initial_learning_rate = 1e-3
        FLAGS.stddev = 5e-2
        FLAGS.weight_decay = 1e-6
        # global step
        global_step = tf.Variable(0, trainable=False)
        # get training batch
        images, labels = model.get_train_input()
        # inference
        outputs = model.inference_resnet(images)
        # calculate total loss
        loss = model.loss(outputs, labels)
        # train operation
        train_op = model.train_op(loss, global_step)
        # saver
        saver = tf.train.Saver(tf.all_variables())
        # summarize
        if not FLAGS.minimal_summaries:
            tf.image_summary('images', images)
            for var in tf.trainable_variables():
                tf.histogram_summary(var.op.name, var)
        summary_op = tf.merge_all_summaries()
        # initialize
        init = tf.initialize_all_variables()
        # Start running operations on the Graph.
        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(init)
        print('network initialized')
        # saver_resnet = tf.train.Saver(tf.trainable_variables())
        saver_resnet = tf.train.Saver(
            [v for v in tf.trainable_variables() if not "fc" in v.name])
        saver_resnet.restore(sess, FLAGS.resnet_param)
        # start queue runner
        tf.train.start_queue_runners(sess=sess)
        # write summary
        summary_writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
        #
        max_iter = int(FLAGS.max_epoch * FLAGS.num_examples_train /
                       FLAGS.batch_size)
        print('total iteration:', str(max_iter))
        for step in xrange(max_iter):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            # loss_value = sess.run(loss) # test inference time only
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: step %d, loss = %.2f'
                              ' (%.1f examples/sec; %.3f sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))
            if step % 200 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == max_iter:
                checkpoint_path = os.path.join(FLAGS.logdir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
コード例 #8
0
ファイル: train.py プロジェクト: anhtu95/prmu2017
def train():
    sess = tf.InteractiveSession()
    global_step = tf.contrib.framework.get_or_create_global_step()

    #Load data
    train_sets = PRMUDataSet("1_train_0.9")
    train_sets.load_data_target()
    n_train_samples = train_sets.get_n_types_target()
    print("n_train_samples = " + str(n_train_samples))

    if not os.path.isfile('save/current/model.ckpt.index'):
        print('Create new model')
        x, y_ = model.input()
        y_conv = model.inference(x)
        loss = model.loss(y_conv=y_conv, y_=y_)
        train_step = model.train_op(loss, global_step)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
    else:
        print('Load exist model')
        saver = tf.train.import_meta_graph('save/current/model.ckpt.meta')
        saver.restore(sess, 'save/current/model.ckpt')

    learning_rate = tf.get_collection('learning_rate')[0]
    cross_entropy_loss = tf.get_collection('cross_entropy_loss')[0]

    train_step = tf.get_collection('train_step')[0]

    keep_prob_fc1 = tf.get_collection('keep_prob_fc1')[0]
    keep_prob_fc2 = tf.get_collection('keep_prob_fc2')[0]
    x = tf.get_collection('x')[0]
    y_ = tf.get_collection('y_')[0]
    y_conv = tf.get_collection('y_conv')[0]

    correct_prediction = tf.equal(tf.arg_max(y_conv, 1), tf.arg_max(y_, 1))
    true_pred = tf.reduce_sum(tf.cast(correct_prediction, dtype=tf.float32))

    for epoch in range(nb_epochs):
        print("Epoch: %d" % epoch)
        print("Learning rate: " + str(learning_rate.eval()))

        avg_ttl = []
        nb_true_pred = 0

        # shuffle data
        perm = np.random.permutation(n_train_samples)
        print('x_train = ' + str(perm))
        if epoch % 10 == 0:
            saver.save(sess, "save/current/model.ckpt")

        for i in range(0, n_train_samples, batch_size):

            x_batch = train_sets.data[perm[i:(i + batch_size)]]
            #print('x_batch['+str(i)+'] = '+str(perm[i:(i+batch_size)]))

            batch_target = np.asarray(train_sets.target[perm[i:i +
                                                             batch_size]])
            y_batch = np.zeros((len(x_batch), 46), dtype=np.float32)
            y_batch[np.arange(len(x_batch)), batch_target] = 1.0
            #print('batch_target = '+str(batch_target))

            ttl, _ = sess.run(
                [cross_entropy_loss, train_step],
                feed_dict={
                    x: x_batch,
                    y_: y_batch,
                    keep_prob_fc1: (1 - drop_out_prob),
                    keep_prob_fc2: (1 - drop_out_prob)
                })

            avg_ttl.append(ttl * len(x_batch))

            nb_true_pred += true_pred.eval(feed_dict={
                x: x_batch,
                y_: y_batch,
                keep_prob_fc1: 1,
                keep_prob_fc2: 1
            })
            print('Batch ' + str(i) + ' : Number of true prediction: ' +
                  str(nb_true_pred))
        print("Average total loss: " + str(np.sum(avg_ttl) / n_train_samples))
        print("Train accuracy: " + str(nb_true_pred * 1.0 / n_train_samples))