def main(argv=None): import os os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list if not tf.gfile.Exists(FLAGS.checkpoint_path): tf.gfile.MkDir(FLAGS.checkpoint_path) else: if not FLAGS.restore: tf.gfile.DeleteRecursively(FLAGS.checkpoint_path) tf.gfile.MkDir(FLAGS.checkpoint_path) input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') input_score_maps = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps') input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 5], name='input_geo_maps') input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_training_masks') input_transcription = tf.sparse_placeholder(tf.int32, name='input_transcription') input_transform_matrix = tf.placeholder(tf.float32, shape=[None, 6], name='input_transform_matrix') input_transform_matrix = tf.stop_gradient(input_transform_matrix) input_box_masks = [] input_box_widths = tf.placeholder(tf.int32, shape=[None], name='input_box_widths') input_seq_len = input_box_widths[tf.argmax( input_box_widths, 0)] * tf.ones_like(input_box_widths) for i in range(FLAGS.batch_size_per_gpu): input_box_masks.append( tf.placeholder(tf.int32, shape=[None], name='input_box_masks_' + str(i))) f_score, f_geometry, recognition_logits = build_graph( input_images, input_transform_matrix, input_box_masks, input_box_widths, input_seq_len) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) # learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True) learning_rate = FLAGS.learning_rate # add summary tf.summary.scalar('learning_rate', learning_rate) opt = tf.train.AdamOptimizer(learning_rate) d_loss, r_loss, model_loss = compute_loss(f_score, f_geometry, recognition_logits, input_score_maps, input_geo_maps, input_training_masks, input_transcription, input_box_widths) # total_loss = detect_part.loss(input_score_maps, f_score, input_geo_maps, f_geometry, input_training_masks) tf.summary.scalar('total_loss', model_loss) total_loss = tf.add_n( [model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) # total_loss = model_loss batch_norm_updates_op = tf.group( *tf.get_collection(tf.GraphKeys.UPDATE_OPS)) if FLAGS.train_stage == 1: print("Train recognition branch only!") recog_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='recog') # grads = opt.compute_gradients(total_loss, recog_vars) grads = opt.compute_gradients(total_loss) else: grads = opt.compute_gradients(total_loss) # greds clip for i, (g, v) in enumerate(grads): if g is not None: grads[i] = (tf.clip_by_norm(g, 1.0), v) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) summary_op = tf.summary.merge_all() # save moving average variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) # batch norm updates with tf.control_dependencies( [variables_averages_op, apply_gradient_op, batch_norm_updates_op]): train_op = tf.no_op(name='train_op') saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph()) init = tf.global_variables_initializer() if FLAGS.pretrained_model_path is not None: if os.path.isdir(FLAGS.pretrained_model_path): print("Restore pretrained model from other datasets") ckpt = tf.train.latest_checkpoint(FLAGS.pretrained_model_path) variable_restore_op = slim.assign_from_checkpoint_fn( ckpt, slim.get_trainable_variables(), ignore_missing_vars=True) else: # is *.ckpt print("Restore pretrained model from imagenet") variable_restore_op = slim.assign_from_checkpoint_fn( FLAGS.pretrained_model_path, slim.get_trainable_variables(), ignore_missing_vars=True) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: if FLAGS.restore: print('continue training from previous checkpoint') ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path) saver.restore(sess, ckpt) else: sess.run(init) if FLAGS.pretrained_model_path is not None: variable_restore_op(sess) dg = data_generator.get_batch(input_images_dir=FLAGS.training_data_dir, input_gt_dir=FLAGS.training_gt_data_dir, num_workers=FLAGS.num_readers, input_size=FLAGS.input_size, batch_size=FLAGS.batch_size_per_gpu) start = time.time() for step in range(FLAGS.max_steps): data = next(dg) inp_dict = { input_images: data[0], input_score_maps: data[2], input_geo_maps: data[3], input_training_masks: data[4], input_transform_matrix: data[5], input_box_widths: data[7], input_transcription: data[8] } for i in range(FLAGS.batch_size_per_gpu): inp_dict[input_box_masks[i]] = data[6][i] dl, rl, tl, _ = sess.run([d_loss, r_loss, total_loss, train_op], feed_dict=inp_dict) if np.isnan(tl): print('Loss diverged, stop training') break if step % 10 == 0: avg_time_per_step = (time.time() - start) / 10 avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu) / ( time.time() - start) start = time.time() print( 'Step {:06d}, detect_loss {:.4f}, recognize_loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second' .format(step, dl, rl, tl, avg_time_per_step, avg_examples_per_second)) """ print "recognition results: " for pred in result: print icdar.ground_truth_to_word(pred) """ if step % FLAGS.save_checkpoint_steps == 0: saver.save(sess, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step) if step % FLAGS.save_summary_steps == 0: """ _, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0], input_score_maps: data[2], input_geo_maps: data[3], input_training_masks: data[4]}) """ dl, rl, tl, _, summary_str = sess.run( [d_loss, r_loss, total_loss, train_op, summary_op], feed_dict=inp_dict) summary_writer.add_summary(summary_str, global_step=step)
def get_data(image_dir, gt_path, voc_type, max_len, num_samples, height, width, batch_size, workers, keep_ratio, with_aug): data_list = [] if isinstance(image_dir, list) and len(image_dir) > 1: # assert len(image_dir) == len(gt_path), "datasets and gt are not corresponding" assert batch_size % len( image_dir) == 0, "batch size should divide dataset num" per_batch_size = batch_size // len(image_dir) if None in gt_path: # Using lmdb input for i in image_dir: data_list.append( lmdb_data_generator.get_batch(workers, lmdb_dir=i, input_height=height, input_width=width, batch_size=per_batch_size, max_len=max_len, voc_type=voc_type, keep_ratio=keep_ratio, with_aug=with_aug)) else: for i, g in zip(image_dir, gt_path): data_list.append( data_generator.get_batch(workers, image_dir=i, gt_path=g, input_height=height, input_width=width, batch_size=per_batch_size, max_len=max_len, voc_type=voc_type, keep_ratio=keep_ratio, with_aug=with_aug)) else: if isinstance(image_dir, list): if None in gt_path: data = lmdb_data_generator.get_batch(workers, lmdb_dir=image_dir[0], input_height=height, input_width=width, batch_size=batch_size, max_len=max_len, voc_type=voc_type, keep_ratio=keep_ratio, with_aug=with_aug) else: data = data_generator.get_batch(workers, image_dir=image_dir[0], gt_path=gt_path[0], input_height=height, input_width=width, batch_size=batch_size, max_len=max_len, voc_type=voc_type, keep_ratio=keep_ratio, with_aug=with_aug) else: if gt_path is None: data = lmdb_data_generator.get_batch(workers, lmdb_dir=image_dir, input_height=height, input_width=width, batch_size=batch_size, max_len=max_len, voc_type=voc_type, keep_ratio=keep_ratio, with_aug=with_aug) else: data = data_generator.get_batch(workers, image_dir=image_dir, gt_path=gt_path, input_height=height, input_width=width, batch_size=batch_size, max_len=max_len, voc_type=voc_type, keep_ratio=keep_ratio, with_aug=with_aug) data_list.append(data) return data_list