Esempio n. 1
0
def main(FLAGS):

    scale_width, scale_height = select_things.select_scale(
        FLAGS.scale, FLAGS.width, FLAGS.height)
    '''--------Creat palceholder--------'''
    datas, labels, train = net.create_placeholder(FLAGS.batch_size,
                                                  FLAGS.width, FLAGS.height,
                                                  scale_width, scale_height)
    '''--------net--------'''
    pre_scale1, pre_scale2, pre_scale3 = net.feature_extractor(datas, train)
    scale1, scale2, scale3 = net.scales(pre_scale1, pre_scale2, pre_scale3,
                                        train)
    '''--------get labels_filenames and datas_filenames--------'''
    datas_filenames = reader.images(FLAGS.batch_size, FLAGS.datas_path)
    labels_fienames = reader.labels(FLAGS.batch_size, FLAGS.labels_path)
    normalize_labels = extract_labels.labels_normalizer(
        labels_fienames, FLAGS.width, FLAGS.height, scale_width, scale_height)
    '''---------partition the train data and val data--------'''
    train_filenames = datas_filenames[:int(len(datas_filenames) * 0.9)]
    train_labels = normalize_labels[:int(len(normalize_labels) * 0.9)]
    val_filenames = datas_filenames[len(datas_filenames) -
                                    int(len(datas_filenames) * 0.9):]
    val_labels = normalize_labels[len(normalize_labels) -
                                  int(len(normalize_labels) * 0.9):]
    '''--------calculate loss--------'''
    if FLAGS.scale == 1:
        loss = get_loss.calculate_loss(scale1, labels)

    if FLAGS.scale == 2:
        loss = get_loss.calculate_loss(scale2, labels)

    if FLAGS.scale == 3:
        loss = get_loss.calculate_loss(scale3, labels)
    '''--------Optimizer--------'''
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.learning_rate).minimize(loss)

    tf.summary.scalar('loss', loss)
    merged = tf.summary.merge_all()

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        writer = tf.summary.FileWriter("logs/", sess.graph)
        number = 0

        saver = tf.train.Saver(max_to_keep=10)
        save_path = select_things.select_checkpoint(FLAGS.scale)
        last_checkpoint = tf.train.latest_checkpoint(save_path, 'checkpoint')
        if last_checkpoint:
            saver.restore(sess, last_checkpoint)
            number = int(last_checkpoint[28:]) + 1
            print('Reuse model form: ', format(last_checkpoint))
        else:
            sess.run(init)

        for epoch in range(FLAGS.epoch):
            epoch_loss = tf.cast(0, tf.float32)
            for i in range(len(train_filenames)):
                normalize_datas = []
                for data_filename in train_filenames[i]:
                    image = reader.get_image(data_filename, FLAGS.width,
                                             FLAGS.height)
                    image = np.array(image, np.float32)

                    normalize_datas.append(image)

                normalize_datas = np.array(normalize_datas)

                _, batch_loss, rs = sess.run([optimizer, loss, merged],
                                             feed_dict={
                                                 datas: normalize_datas,
                                                 labels: train_labels[i],
                                                 train: True
                                             })
                print('batch_loss after epoch %i: %f' % (i, batch_loss))
                epoch_loss = +batch_loss

            writer.add_summary(rs, epoch + number)

            if epoch % 1 == 0 & epoch != 0:
                print('Cost after epoch %i: %f' % (epoch + number, epoch_loss))
                name = 'scale' + str(FLAGS.scale) + '.ckpt'
                saver.save(sess,
                           os.path.join(save_path, name),
                           global_step=epoch + number)

            if epoch % 10 == 0 & epoch != 0:
                val_loss = tf.cast(0, tf.float32)
                for i in range(len(val_filenames)):
                    normalize_datas = []
                    for val_filename in val_filenames[i]:
                        image = reader.get_image(val_filename, FLAGS.width,
                                                 FLAGS.height)
                        image = np.array(image, np.float32)
                        image = np.divide(image, 255)

                        normalize_datas.append(image)

                    normalize_datas = np.array(normalize_datas)

                    batch_loss = sess.run(loss,
                                          feed_dict={
                                              datas: normalize_datas,
                                              labels: val_labels[i],
                                              train: False
                                          })

                    val_loss = +batch_loss

                print('VAL_Cost after epoch %i: %f' %
                      (epoch + number, val_loss))
Esempio n. 2
0
def main(FLAGS):
    if not os.path.exists(FLAGS.save_dir):
        os.makedirs(FLAGS.save_dir)

    input_image = reader.get_image(FLAGS.image_dir, FLAGS.image_width,
                                   FLAGS.image_height)
    output_image = np.copy(input_image)
    '''--------Create placeholder--------'''
    image = net.create_eval_placeholder(FLAGS.image_width, FLAGS.image_height)
    '''--------net--------'''

    pre_scale1, pre_scale2, pre_scale3 = net.feature_extractor(image, False)
    scale1, scale2, scale3 = net.scales(pre_scale1, pre_scale2, pre_scale3,
                                        False)

    with tf.Session() as sess:
        saver = tf.train.Saver()
        save_path = select_things.select_checkpoint(FLAGS.scale)
        #ckpt = tf.train.get_checkpoint_state() #获取checkpoints对象
        #if ckpt and ckpt.model_checkpoint_path:##判断ckpt是否为空,若不为空,才进行模型的加载,否则从头开始训练
        #print('Restoring weights from: ' + ckpt.model_checkpoint_path)
        #saver.restore(sess,'yolov3.ckpt')#恢复保存的神经网络结构,实现断点续训
        print('load-weight-start')
        load_ops = load_weights(tf.global_variables(), 'yolov3.weights')
        sess.run(load_ops)
        print('laod-weights-done')
        #last_checkpoint = tf.train.latest_checkpoint( save_path, 'checkpoint' )
        #if last_checkpoint:
        #saver.restore(sess, last_checkpoint)
        #print( 'Success load model from: ', format( last_checkpoint ) )
        #else:
        #print( 'Model has not trained' )
        saver.save(sess, save_path)
        start_time = time.time()
        scale1, scale2, scale3 = sess.run([scale1, scale2, scale3],
                                          feed_dict={image: [output_image]})

    if FLAGS.scale == 1:
        scale = scale1
    if FLAGS.scale == 2:
        scale = scale2
    if FLAGS.scale == 3:
        scale = scale3

    boxes_labels = eval_uitls.label_extractor(scale[0])

    bdboxes = eval_uitls.get_bdboxes(boxes_labels)

    for bdbox in bdboxes:
        font = cv2.FONT_HERSHEY_SIMPLEX
        output_image = cv2.rectangle(
            output_image,
            (int(bdbox[0] - bdbox[2] / 2), int(bdbox[1] - bdbox[3] / 2)),
            (int(bdbox[0] + bdbox[2] / 2), int(bdbox[1] + bdbox[3] / 2)),
            (200, 0, 0), 1)
        output_image = cv2.putText(
            output_image, bdbox[4],
            (int(bdbox[0] - bdbox[2] / 2), int(bdbox[1] - bdbox[3] / 2)),
            cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0), 1)
    # output_image = np.multiply( output_image, 255 )

    generate_image = FLAGS.save_dir + '/res.jpg'
    if not os.path.exists(FLAGS.save_dir):
        os.makedirs(FLAGS.save_dir)

    cv2.imwrite(generate_image, cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR))
    end_time = time.time()

    print('Use time: ', end_time - start_time)

    plt.imshow(output_image)
    plt.show()
Esempio n. 3
0
def main(FLAGS):
    if not os.path.exists(FLAGS.save_dir):
        os.makedirs(FLAGS.save_dir)

    input_image = reader.get_image(FLAGS.image_dir, FLAGS.image_width,
                                   FLAGS.image_height)
    output_image = np.copy(input_image)
    '''--------Create placeholder--------'''
    image = net.create_eval_placeholder(FLAGS.image_width, FLAGS.image_height)
    '''--------net--------'''
    pre_scale1, pre_scale2, pre_scale3 = net.feature_extractor(image)
    scale1, scale2, scale3 = net.scales(pre_scale1, pre_scale2, pre_scale3)

    with tf.Session() as sess:
        saver = tf.train.Saver()
        save_path = select_things.select_checkpoint(FLAGS.scale)
        last_checkpoint = tf.train.latest_checkpoint(save_path, 'checkpoint')
        if last_checkpoint:
            saver.restore(sess, last_checkpoint)
            print('Success load model from: ', format(last_checkpoint))
        else:
            print('Model has not trained')

        start_time = time.time()
        scale1, scale2, scale3 = sess.run([scale1, scale2, scale3],
                                          feed_dict={image: [output_image]})

    if FLAGS.scale == 1:
        scale = scale1
    if FLAGS.scale == 2:
        scale = scale2
    if FLAGS.scale == 3:
        scale = scale3

    boxes_labels = eval_uitls.label_extractor(scale[0])

    bdboxes = eval_uitls.get_bdboxes(boxes_labels)

    for bdbox in bdboxes:
        font = cv2.FONT_HERSHEY_SIMPLEX
        output_image = cv2.rectangle(
            output_image,
            (int(bdbox[0] - bdbox[2] / 2), int(bdbox[1] - bdbox[3] / 2)),
            (int(bdbox[0] + bdbox[2] / 2), int(bdbox[1] + bdbox[3] / 2)),
            (200, 0, 0), 1)
        output_image = cv2.putText(
            output_image, bdbox[4],
            (int(bdbox[0] - bdbox[2] / 2), int(bdbox[1] - bdbox[3] / 2)),
            cv2.FONT_HERSHEY_SIMPLEX, 0.1, (0, 255, 0), 1)
    # output_image = np.multiply( output_image, 255 )

    generate_image = FLAGS.save_dir + '/res.jpg'
    if not os.path.exists(FLAGS.save_dir):
        os.makedirs(FLAGS.save_dir)

    cv2.imwrite(generate_image, cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR))
    end_time = time.time()

    print('Use time: ', end_time - start_time)

    plt.imshow(output_image)
    plt.show()