def run_test(): """Train CAPTCHA for a number of steps.""" with tf.Graph().as_default(): valid_reader = Reader('tiny_val.tfrecord', name='valid_data', min_queue_examples=50000, batch_size=BATCH_SIZE, num_threads=13) HSIs_valid, labels_valid = valid_reader.feed(train_data = False) # hsi_pl, labels_pl = placeholder_inputs(BATCH_SIZE) HSI = HSI_branch() logits1, outputs1, variables1 = HSI(HSIs_valid, keep_prob=1) #(500, 2432) saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: start_time = time.time() do_eval(sess, HSIs_valid, labels_valid, logits1) duration = time.time() - start_time logging.info('>>Take time %.3f(sec)' % duration) except KeyboardInterrupt: print('INTERRUPTED') coord.request_stop() finally: coord.request_stop() coord.join(threads) sess.close()
def run_train(): """Train CAPTCHA for a number of steps.""" with tf.Graph().as_default(): records = glob.glob('../HSI/data/*.tfrecord') train_readers = [] for record in records: record_name = 'train_data' + os.path.basename(record).split('.')[0] train_reader = Reader(record, name=record_name, batch_size=40) train_readers.append(train_reader) valid_reader = Reader('tiny_val.tfrecord', name='valid_data', batch_size=BATCH_SIZE) train_imgs_and_labels = [ train_reader_.feed(train_data=True) for train_reader_ in train_readers ] HSIs_valid, labels_valid = valid_reader.feed(train_data=False) hsi_pl, labels_pl = placeholder_inputs(BATCH_SIZE) HSI = HSI_branch() logits1, variables1, logits2, variables2, logits, outputs, variables = HSI( hsi_pl, keep_prob=0.5) predicts1 = model_predict(logits1) predicts2 = model_predict(logits2) predicts = model_predict(logits) loss1, loss2, loss = model_loss(logits1, logits2, logits, labels_pl) train_op = model_training(variables1, loss1, variables2, loss2, variables, loss) summary = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=50) # init_op = tf.global_variables_initializer() sess = tf.Session() summary_writer = tf.summary.FileWriter(train_dir, sess.graph) # sess.run(init_op) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: max_step = 50000 for step in range(5144, max_step): start_time = time.time() images_and_labels = sess.run(train_imgs_and_labels) HSIs, labels = [], [] for hsi, label in images_and_labels: HSIs.extend(hsi) labels.extend(label) shuffle = np.random.permutation(range(len(labels))) HSIs = np.array(HSIs) labels = np.array(labels) HSIs = HSIs[shuffle][:BATCH_SIZE] labels = labels[shuffle][:BATCH_SIZE] _, loss_value1,loss_value2,loss_value, summary_str, predicts_value = sess.run([train_op, loss1,loss2,loss, summary, predicts],\ feed_dict={hsi_pl : HSIs, labels_pl : labels} ) train_precision1 = np.mean(predicts_value == labels) summary_writer.add_summary(summary_str, step) summary_writer.flush() duration = time.time() - start_time count = (step % 500) or 500 message = ( '>>Step: %d loss1 = %.4f loss2 = %.4f loss = %.4f acc = %.3f(%.3f sec) ETA = %.3f' % (step, loss_value1, loss_value2, loss_value, train_precision1, duration, (500 - count) * duration)) view_bar(message, count, 500) #------------------------------- if step % 500 == 0: logging.info('>>%s Saving in %s' % (datetime.now(), checkpoint_dir)) saver.save(sess, checkpoint_file, global_step=step) logging.info('Valid Data Eval:') do_eval( sess, step, HSIs_valid, labels_valid, hsi_pl, labels_pl, predicts1, predicts2, predicts, ) except KeyboardInterrupt: print('INTERRUPTED') coord.request_stop() finally: saver.save(sess, checkpoint_file, global_step=step) print('Model saved in file :%s' % checkpoint_dir) coord.request_stop() coord.join(threads) sess.close()
def run_train(): """Train CAPTCHA for a number of steps.""" test_data = dataset.read_data_sets( dataset_dir='/home/sw/Documents/rgb-nir2/qd_fang2_9_8/field_2ch.npz') with tf.Graph().as_default(): train_reader = Reader( '/home/sw/Documents/rgb-nir2/qd_fang2_9_8/country_2ch.tfrecord', name='train_data', batch_size=BATCH_SIZE) leftIMG, rightIMG, labels_op = train_reader.feed() #[64,128] images_pl1, images_pl2, labels_pl = placeholder_inputs(BATCH_SIZE) conv_features1, features1 = model.get_features(images_pl1, reuse=False) conv_features2, features2 = model.get_features(images_pl2, reuse=True) predicts = tf.sqrt( tf.reduce_sum(tf.square(features1 - features2), axis=1)) total_loss = model.caculate_loss(conv_features1, conv_features2, features1, features2) tf.summary.scalar('sum_loss', total_loss) train_op = model.training(total_loss) summary = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=50) # init_op = tf.global_variables_initializer() sess = tf.Session() summary_writer = tf.summary.FileWriter(train_dir, sess.graph) # sess.run(init_op) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: max_step = 500000 for step in range(390380, max_step): start_time = time.time() lefts, rights, batch_labels = sess.run( [leftIMG, rightIMG, labels_op]) _, summary_str, loss_value = sess.run([train_op, summary, total_loss], \ feed_dict={images_pl1:lefts, images_pl2:rights, labels_pl:batch_labels}) summary_writer.add_summary(summary_str, step) summary_writer.flush() duration = time.time() - start_time if step % 10 == 0: logging.info( '>> Step %d run_train: loss = %.4f (%.3f sec)' % (step, loss_value, duration)) #------------------------------- if step % 1000 == 0: logging.info('>> %s Saving in %s' % (datetime.now(), checkpoint_dir)) saver.save(sess, checkpoint_file, global_step=step) logging.info('Test Data Eval:') do_eval(sess, step, predicts, images_pl1, images_pl2, labels_pl, test_data, name='notredame') except KeyboardInterrupt: print('INTERRUPTED') coord.request_stop() finally: saver.save(sess, checkpoint_file, global_step=step) print('\rModel saved in file :%s' % checkpoint_dir) coord.request_stop() coord.join(threads) sess.close()
def run_train(): """Train CAPTCHA for a number of steps.""" with tf.Graph().as_default(): train_reader = Reader( '/home/sw/Documents/rgb-nir2/qd_fang2_9_8/country_2ch.tfrecord', name='train_data', batch_size=BATCH_SIZE) leftIMG, rightIMG, labels_op = train_reader.feed() #[64,128] conv_features1, features1 = model.get_features(leftIMG, reuse=False) conv_features2, features2 = model.get_features(rightIMG, reuse=True) predicts = tf.sqrt( tf.reduce_sum(tf.square(features1 - features2), axis=1)) total_loss = model.caculate_loss(conv_features1, conv_features2, features1, features2) eval_all = model.evaluation(features1, features2, labels_op, 1) #train tf.summary.scalar('sum_loss', total_loss) tf.summary.scalar('roc/tp', eval_all['tp']) tf.summary.scalar('roc/fp', eval_all['fp']) tf.summary.scalar('roc/tpr', eval_all['tp'] / (eval_all['tp'] + eval_all['fn'])) tf.summary.scalar('roc/precision', eval_all['precision']) train_op = model.training(total_loss) summary = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=50) init_op = tf.global_variables_initializer() sess = tf.Session() summary_writer = tf.summary.FileWriter(train_dir, sess.graph) sess.run(init_op) # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: max_step = 100000 for step in range(1, max_step): start_time = time.time() _, summary_str, loss_value, predicts_value, f1, f2 = sess.run([ train_op, summary, total_loss, predicts, features1, features2 ]) summary_writer.add_summary(summary_str, step) summary_writer.flush() duration = time.time() - start_time if step % 10 == 0: logging.info( '\r>> Step %d run_train: loss = %.4f (%.3f sec)' % (step, loss_value, duration)) #------------------------------- if step % 1000 == 0: logging.info('>> %s Saving in %s' % (datetime.now(), checkpoint_dir)) saver.save(sess, checkpoint_file, global_step=step) except KeyboardInterrupt: print('INTERRUPTED') coord.request_stop() finally: saver.save(sess, checkpoint_file, global_step=step) print('\rModel saved in file :%s' % checkpoint_dir) coord.request_stop() coord.join(threads) sess.close()