Пример #1
0
def main_test(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, None, 3], name="input_images")
    input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width")
    input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels")
    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)

    model_infer, attention_weights, pred = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False)
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())
    
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
        model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
        print('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)

        images_path, labels = get_data(args)
        predicts = []
        for img_path, label in zip(images_path, labels):
            try:
                img = cv2.imread(img_path)
            except Exception as e:
                print("{} error: {}".format(img_path, e))
                continue

            img, la, width = data_preprocess(img, label, char2id, args)

            pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img],
                                                                                                 input_labels: [la],
                                                                                                 input_images_width: [width]})
            pred_value_str = idx2label(pred_value, id2char, char2id)[0]
            print("predict: {} label: {}".format(pred_value_str, label))
            predicts.append(pred_value_str)
            pred_value_str += '$'
            if args.vis_dir != None and args.vis_dir != "":
                os.makedirs(args.vis_dir, exist_ok=True)
                assert len(img.shape) == 3
                att_maps = attention_weights_value.reshape([-1, attention_weights_value.shape[2], attention_weights_value.shape[3], 1]) # T * H * W * 1
                for i, att_map in enumerate(att_maps):
                    if i >= len(pred_value_str):
                        break
                    att_map = cv2.resize(att_map, (img.shape[1], img.shape[0]))
                    _att_map = np.zeros(dtype=np.uint8, shape=[img.shape[0], img.shape[1], 3])
                    _att_map[:, :, -1] = (att_map * 255).astype(np.uint8)

                    show_attention = cv2.addWeighted(img, 0.5, _att_map, 2, 0)
                    cv2.imwrite(os.path.join(args.vis_dir, os.path.basename(img_path).split('.')[0] + "_" + str(i) + "_" + pred_value_str[i] + ".jpg"), show_attention)

    acc_rate = accuracy(predicts, labels)
    print("Done, Accuracy: {}".format(acc_rate))
Пример #2
0
def main_test_with_lexicon(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images")
    input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width")
    input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels")
    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)

    # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True)
    # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type)
    model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False)
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())
    # saver = tf.train.Saver(tf.global_variables())
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
        model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
        print('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)
        print("Checkpoints step: {}".format(global_step.eval(session=sess)))
        images_path, labels, lexicons = get_data_lexicon(args)
        predicts = []
        # for img_path, label in zip(images_path, labels):
        for i in tqdm.tqdm(range(len(images_path))):
            img_path = images_path[i]
            label = labels[i]
            try:
                img = cv2.imread(img_path)
            except Exception as e:
                print("{} error: {}".format(img_path, e))
                continue

            img, la, width = data_preprocess(img, label, char2id, args)

            pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img],
                                                                                                 input_labels: [la],
                                                                                                 input_images_width: [width]})
            pred_value_str = idx2label(pred_value, id2char, char2id)[0]
            # print("predict: {} label: {}".format(pred_value_str, label))
            predicts.append(pred_value_str)
            if args.vis_dir != None and args.vis_dir != "":
                os.makedirs(args.vis_dir, exist_ok=True)
                os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True)
                _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path)
                if pred_value_str.lower() != label.lower():
                    _ = line_visualize(img, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), img_path)
    acc_rate = calc_metrics_lexicon(predicts, labels, lexicons)
    print("Done, Accuracy: {}".format(acc_rate))
Пример #3
0
def test(args, cpks):

    assert isinstance(cpks, str)

    voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS")

    test_data = build_dataloader(args.val_data_cfg)
    print("test data: {}".format(len(test_data)))

    model = MultiInstanceRecognition(args.model_cfg).cuda()
    # model = MMDataParallel(model).cuda()
    model.load_state_dict(torch.load(cpks))
    model.eval()
    pred_strs = []
    gt_strs = []
    test_data_iter = iter(test_data)
    for i, batch_data in enumerate(test_data):
        torch.cuda.empty_cache()
        batch_imgs, batch_imgs_path, batch_rectangles, \
        batch_text_labels, batch_text_labels_mask, batch_words = \
            batch_data
        if batch_imgs is None:
            continue
        batch_imgs = batch_imgs.cuda()
        batch_rectangles = batch_rectangles.cuda()
        batch_text_labels = batch_text_labels.cuda()
        with torch.no_grad():
            loss, decoder_logits = model(batch_imgs, batch_text_labels,
                                         batch_rectangles,
                                         batch_text_labels_mask)

        pred_labels = decoder_logits.argmax(dim=2).cpu().numpy()
        pred_value_str = idx2label(pred_labels, id2char, char2id)
        gt_str = batch_words

        for i in range(len(gt_str[0])):
            print("predict: {} label: {}".format(pred_value_str[i],
                                                 gt_str[0][i]))
            pred_strs.append(pred_value_str[i])
            gt_strs.append(gt_str[0][i])

        val_dec_metrics_result = calc_metrics(pred_strs,
                                              gt_strs,
                                              metrics_type="accuracy")

        print("test accuracy= {:3f}".format(val_dec_metrics_result))
        #
        #
        #                                                                         val_loss_value))
        print('---------')
Пример #4
0
def train(cfg, args):
    logger = logging.getLogger('model training')
    train_data = build_dataloader(cfg.train_data_cfg, args.distributed)
    logger.info("train data: {}".format(len(train_data)))
    val_data = build_dataloader(cfg.val_data_cfg, args.distributed)
    logger.info("val data: {}".format(len(val_data)))

    model = MultiInstanceRecognition(cfg.model_cfg).cuda()
    if cfg.resume_from is not None:
        logger.info('loading pretrained models from {opt.continue_model}')
        model.load_state_dict(torch.load(cfg.resume_from))
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
    voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS")

    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    logger.info('Trainable params num : ', sum(params_num))
    optimizer = optim.Adam(filtered_parameters, lr=cfg.lr, betas=(0.9, 0.999))
    lrScheduler = lr_scheduler.MultiStepLR(optimizer, [1, 2, 3], gamma=0.1)

    max_iters = cfg.max_iters
    start_iter = 0
    if cfg.resume_from is not None:
        start_iter = int(cfg.resume_from.split('_')[-1].split('.')[0])
        logger.info('continue to train, start_iter: {start_iter}')

    train_data_iter = iter(train_data)
    val_data_iter = iter(val_data)
    start_time = time.time()
    for i in range(start_iter, max_iters):
        model.train()
        try:
            batch_data = next(train_data_iter)
        except StopIteration:
            train_data_iter = iter(train_data)
            batch_data = next(train_data_iter)
        data_time_s = time.time()
        batch_imgs, batch_imgs_path, batch_rectangles, \
        batch_text_labels, batch_text_labels_mask, batch_words = \
            batch_data
        while batch_imgs is None:
            batch_data = next(train_data_iter)
            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                batch_data

        batch_imgs = batch_imgs.cuda(non_blocking=True)
        batch_rectangles = batch_rectangles.cuda(non_blocking=True)
        batch_text_labels = batch_text_labels.cuda(non_blocking=True)
        data_time = time.time() - data_time_s
        # print(time.time() -s)
        # s = time.time()
        loss, decoder_logits = model(batch_imgs, batch_text_labels,
                                     batch_rectangles, batch_text_labels_mask)
        del batch_data
        # print(time.time() - s)
        # print('------')
        # s = time.time()

        loss = loss.mean()
        print(loss)
        # del loss
        # print(time.time() - s)
        # print('------')

        if i % cfg.train_verbose == 0:
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(loss, 0)
            log_info = "train iter :{}, time: {:.2f}, data_time: {:.2f}, Loss: {:.3f}".format(
                i, this_time, data_time, loss.data)
            logger.info(log_info)
            torch.cuda.empty_cache()
            # break

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del loss
        if i % cfg.val_iter == 0:
            print("--------Val iteration---------")
            model.eval()

            try:
                val_batch = next(val_data_iter)
            except StopIteration:
                val_data_iter = iter(val_data)
                val_batch = next(val_data_iter)

            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                val_batch
            while batch_imgs is None:
                val_batch = next(val_data_iter)
                batch_imgs, batch_imgs_path, batch_rectangles, \
                batch_text_labels, batch_text_labels_mask, batch_words = \
                    val_batch
            del val_batch
            batch_imgs = batch_imgs.cuda(non_blocking=True)
            batch_rectangles = batch_rectangles.cuda(non_blocking=True)
            batch_text_labels = batch_text_labels.cuda(non_blocking=True)
            with torch.no_grad():
                val_loss, val_pred_logits = model(batch_imgs,
                                                  batch_text_labels,
                                                  batch_rectangles,
                                                  batch_text_labels_mask)
            pred_labels = val_pred_logits.argmax(dim=2).cpu().numpy()
            pred_value_str = idx2label(pred_labels, id2char, char2id)
            # gt_str = batch_words
            gt_str = []
            for words in batch_words:
                gt_str = gt_str + words
            val_dec_metrics_result = calc_metrics(pred_value_str,
                                                  gt_str,
                                                  metrics_type="accuracy")
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(val_loss, 0)
            log_info = "val iter :{}, time: {:.2f} Loss: {:.3f}, acc: {:.2f}".format(
                i, this_time,
                loss.mean().data, val_dec_metrics_result)
            logger.info(log_info)
            del val_loss
        if (i + 1) % cfg.save_iter == 0:
            torch.save(model.state_dict(),
                       cfg.save_name + '_{}.pth'.format(i + 1))
        if i > 0 and i % cfg.lr_step == 0:  # 调整学习速率
            lrScheduler.step()
            logger.info("lr step")
        # torch.cuda.empty_cache()
    print('end the training')
Пример #5
0
def main_test_lmdb(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images")
    input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width")
    input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels")
    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)

    # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True)
    # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type)
    model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False)
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())
    # saver = tf.train.Saver(tf.global_variables())
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
        model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
        print('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)
        print("Checkpoints step: {}".format(global_step.eval(session=sess)))

        env = lmdb.open(args.test_data_dir, readonly=True)
        txn = env.begin()
        num_samples = int(txn.get(b"num-samples").decode())
        predicts = []
        labels = []
        # for img_path, label in zip(images_path, labels):
        for i in tqdm.tqdm(range(1, num_samples+1)):
            image_key = b'image-%09d' % i
            label_key = b'label-%09d' % i

            imgbuf = txn.get(image_key)
            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)

            img_pil = Image.open(buf).convert('RGB')
            img = np.array(img_pil)
            label = txn.get(label_key).decode()
            labels.append(label)
            img, la, width = data_preprocess(img, label, char2id, args)

            pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img],
                                                                                                 input_labels: [la],
                                                                                                 input_images_width: [
                                                                                                     width]})
            pred_value_str = idx2label(pred_value, id2char, char2id)[0]
            # print("predict: {} label: {}".format(pred_value_str, label))
            predicts.append(pred_value_str)
            img_vis = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            if args.vis_dir != None and args.vis_dir != "":
                os.makedirs(args.vis_dir, exist_ok=True)
                os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True)
                # _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i))
                _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i))
                # _ = mask_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path)
                if pred_value_str.lower() != label.lower():
                    _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), "{}.jpg".format(i))
        acc_rate = calc_metrics(predicts, labels)
        print("Done, Accuracy: {}".format(acc_rate))
Пример #6
0
def main_train(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)
    tf.set_random_seed(1)
    # Build graph
    input_train_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.height, args.width, 3],
        name="input_train_images")
    input_train_images_width = tf.placeholder(dtype=tf.float32,
                                              shape=[args.train_batch_size],
                                              name="input_train_width")
    input_train_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels")
    input_train_gauss_labels = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 6, 40],
        name="input_train_gauss_labels")  # better way wanted!!!
    input_train_gauss_tags = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_gauss_tags")
    input_train_gauss_params = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 4],
        name="input_train_gauss_params")
    input_train_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels_mask")
    input_train_char_sizes = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 2],
        name="input_train_char_sizes")
    input_train_gauss_mask = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 6, 40],
        name="input_train_gauss_mask")

    input_val_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.val_batch_size, args.height, args.width, 3],
        name="input_val_images")
    input_val_images_width = tf.placeholder(dtype=tf.float32,
                                            shape=[args.val_batch_size],
                                            name="input_val_width")
    input_val_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels")
    # input_val_gauss_labels = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size, args.max_len, args.height, args.width], name="input_val_gauss_labels")
    input_val_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels_mask")

    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=True,
                         att_loss_type=args.att_loss_type,
                         att_loss_weight=args.att_loss_weight)
    sar_model_val = SARModel(num_classes=len(voc),
                             encoder_dim=args.encoder_sdim,
                             encoder_layer=args.encoder_layers,
                             decoder_dim=args.decoder_sdim,
                             decoder_layer=args.decoder_layers,
                             decoder_embed_dim=args.decoder_edim,
                             seq_len=args.max_len,
                             is_training=False)
    train_model_infer, train_attention_weights, train_pred, train_attention_params = sar_model(
        input_train_images,
        input_train_labels,
        input_train_images_width,
        batch_size=args.train_batch_size,
        reuse=False)
    if args.att_loss_type == "kldiv":
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_weights, input_train_labels,
            input_train_gauss_labels, input_train_labels_mask,
            input_train_gauss_tags)
    elif args.att_loss_type == "l1" or args.att_loss_type == "l2":
        # train_loss, train_recog_loss, train_att_loss = sar_model.loss(train_model_infer, train_attention_params, input_train_labels, input_train_gauss_params, input_train_labels_mask, input_train_gauss_tags, input_train_char_sizes)
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_weights, input_train_labels,
            input_train_gauss_labels, input_train_labels_mask,
            input_train_gauss_tags)
    elif args.att_loss_type == 'ce':
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_weights, input_train_labels,
            input_train_gauss_labels, input_train_labels_mask,
            input_train_gauss_tags, input_train_gauss_mask)
    elif args.att_loss_type == 'gausskldiv':
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_params, input_train_labels,
            input_train_gauss_params, input_train_labels_mask,
            input_train_gauss_tags)

    else:
        print("Unimplemented loss type {}".format(args.att_loss_dtype))
        exit(-1)
    val_model_infer, val_attention_weights, val_pred, _ = sar_model_val(
        input_val_images,
        input_val_labels,
        input_val_images_width,
        batch_size=args.val_batch_size,
        reuse=True)

    train_data_list = get_data(args.train_data_dir,
                               args.voc_type,
                               args.max_len,
                               args.height,
                               args.width,
                               args.train_batch_size,
                               args.workers,
                               args.keep_ratio,
                               with_aug=args.aug)

    val_data_gen = evaluator_data.Evaluator(lmdb_data_dir=args.test_data_dir,
                                            batch_size=args.val_batch_size,
                                            height=args.height,
                                            width=args.width,
                                            max_len=args.max_len,
                                            keep_ratio=args.keep_ratio,
                                            voc_type=args.voc_type)
    val_data_gen.reset()

    global_step = tf.get_variable(name='global_step',
                                  initializer=tf.constant(0),
                                  trainable=False)

    learning_rate = tf.train.piecewise_constant(global_step, args.decay_bound,
                                                args.lr_stage)
    batch_norm_updates_op = tf.group(tf.get_collection(
        tf.GraphKeys.UPDATE_OPS))

    # Save summary
    os.makedirs(args.checkpoints, exist_ok=True)
    tf.summary.scalar(name='train_loss', tensor=train_loss)
    tf.summary.scalar(name='train_recog_loss', tensor=train_recog_loss)
    tf.summary.scalar(name='train_att_loss', tensor=train_att_loss)
    # tf.summary.scalar(name='val_att_loss', tensor=val_att_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    merge_summary_op = tf.summary.merge_all()

    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'sar_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = os.path.join(args.checkpoints, model_name)
    best_model_save_path = os.path.join(args.checkpoints, 'best_model',
                                        model_name)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    with tf.control_dependencies(
        [variables_averages_op, batch_norm_updates_op]):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        grads = optimizer.compute_gradients(train_loss)
        if args.grad_clip > 0:
            print("With Gradients clipped!")
            for idx, (grad, var) in enumerate(grads):
                grads[idx] = (tf.clip_by_norm(grad, args.grad_clip), var)
        train_op = optimizer.apply_gradients(grads, global_step=global_step)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    summary_writer = tf.summary.FileWriter(args.checkpoints)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    log_file = open(os.path.join(args.checkpoints, args.checkpoints + ".log"),
                    "w")
    with tf.Session(config=config) as sess:
        summary_writer.add_graph(sess.graph)
        start_iter = 0
        if args.resume == True and args.pretrained != '':
            print('Restore model from {:s}'.format(args.pretrained))
            ckpt_state = tf.train.get_checkpoint_state(args.pretrained)
            model_path = os.path.join(
                args.pretrained,
                os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            start_iter = sess.run(tf.train.get_global_step())
        elif args.resume == False and args.pretrained != '':
            print('Restore pretrained model from {:s}'.format(args.pretrained))
            ckpt_state = tf.train.get_checkpoint_state(args.pretrained)
            model_path = os.path.join(
                args.pretrained,
                os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            sess.run(tf.assign(global_step, 0))
        else:
            print('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)

        # Evaluate the model first
        val_pred_value_all = []
        val_labels = []
        for eval_iter in range(val_data_gen.num_samples //
                               args.val_batch_size):

            val_data = val_data_gen.get_batch()
            if val_data is None:
                break
            print("Evaluation: [{} / {}]".format(
                eval_iter, (val_data_gen.num_samples // args.val_batch_size)))
            val_pred_value = sess.run(val_pred,
                                      feed_dict={
                                          input_val_images: val_data[0],
                                          input_val_labels: val_data[1],
                                          input_val_images_width: val_data[5],
                                          input_val_labels_mask: val_data[2]
                                      })
            val_pred_value_all.extend(val_pred_value)
            val_labels.extend(val_data[4])

        val_data_gen.reset()
        val_metrics_result = calc_metrics(idx2label(
            np.array(val_pred_value_all)),
                                          val_labels,
                                          metrics_type="accuracy")
        print("Evaluation Before training: Test accuracy {:3f}".format(
            val_metrics_result))
        val_best_acc = val_metrics_result

        while start_iter < args.iters:
            start_iter += 1
            train_data = get_batch_data(train_data_list, args.train_batch_size)
            _, train_loss_value, train_recog_loss_value, train_att_loss_value, train_pred_value = sess.run(
                [
                    train_op, train_loss, train_recog_loss, train_att_loss,
                    train_pred
                ],
                feed_dict={
                    input_train_images: train_data[0],
                    input_train_labels: train_data[1],
                    input_train_gauss_labels: train_data[2],
                    input_train_gauss_params: train_data[7],
                    input_train_labels_mask: train_data[3],
                    input_train_images_width: train_data[5],
                    input_train_gauss_tags: train_data[6],
                    input_train_char_sizes: train_data[8],
                    input_train_gauss_mask: train_data[9]
                })

            if start_iter % args.log_iter == 0:
                print(
                    "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})"
                    .format(start_iter, train_loss_value,
                            train_recog_loss_value, train_att_loss_value))
                log_file.write(
                    "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})"
                    .format(start_iter, train_loss_value,
                            train_recog_loss_value, train_att_loss_value))
            if start_iter % args.summary_iter == 0:
                merge_summary_value = sess.run(merge_summary_op,
                                               feed_dict={
                                                   input_train_images:
                                                   train_data[0],
                                                   input_train_labels:
                                                   train_data[1],
                                                   input_train_gauss_labels:
                                                   train_data[2],
                                                   input_train_gauss_params:
                                                   train_data[7],
                                                   input_train_labels_mask:
                                                   train_data[3],
                                                   input_train_images_width:
                                                   train_data[5],
                                                   input_train_gauss_tags:
                                                   train_data[6],
                                                   input_train_char_sizes:
                                                   train_data[8],
                                                   input_train_gauss_mask:
                                                   train_data[9]
                                               })

                summary_writer.add_summary(summary=merge_summary_value,
                                           global_step=start_iter)
                if start_iter % args.eval_iter == 0:
                    val_pred_value_all = []
                    val_labels = []
                    for eval_iter in range(val_data_gen.num_samples //
                                           args.val_batch_size):
                        val_data = val_data_gen.get_batch()
                        if val_data is None:
                            break
                        print("Evaluation: [{} / {}]".format(
                            eval_iter,
                            (val_data_gen.num_samples // args.val_batch_size)))
                        val_pred_value = sess.run(val_pred,
                                                  feed_dict={
                                                      input_val_images:
                                                      val_data[0],
                                                      input_val_labels:
                                                      val_data[1],
                                                      input_val_labels_mask:
                                                      val_data[2],
                                                      input_val_images_width:
                                                      val_data[5]
                                                  })
                        val_pred_value_all.extend(val_pred_value)
                        val_labels.extend(val_data[4])

                    val_data_gen.reset()
                    train_metrics_result = calc_metrics(
                        idx2label(train_pred_value),
                        train_data[4],
                        metrics_type="accuracy")
                    val_metrics_result = calc_metrics(idx2label(
                        np.array(val_pred_value_all)),
                                                      val_labels,
                                                      metrics_type="accuracy")
                    print(
                        "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}"
                        .format(start_iter, train_metrics_result,
                                val_metrics_result))
                    log_file.write(
                        "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}\n"
                        .format(start_iter, train_metrics_result,
                                val_metrics_result))

                    if val_metrics_result >= val_best_acc:
                        print("Better results! Save checkpoitns to {}".format(
                            best_model_save_path))
                        val_best_acc = val_metrics_result
                        best_saver.save(sess,
                                        best_model_save_path,
                                        global_step=global_step)

            if start_iter % args.save_iter == 0:
                print("Iter {} save to checkpoint".format(start_iter))
                saver.save(sess, model_save_path, global_step=global_step)
    log_file.close()
Пример #7
0
def main_train(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    # Build graph
    input_train_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.height, args.width, 3],
        name="input_train_images")
    input_train_images_width = tf.placeholder(dtype=tf.float32,
                                              shape=[args.train_batch_size],
                                              name="input_train_width")
    input_train_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels")
    input_train_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels_mask")

    input_val_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.val_batch_size, args.height, args.width, 3],
        name="input_val_images")
    input_val_images_width = tf.placeholder(dtype=tf.float32,
                                            shape=[args.val_batch_size],
                                            name="input_val_width")
    input_val_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels")
    input_val_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels_mask")

    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=True)
    sar_model_val = SARModel(num_classes=len(voc),
                             encoder_dim=args.encoder_sdim,
                             encoder_layer=args.encoder_layers,
                             decoder_dim=args.decoder_sdim,
                             decoder_layer=args.decoder_layers,
                             decoder_embed_dim=args.decoder_edim,
                             seq_len=args.max_len,
                             is_training=False)
    train_model_infer, train_attention_weights, train_pred = sar_model(
        input_train_images,
        input_train_labels,
        input_train_images_width,
        batch_size=args.train_batch_size,
        reuse=False)
    train_loss = sar_model.loss(train_model_infer, input_train_labels,
                                input_train_labels_mask)

    val_model_infer, val_attention_weights, val_pred = sar_model_val(
        input_val_images,
        input_val_labels,
        input_val_images_width,
        batch_size=args.val_batch_size,
        reuse=True)
    val_loss = sar_model_val.loss(val_model_infer, input_val_labels,
                                  input_val_labels_mask)

    train_data_list = get_data(args.train_data_dir,
                               args.train_data_gt,
                               args.voc_type,
                               args.max_len,
                               args.num_train,
                               args.height,
                               args.width,
                               args.train_batch_size,
                               args.workers,
                               args.keep_ratio,
                               with_aug=args.aug)

    val_data_list = get_data(args.test_data_dir,
                             args.test_data_gt,
                             args.voc_type,
                             args.max_len,
                             args.num_train,
                             args.height,
                             args.width,
                             args.val_batch_size,
                             args.workers,
                             args.keep_ratio,
                             with_aug=False)

    global_step = tf.get_variable(name='global_step',
                                  initializer=tf.constant(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(learning_rate=args.lr,
                                               global_step=global_step,
                                               decay_steps=args.decay_iter,
                                               decay_rate=args.weight_decay,
                                               staircase=True)

    batch_norm_updates_op = tf.group(
        *tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    grads = optimizer.compute_gradients(train_loss)
    apply_gradient_op = optimizer.apply_gradients(grads,
                                                  global_step=global_step)

    # Save summary
    os.makedirs(args.checkpoints, exist_ok=True)
    tf.summary.scalar(name='train_loss', tensor=train_loss)
    tf.summary.scalar(name='val_loss', tensor=val_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    merge_summary_op = tf.summary.merge_all()

    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'sar_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = os.path.join(args.checkpoints, model_name)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    summary_writer = tf.summary.FileWriter(args.checkpoints)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        summary_writer.add_graph(sess.graph)
        start_iter = 0
        if args.resume == True and args.pretrained != '':
            print('Restore model from {:s}'.format(args.pretrained))
            ckpt_state = tf.train.get_checkpoint_state(args.pretrained)
            model_path = os.path.join(
                args.pretrained,
                os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            start_iter = sess.run(tf.train.get_global_step())
        else:
            print('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)

        while start_iter < args.iters:
            start_iter += 1
            train_data = get_batch_data(train_data_list, args.train_batch_size)
            _, train_loss_value, train_pred_value = sess.run(
                [train_op, train_loss, train_pred],
                feed_dict={
                    input_train_images: train_data[0],
                    input_train_labels: train_data[1],
                    input_train_labels_mask: train_data[2],
                    input_train_images_width: train_data[4]
                })

            if start_iter % args.log_iter == 0:
                print("Iter {} train loss= {:3f}".format(
                    start_iter, train_loss_value))

            if start_iter % args.summary_iter == 0:
                val_data = get_batch_data(val_data_list, args.val_batch_size)

                merge_summary_value, val_pred_value, val_loss_value = sess.run(
                    [merge_summary_op, val_pred, val_loss],
                    feed_dict={
                        input_train_images: train_data[0],
                        input_train_labels: train_data[1],
                        input_train_labels_mask: train_data[2],
                        input_train_images_width: train_data[4],
                        input_val_images: val_data[0],
                        input_val_labels: val_data[1],
                        input_val_labels_mask: val_data[2],
                        input_val_images_width: val_data[4]
                    })

                summary_writer.add_summary(summary=merge_summary_value,
                                           global_step=start_iter)
                if start_iter % args.eval_iter == 0:
                    print("#" * 80)
                    print("train prediction \t train labels ")
                    for result, gt in zip(idx2label(train_pred_value),
                                          train_data[3]):
                        print("{} \t {}".format(result, gt))
                    print("#" * 80)
                    print("test prediction \t test labels ")
                    for result, gt in zip(
                            idx2label(val_pred_value)[:32], val_data[3][:32]):
                        print("{} \t {}".format(result, gt))
                    print("#" * 80)

                    train_metrics_result = calc_metrics(
                        idx2label(train_pred_value),
                        train_data[3],
                        metrics_type="accuracy")
                    val_metrics_result = calc_metrics(
                        idx2label(val_pred_value),
                        val_data[3],
                        metrics_type="accuracy")
                    print(
                        "Evaluation Iter {} test loss: {:3f} train accuracy: {:3f} test accuracy {:3f}"
                        .format(start_iter, val_loss_value,
                                train_metrics_result, val_metrics_result))
            if start_iter % args.save_iter == 0:
                print("Iter {} save to checkpoint".format(start_iter))
                saver.save(sess, model_save_path, global_step=global_step)