def main_test(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, None, 3], name="input_images") input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width") input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) model_infer, attention_weights, pred = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) images_path, labels = get_data(args) predicts = [] for img_path, label in zip(images_path, labels): try: img = cv2.imread(img_path) except Exception as e: print("{} error: {}".format(img_path, e)) continue img, la, width = data_preprocess(img, label, char2id, args) pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img], input_labels: [la], input_images_width: [width]}) pred_value_str = idx2label(pred_value, id2char, char2id)[0] print("predict: {} label: {}".format(pred_value_str, label)) predicts.append(pred_value_str) pred_value_str += '$' if args.vis_dir != None and args.vis_dir != "": os.makedirs(args.vis_dir, exist_ok=True) assert len(img.shape) == 3 att_maps = attention_weights_value.reshape([-1, attention_weights_value.shape[2], attention_weights_value.shape[3], 1]) # T * H * W * 1 for i, att_map in enumerate(att_maps): if i >= len(pred_value_str): break att_map = cv2.resize(att_map, (img.shape[1], img.shape[0])) _att_map = np.zeros(dtype=np.uint8, shape=[img.shape[0], img.shape[1], 3]) _att_map[:, :, -1] = (att_map * 255).astype(np.uint8) show_attention = cv2.addWeighted(img, 0.5, _att_map, 2, 0) cv2.imwrite(os.path.join(args.vis_dir, os.path.basename(img_path).split('.')[0] + "_" + str(i) + "_" + pred_value_str[i] + ".jpg"), show_attention) acc_rate = accuracy(predicts, labels) print("Done, Accuracy: {}".format(acc_rate))
def main_test_with_lexicon(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images") input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width") input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True) # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type) model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) # saver = tf.train.Saver(tf.global_variables()) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) print("Checkpoints step: {}".format(global_step.eval(session=sess))) images_path, labels, lexicons = get_data_lexicon(args) predicts = [] # for img_path, label in zip(images_path, labels): for i in tqdm.tqdm(range(len(images_path))): img_path = images_path[i] label = labels[i] try: img = cv2.imread(img_path) except Exception as e: print("{} error: {}".format(img_path, e)) continue img, la, width = data_preprocess(img, label, char2id, args) pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img], input_labels: [la], input_images_width: [width]}) pred_value_str = idx2label(pred_value, id2char, char2id)[0] # print("predict: {} label: {}".format(pred_value_str, label)) predicts.append(pred_value_str) if args.vis_dir != None and args.vis_dir != "": os.makedirs(args.vis_dir, exist_ok=True) os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True) _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path) if pred_value_str.lower() != label.lower(): _ = line_visualize(img, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), img_path) acc_rate = calc_metrics_lexicon(predicts, labels, lexicons) print("Done, Accuracy: {}".format(acc_rate))
def test(args, cpks): assert isinstance(cpks, str) voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS") test_data = build_dataloader(args.val_data_cfg) print("test data: {}".format(len(test_data))) model = MultiInstanceRecognition(args.model_cfg).cuda() # model = MMDataParallel(model).cuda() model.load_state_dict(torch.load(cpks)) model.eval() pred_strs = [] gt_strs = [] test_data_iter = iter(test_data) for i, batch_data in enumerate(test_data): torch.cuda.empty_cache() batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ batch_data if batch_imgs is None: continue batch_imgs = batch_imgs.cuda() batch_rectangles = batch_rectangles.cuda() batch_text_labels = batch_text_labels.cuda() with torch.no_grad(): loss, decoder_logits = model(batch_imgs, batch_text_labels, batch_rectangles, batch_text_labels_mask) pred_labels = decoder_logits.argmax(dim=2).cpu().numpy() pred_value_str = idx2label(pred_labels, id2char, char2id) gt_str = batch_words for i in range(len(gt_str[0])): print("predict: {} label: {}".format(pred_value_str[i], gt_str[0][i])) pred_strs.append(pred_value_str[i]) gt_strs.append(gt_str[0][i]) val_dec_metrics_result = calc_metrics(pred_strs, gt_strs, metrics_type="accuracy") print("test accuracy= {:3f}".format(val_dec_metrics_result)) # # # val_loss_value)) print('---------')
def train(cfg, args): logger = logging.getLogger('model training') train_data = build_dataloader(cfg.train_data_cfg, args.distributed) logger.info("train data: {}".format(len(train_data))) val_data = build_dataloader(cfg.val_data_cfg, args.distributed) logger.info("val data: {}".format(len(val_data))) model = MultiInstanceRecognition(cfg.model_cfg).cuda() if cfg.resume_from is not None: logger.info('loading pretrained models from {opt.continue_model}') model.load_state_dict(torch.load(cfg.resume_from)) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS") filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) logger.info('Trainable params num : ', sum(params_num)) optimizer = optim.Adam(filtered_parameters, lr=cfg.lr, betas=(0.9, 0.999)) lrScheduler = lr_scheduler.MultiStepLR(optimizer, [1, 2, 3], gamma=0.1) max_iters = cfg.max_iters start_iter = 0 if cfg.resume_from is not None: start_iter = int(cfg.resume_from.split('_')[-1].split('.')[0]) logger.info('continue to train, start_iter: {start_iter}') train_data_iter = iter(train_data) val_data_iter = iter(val_data) start_time = time.time() for i in range(start_iter, max_iters): model.train() try: batch_data = next(train_data_iter) except StopIteration: train_data_iter = iter(train_data) batch_data = next(train_data_iter) data_time_s = time.time() batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ batch_data while batch_imgs is None: batch_data = next(train_data_iter) batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ batch_data batch_imgs = batch_imgs.cuda(non_blocking=True) batch_rectangles = batch_rectangles.cuda(non_blocking=True) batch_text_labels = batch_text_labels.cuda(non_blocking=True) data_time = time.time() - data_time_s # print(time.time() -s) # s = time.time() loss, decoder_logits = model(batch_imgs, batch_text_labels, batch_rectangles, batch_text_labels_mask) del batch_data # print(time.time() - s) # print('------') # s = time.time() loss = loss.mean() print(loss) # del loss # print(time.time() - s) # print('------') if i % cfg.train_verbose == 0: this_time = time.time() - start_time if args.distributed: loss = dist.reduce(loss, 0) log_info = "train iter :{}, time: {:.2f}, data_time: {:.2f}, Loss: {:.3f}".format( i, this_time, data_time, loss.data) logger.info(log_info) torch.cuda.empty_cache() # break optimizer.zero_grad() loss.backward() optimizer.step() del loss if i % cfg.val_iter == 0: print("--------Val iteration---------") model.eval() try: val_batch = next(val_data_iter) except StopIteration: val_data_iter = iter(val_data) val_batch = next(val_data_iter) batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ val_batch while batch_imgs is None: val_batch = next(val_data_iter) batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ val_batch del val_batch batch_imgs = batch_imgs.cuda(non_blocking=True) batch_rectangles = batch_rectangles.cuda(non_blocking=True) batch_text_labels = batch_text_labels.cuda(non_blocking=True) with torch.no_grad(): val_loss, val_pred_logits = model(batch_imgs, batch_text_labels, batch_rectangles, batch_text_labels_mask) pred_labels = val_pred_logits.argmax(dim=2).cpu().numpy() pred_value_str = idx2label(pred_labels, id2char, char2id) # gt_str = batch_words gt_str = [] for words in batch_words: gt_str = gt_str + words val_dec_metrics_result = calc_metrics(pred_value_str, gt_str, metrics_type="accuracy") this_time = time.time() - start_time if args.distributed: loss = dist.reduce(val_loss, 0) log_info = "val iter :{}, time: {:.2f} Loss: {:.3f}, acc: {:.2f}".format( i, this_time, loss.mean().data, val_dec_metrics_result) logger.info(log_info) del val_loss if (i + 1) % cfg.save_iter == 0: torch.save(model.state_dict(), cfg.save_name + '_{}.pth'.format(i + 1)) if i > 0 and i % cfg.lr_step == 0: # 调整学习速率 lrScheduler.step() logger.info("lr step") # torch.cuda.empty_cache() print('end the training')
def main_test_lmdb(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images") input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width") input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True) # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type) model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) # saver = tf.train.Saver(tf.global_variables()) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) print("Checkpoints step: {}".format(global_step.eval(session=sess))) env = lmdb.open(args.test_data_dir, readonly=True) txn = env.begin() num_samples = int(txn.get(b"num-samples").decode()) predicts = [] labels = [] # for img_path, label in zip(images_path, labels): for i in tqdm.tqdm(range(1, num_samples+1)): image_key = b'image-%09d' % i label_key = b'label-%09d' % i imgbuf = txn.get(image_key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) img_pil = Image.open(buf).convert('RGB') img = np.array(img_pil) label = txn.get(label_key).decode() labels.append(label) img, la, width = data_preprocess(img, label, char2id, args) pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img], input_labels: [la], input_images_width: [ width]}) pred_value_str = idx2label(pred_value, id2char, char2id)[0] # print("predict: {} label: {}".format(pred_value_str, label)) predicts.append(pred_value_str) img_vis = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if args.vis_dir != None and args.vis_dir != "": os.makedirs(args.vis_dir, exist_ok=True) os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True) # _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i)) _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i)) # _ = mask_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path) if pred_value_str.lower() != label.lower(): _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), "{}.jpg".format(i)) acc_rate = calc_metrics(predicts, labels) print("Done, Accuracy: {}".format(acc_rate))
def main_train(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) tf.set_random_seed(1) # Build graph input_train_images = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.height, args.width, 3], name="input_train_images") input_train_images_width = tf.placeholder(dtype=tf.float32, shape=[args.train_batch_size], name="input_train_width") input_train_labels = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels") input_train_gauss_labels = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 6, 40], name="input_train_gauss_labels") # better way wanted!!! input_train_gauss_tags = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len], name="input_train_gauss_tags") input_train_gauss_params = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 4], name="input_train_gauss_params") input_train_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels_mask") input_train_char_sizes = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 2], name="input_train_char_sizes") input_train_gauss_mask = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 6, 40], name="input_train_gauss_mask") input_val_images = tf.placeholder( dtype=tf.float32, shape=[args.val_batch_size, args.height, args.width, 3], name="input_val_images") input_val_images_width = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size], name="input_val_width") input_val_labels = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels") # input_val_gauss_labels = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size, args.max_len, args.height, args.width], name="input_val_gauss_labels") input_val_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels_mask") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=True, att_loss_type=args.att_loss_type, att_loss_weight=args.att_loss_weight) sar_model_val = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) train_model_infer, train_attention_weights, train_pred, train_attention_params = sar_model( input_train_images, input_train_labels, input_train_images_width, batch_size=args.train_batch_size, reuse=False) if args.att_loss_type == "kldiv": train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_weights, input_train_labels, input_train_gauss_labels, input_train_labels_mask, input_train_gauss_tags) elif args.att_loss_type == "l1" or args.att_loss_type == "l2": # train_loss, train_recog_loss, train_att_loss = sar_model.loss(train_model_infer, train_attention_params, input_train_labels, input_train_gauss_params, input_train_labels_mask, input_train_gauss_tags, input_train_char_sizes) train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_weights, input_train_labels, input_train_gauss_labels, input_train_labels_mask, input_train_gauss_tags) elif args.att_loss_type == 'ce': train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_weights, input_train_labels, input_train_gauss_labels, input_train_labels_mask, input_train_gauss_tags, input_train_gauss_mask) elif args.att_loss_type == 'gausskldiv': train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_params, input_train_labels, input_train_gauss_params, input_train_labels_mask, input_train_gauss_tags) else: print("Unimplemented loss type {}".format(args.att_loss_dtype)) exit(-1) val_model_infer, val_attention_weights, val_pred, _ = sar_model_val( input_val_images, input_val_labels, input_val_images_width, batch_size=args.val_batch_size, reuse=True) train_data_list = get_data(args.train_data_dir, args.voc_type, args.max_len, args.height, args.width, args.train_batch_size, args.workers, args.keep_ratio, with_aug=args.aug) val_data_gen = evaluator_data.Evaluator(lmdb_data_dir=args.test_data_dir, batch_size=args.val_batch_size, height=args.height, width=args.width, max_len=args.max_len, keep_ratio=args.keep_ratio, voc_type=args.voc_type) val_data_gen.reset() global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) learning_rate = tf.train.piecewise_constant(global_step, args.decay_bound, args.lr_stage) batch_norm_updates_op = tf.group(tf.get_collection( tf.GraphKeys.UPDATE_OPS)) # Save summary os.makedirs(args.checkpoints, exist_ok=True) tf.summary.scalar(name='train_loss', tensor=train_loss) tf.summary.scalar(name='train_recog_loss', tensor=train_recog_loss) tf.summary.scalar(name='train_att_loss', tensor=train_att_loss) # tf.summary.scalar(name='val_att_loss', tensor=val_att_loss) tf.summary.scalar(name='learning_rate', tensor=learning_rate) merge_summary_op = tf.summary.merge_all() train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'sar_{:s}.ckpt'.format(str(train_start_time)) model_save_path = os.path.join(args.checkpoints, model_name) best_model_save_path = os.path.join(args.checkpoints, 'best_model', model_name) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies( [variables_averages_op, batch_norm_updates_op]): optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) grads = optimizer.compute_gradients(train_loss) if args.grad_clip > 0: print("With Gradients clipped!") for idx, (grad, var) in enumerate(grads): grads[idx] = (tf.clip_by_norm(grad, args.grad_clip), var) train_op = optimizer.apply_gradients(grads, global_step=global_step) saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) summary_writer = tf.summary.FileWriter(args.checkpoints) config = tf.ConfigProto() config.gpu_options.allow_growth = True log_file = open(os.path.join(args.checkpoints, args.checkpoints + ".log"), "w") with tf.Session(config=config) as sess: summary_writer.add_graph(sess.graph) start_iter = 0 if args.resume == True and args.pretrained != '': print('Restore model from {:s}'.format(args.pretrained)) ckpt_state = tf.train.get_checkpoint_state(args.pretrained) model_path = os.path.join( args.pretrained, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) start_iter = sess.run(tf.train.get_global_step()) elif args.resume == False and args.pretrained != '': print('Restore pretrained model from {:s}'.format(args.pretrained)) ckpt_state = tf.train.get_checkpoint_state(args.pretrained) model_path = os.path.join( args.pretrained, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) sess.run(tf.assign(global_step, 0)) else: print('Training from scratch') init = tf.global_variables_initializer() sess.run(init) # Evaluate the model first val_pred_value_all = [] val_labels = [] for eval_iter in range(val_data_gen.num_samples // args.val_batch_size): val_data = val_data_gen.get_batch() if val_data is None: break print("Evaluation: [{} / {}]".format( eval_iter, (val_data_gen.num_samples // args.val_batch_size))) val_pred_value = sess.run(val_pred, feed_dict={ input_val_images: val_data[0], input_val_labels: val_data[1], input_val_images_width: val_data[5], input_val_labels_mask: val_data[2] }) val_pred_value_all.extend(val_pred_value) val_labels.extend(val_data[4]) val_data_gen.reset() val_metrics_result = calc_metrics(idx2label( np.array(val_pred_value_all)), val_labels, metrics_type="accuracy") print("Evaluation Before training: Test accuracy {:3f}".format( val_metrics_result)) val_best_acc = val_metrics_result while start_iter < args.iters: start_iter += 1 train_data = get_batch_data(train_data_list, args.train_batch_size) _, train_loss_value, train_recog_loss_value, train_att_loss_value, train_pred_value = sess.run( [ train_op, train_loss, train_recog_loss, train_att_loss, train_pred ], feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_gauss_labels: train_data[2], input_train_gauss_params: train_data[7], input_train_labels_mask: train_data[3], input_train_images_width: train_data[5], input_train_gauss_tags: train_data[6], input_train_char_sizes: train_data[8], input_train_gauss_mask: train_data[9] }) if start_iter % args.log_iter == 0: print( "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})" .format(start_iter, train_loss_value, train_recog_loss_value, train_att_loss_value)) log_file.write( "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})" .format(start_iter, train_loss_value, train_recog_loss_value, train_att_loss_value)) if start_iter % args.summary_iter == 0: merge_summary_value = sess.run(merge_summary_op, feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_gauss_labels: train_data[2], input_train_gauss_params: train_data[7], input_train_labels_mask: train_data[3], input_train_images_width: train_data[5], input_train_gauss_tags: train_data[6], input_train_char_sizes: train_data[8], input_train_gauss_mask: train_data[9] }) summary_writer.add_summary(summary=merge_summary_value, global_step=start_iter) if start_iter % args.eval_iter == 0: val_pred_value_all = [] val_labels = [] for eval_iter in range(val_data_gen.num_samples // args.val_batch_size): val_data = val_data_gen.get_batch() if val_data is None: break print("Evaluation: [{} / {}]".format( eval_iter, (val_data_gen.num_samples // args.val_batch_size))) val_pred_value = sess.run(val_pred, feed_dict={ input_val_images: val_data[0], input_val_labels: val_data[1], input_val_labels_mask: val_data[2], input_val_images_width: val_data[5] }) val_pred_value_all.extend(val_pred_value) val_labels.extend(val_data[4]) val_data_gen.reset() train_metrics_result = calc_metrics( idx2label(train_pred_value), train_data[4], metrics_type="accuracy") val_metrics_result = calc_metrics(idx2label( np.array(val_pred_value_all)), val_labels, metrics_type="accuracy") print( "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}" .format(start_iter, train_metrics_result, val_metrics_result)) log_file.write( "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}\n" .format(start_iter, train_metrics_result, val_metrics_result)) if val_metrics_result >= val_best_acc: print("Better results! Save checkpoitns to {}".format( best_model_save_path)) val_best_acc = val_metrics_result best_saver.save(sess, best_model_save_path, global_step=global_step) if start_iter % args.save_iter == 0: print("Iter {} save to checkpoint".format(start_iter)) saver.save(sess, model_save_path, global_step=global_step) log_file.close()
def main_train(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) # Build graph input_train_images = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.height, args.width, 3], name="input_train_images") input_train_images_width = tf.placeholder(dtype=tf.float32, shape=[args.train_batch_size], name="input_train_width") input_train_labels = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels") input_train_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels_mask") input_val_images = tf.placeholder( dtype=tf.float32, shape=[args.val_batch_size, args.height, args.width, 3], name="input_val_images") input_val_images_width = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size], name="input_val_width") input_val_labels = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels") input_val_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels_mask") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=True) sar_model_val = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) train_model_infer, train_attention_weights, train_pred = sar_model( input_train_images, input_train_labels, input_train_images_width, batch_size=args.train_batch_size, reuse=False) train_loss = sar_model.loss(train_model_infer, input_train_labels, input_train_labels_mask) val_model_infer, val_attention_weights, val_pred = sar_model_val( input_val_images, input_val_labels, input_val_images_width, batch_size=args.val_batch_size, reuse=True) val_loss = sar_model_val.loss(val_model_infer, input_val_labels, input_val_labels_mask) train_data_list = get_data(args.train_data_dir, args.train_data_gt, args.voc_type, args.max_len, args.num_train, args.height, args.width, args.train_batch_size, args.workers, args.keep_ratio, with_aug=args.aug) val_data_list = get_data(args.test_data_dir, args.test_data_gt, args.voc_type, args.max_len, args.num_train, args.height, args.width, args.val_batch_size, args.workers, args.keep_ratio, with_aug=False) global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) learning_rate = tf.train.exponential_decay(learning_rate=args.lr, global_step=global_step, decay_steps=args.decay_iter, decay_rate=args.weight_decay, staircase=True) batch_norm_updates_op = tf.group( *tf.get_collection(tf.GraphKeys.UPDATE_OPS)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) grads = optimizer.compute_gradients(train_loss) apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step) # Save summary os.makedirs(args.checkpoints, exist_ok=True) tf.summary.scalar(name='train_loss', tensor=train_loss) tf.summary.scalar(name='val_loss', tensor=val_loss) tf.summary.scalar(name='learning_rate', tensor=learning_rate) merge_summary_op = tf.summary.merge_all() train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'sar_{:s}.ckpt'.format(str(train_start_time)) model_save_path = os.path.join(args.checkpoints, model_name) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies( [variables_averages_op, apply_gradient_op, batch_norm_updates_op]): train_op = tf.no_op(name='train_op') saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) summary_writer = tf.summary.FileWriter(args.checkpoints) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: summary_writer.add_graph(sess.graph) start_iter = 0 if args.resume == True and args.pretrained != '': print('Restore model from {:s}'.format(args.pretrained)) ckpt_state = tf.train.get_checkpoint_state(args.pretrained) model_path = os.path.join( args.pretrained, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) start_iter = sess.run(tf.train.get_global_step()) else: print('Training from scratch') init = tf.global_variables_initializer() sess.run(init) while start_iter < args.iters: start_iter += 1 train_data = get_batch_data(train_data_list, args.train_batch_size) _, train_loss_value, train_pred_value = sess.run( [train_op, train_loss, train_pred], feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_labels_mask: train_data[2], input_train_images_width: train_data[4] }) if start_iter % args.log_iter == 0: print("Iter {} train loss= {:3f}".format( start_iter, train_loss_value)) if start_iter % args.summary_iter == 0: val_data = get_batch_data(val_data_list, args.val_batch_size) merge_summary_value, val_pred_value, val_loss_value = sess.run( [merge_summary_op, val_pred, val_loss], feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_labels_mask: train_data[2], input_train_images_width: train_data[4], input_val_images: val_data[0], input_val_labels: val_data[1], input_val_labels_mask: val_data[2], input_val_images_width: val_data[4] }) summary_writer.add_summary(summary=merge_summary_value, global_step=start_iter) if start_iter % args.eval_iter == 0: print("#" * 80) print("train prediction \t train labels ") for result, gt in zip(idx2label(train_pred_value), train_data[3]): print("{} \t {}".format(result, gt)) print("#" * 80) print("test prediction \t test labels ") for result, gt in zip( idx2label(val_pred_value)[:32], val_data[3][:32]): print("{} \t {}".format(result, gt)) print("#" * 80) train_metrics_result = calc_metrics( idx2label(train_pred_value), train_data[3], metrics_type="accuracy") val_metrics_result = calc_metrics( idx2label(val_pred_value), val_data[3], metrics_type="accuracy") print( "Evaluation Iter {} test loss: {:3f} train accuracy: {:3f} test accuracy {:3f}" .format(start_iter, val_loss_value, train_metrics_result, val_metrics_result)) if start_iter % args.save_iter == 0: print("Iter {} save to checkpoint".format(start_iter)) saver.save(sess, model_save_path, global_step=global_step)