def run_test():
  """Train CAPTCHA for a number of steps."""

  with tf.Graph().as_default():      
    valid_reader = Reader('tiny_val.tfrecord', name='valid_data', min_queue_examples=50000, batch_size=BATCH_SIZE, num_threads=13)    
    HSIs_valid, labels_valid = valid_reader.feed(train_data = False)
#    hsi_pl, labels_pl = placeholder_inputs(BATCH_SIZE)
    
    HSI = HSI_branch()
    logits1, outputs1, variables1 = HSI(HSIs_valid, keep_prob=1) #(500, 2432)
 
    saver = tf.train.Saver()
    sess = tf.Session()
    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:      

      start_time = time.time()        

      do_eval(sess, HSIs_valid, labels_valid, logits1)
      duration = time.time() - start_time
      
      logging.info('>>Take time %.3f(sec)' % duration)
    except KeyboardInterrupt:
        print('INTERRUPTED')
        coord.request_stop()
    finally:
        coord.request_stop()
        coord.join(threads)

    sess.close()
示例#2
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        records = glob.glob('../HSI/data/*.tfrecord')
        train_readers = []
        for record in records:
            record_name = 'train_data' + os.path.basename(record).split('.')[0]
            train_reader = Reader(record, name=record_name, batch_size=40)
            train_readers.append(train_reader)

        valid_reader = Reader('tiny_val.tfrecord',
                              name='valid_data',
                              batch_size=BATCH_SIZE)
        train_imgs_and_labels = [
            train_reader_.feed(train_data=True)
            for train_reader_ in train_readers
        ]

        HSIs_valid, labels_valid = valid_reader.feed(train_data=False)
        hsi_pl, labels_pl = placeholder_inputs(BATCH_SIZE)

        HSI = HSI_branch()
        logits1, variables1, logits2, variables2, logits, outputs, variables = HSI(
            hsi_pl, keep_prob=0.5)

        predicts1 = model_predict(logits1)
        predicts2 = model_predict(logits2)
        predicts = model_predict(logits)

        loss1, loss2, loss = model_loss(logits1, logits2, logits, labels_pl)

        train_op = model_training(variables1, loss1, variables2, loss2,
                                  variables, loss)
        summary = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=50)
        #    init_op = tf.global_variables_initializer()
        sess = tf.Session()
        summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
        #    sess.run(init_op)
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            max_step = 50000
            for step in range(5144, max_step):
                start_time = time.time()
                images_and_labels = sess.run(train_imgs_and_labels)
                HSIs, labels = [], []
                for hsi, label in images_and_labels:
                    HSIs.extend(hsi)
                    labels.extend(label)

                shuffle = np.random.permutation(range(len(labels)))
                HSIs = np.array(HSIs)
                labels = np.array(labels)

                HSIs = HSIs[shuffle][:BATCH_SIZE]
                labels = labels[shuffle][:BATCH_SIZE]

                _, loss_value1,loss_value2,loss_value, summary_str, predicts_value = sess.run([train_op, loss1,loss2,loss, summary, predicts],\
                                                      feed_dict={hsi_pl : HSIs, labels_pl : labels}
                                                      )

                train_precision1 = np.mean(predicts_value == labels)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()
                duration = time.time() - start_time

                count = (step % 500) or 500
                message = (
                    '>>Step: %d loss1 = %.4f loss2 = %.4f loss = %.4f acc = %.3f(%.3f sec) ETA = %.3f'
                    % (step, loss_value1, loss_value2, loss_value,
                       train_precision1, duration, (500 - count) * duration))

                view_bar(message, count, 500)
                #-------------------------------
                if step % 500 == 0:
                    logging.info('>>%s Saving in %s' %
                                 (datetime.now(), checkpoint_dir))
                    saver.save(sess, checkpoint_file, global_step=step)

                    logging.info('Valid Data Eval:')
                    do_eval(
                        sess,
                        step,
                        HSIs_valid,
                        labels_valid,
                        hsi_pl,
                        labels_pl,
                        predicts1,
                        predicts2,
                        predicts,
                    )

        except KeyboardInterrupt:
            print('INTERRUPTED')
            coord.request_stop()

        finally:
            saver.save(sess, checkpoint_file, global_step=step)
            print('Model saved in file :%s' % checkpoint_dir)
            coord.request_stop()
            coord.join(threads)

        sess.close()
def run_train():
    """Train CAPTCHA for a number of steps."""

    test_data = dataset.read_data_sets(
        dataset_dir='/home/sw/Documents/rgb-nir2/qd_fang2_9_8/field_2ch.npz')
    with tf.Graph().as_default():
        train_reader = Reader(
            '/home/sw/Documents/rgb-nir2/qd_fang2_9_8/country_2ch.tfrecord',
            name='train_data',
            batch_size=BATCH_SIZE)
        leftIMG, rightIMG, labels_op = train_reader.feed()  #[64,128]

        images_pl1, images_pl2, labels_pl = placeholder_inputs(BATCH_SIZE)
        conv_features1, features1 = model.get_features(images_pl1, reuse=False)
        conv_features2, features2 = model.get_features(images_pl2, reuse=True)
        predicts = tf.sqrt(
            tf.reduce_sum(tf.square(features1 - features2), axis=1))

        total_loss = model.caculate_loss(conv_features1, conv_features2,
                                         features1, features2)
        tf.summary.scalar('sum_loss', total_loss)
        train_op = model.training(total_loss)

        summary = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=50)
        #    init_op = tf.global_variables_initializer()

        sess = tf.Session()
        summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
        #    sess.run(init_op)
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            max_step = 500000
            for step in range(390380, max_step):
                start_time = time.time()
                lefts, rights, batch_labels = sess.run(
                    [leftIMG, rightIMG, labels_op])
                _, summary_str, loss_value = sess.run([train_op, summary, total_loss], \
                                                                             feed_dict={images_pl1:lefts, images_pl2:rights, labels_pl:batch_labels})

                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()
                duration = time.time() - start_time
                if step % 10 == 0:
                    logging.info(
                        '>> Step %d run_train: loss = %.4f  (%.3f sec)' %
                        (step, loss_value, duration))
                    #-------------------------------
                if step % 1000 == 0:
                    logging.info('>> %s Saving in %s' %
                                 (datetime.now(), checkpoint_dir))

                    saver.save(sess, checkpoint_file, global_step=step)

                    logging.info('Test Data Eval:')
                    do_eval(sess,
                            step,
                            predicts,
                            images_pl1,
                            images_pl2,
                            labels_pl,
                            test_data,
                            name='notredame')

        except KeyboardInterrupt:
            print('INTERRUPTED')
            coord.request_stop()

        finally:
            saver.save(sess, checkpoint_file, global_step=step)
            print('\rModel saved in file :%s' % checkpoint_dir)
            coord.request_stop()
            coord.join(threads)

        sess.close()
示例#4
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        train_reader = Reader(
            '/home/sw/Documents/rgb-nir2/qd_fang2_9_8/country_2ch.tfrecord',
            name='train_data',
            batch_size=BATCH_SIZE)
        leftIMG, rightIMG, labels_op = train_reader.feed()  #[64,128]

        conv_features1, features1 = model.get_features(leftIMG, reuse=False)
        conv_features2, features2 = model.get_features(rightIMG, reuse=True)
        predicts = tf.sqrt(
            tf.reduce_sum(tf.square(features1 - features2), axis=1))

        total_loss = model.caculate_loss(conv_features1, conv_features2,
                                         features1, features2)
        eval_all = model.evaluation(features1, features2, labels_op, 1)  #train

        tf.summary.scalar('sum_loss', total_loss)
        tf.summary.scalar('roc/tp', eval_all['tp'])
        tf.summary.scalar('roc/fp', eval_all['fp'])
        tf.summary.scalar('roc/tpr',
                          eval_all['tp'] / (eval_all['tp'] + eval_all['fn']))
        tf.summary.scalar('roc/precision', eval_all['precision'])
        train_op = model.training(total_loss)

        summary = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=50)
        init_op = tf.global_variables_initializer()

        sess = tf.Session()
        summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
        sess.run(init_op)
        #    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            max_step = 100000
            for step in range(1, max_step):
                start_time = time.time()
                _, summary_str, loss_value, predicts_value, f1, f2 = sess.run([
                    train_op, summary, total_loss, predicts, features1,
                    features2
                ])
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()
                duration = time.time() - start_time
                if step % 10 == 0:
                    logging.info(
                        '\r>> Step %d run_train: loss = %.4f  (%.3f sec)' %
                        (step, loss_value, duration))
                    #-------------------------------
                if step % 1000 == 0:
                    logging.info('>> %s Saving in %s' %
                                 (datetime.now(), checkpoint_dir))
                    saver.save(sess, checkpoint_file, global_step=step)

        except KeyboardInterrupt:
            print('INTERRUPTED')
            coord.request_stop()

        finally:
            saver.save(sess, checkpoint_file, global_step=step)
            print('\rModel saved in file :%s' % checkpoint_dir)
            coord.request_stop()
            coord.join(threads)

        sess.close()