def main():

    parser = argparse.ArgumentParser(description="Im2Latex Evaluating Program")
    parser.add_argument('--model_path', required=True, default="./ckpts/best_ckpt.pt",
                        help='path of the evaluated model')

    # model args
    parser.add_argument("--data_path", type=str,
                        default="./data/", help="The dataset's dir")
    parser.add_argument("--cuda", action='store_true',
                        default=True, help="Use cuda or not")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--beam_size", type=int, default=5)
    parser.add_argument("--result_path", type=str,
                        default="./results/result.txt", help="The file to store result")
    parser.add_argument("--ref_path", type=str,
                        default="./results/ref.txt", help="The file to store reference")
    parser.add_argument("--max_len", type=int,
                        default=64, help="Max step of decoding")
    parser.add_argument("--split", type=str,
                        default="test", help="The data split to decode")
    args = parser.parse_args()
    # 加载 模型
    checkpoint = torch.load(join(args.model_path))
    model_args = checkpoint['args']
    # 读入词典,设置其他相关参数
    vocab = load_vocab(args.data_path)
    use_cuda = True if args.cuda and torch.cuda.is_available() else False
    # 加载测试集
    data_loader = DataLoader(
        Im2LatexDataset(args.data_path, args.split, args.max_len),
        batch_size=args.batch_size,
        collate_fn=partial(collate_fn, vocab.token2idx),
        pin_memory=True if use_cuda else False,
        num_workers=4
    )
    model = Im2LatexModel(
        len(vocab), model_args.emb_dim, model_args.dec_rnn_h,
        add_pos_feat=model_args.add_position_features,
        dropout=model_args.dropout
    )
    model.load_state_dict(checkpoint['model_state_dict'], False)
    result_file = open(args.result_path, 'w')
    ref_file = open(args.ref_path, 'w')
    latex_producer = LatexProducer(
        model, vocab, max_len=args.max_len,
        use_cuda=use_cuda, beam_size=args.beam_size)
    for imgs, tgt4training, tgt4cal_loss in tqdm(data_loader):
        try:
            reference = latex_producer._idx2formulas(tgt4cal_loss)
            results = latex_producer(imgs)
        except RuntimeError:
            break
        result_file.write('\n'.join(results))
        ref_file.write('\n'.join(reference))
    result_file.close()
    ref_file.close()
    score = score_files(args.result_path, args.ref_path)
    print("beam search result:", score)
Пример #2
0
def main():
    # get args
    parser = argparse.ArgumentParser(description="Im2Latex Training Program")
    # parser.add_argument('--path', required=True, help='root of the model')

    # model args
    parser.add_argument("--emb_dim",
                        type=int,
                        default=80,
                        help="Embedding size")
    parser.add_argument("--dec_rnn_h",
                        type=int,
                        default=512,
                        help="The hidden state of the decoder RNN")
    parser.add_argument("--data_path",
                        type=str,
                        default="./data/",
                        help="The dataset's dir")
    parser.add_argument("--add_position_features",
                        action='store_true',
                        default=False,
                        help="Use position embeddings or not")
    # training args
    parser.add_argument("--max_len",
                        type=int,
                        default=150,
                        help="Max size of formula")
    parser.add_argument("--dropout",
                        type=float,
                        default=0.4,
                        help="Dropout probility")
    parser.add_argument("--cuda",
                        action='store_true',
                        default=True,
                        help="Use cuda or not")
    parser.add_argument("--batch_size", type=int, default=16)  # 指定batch_size
    parser.add_argument("--epoches", type=int, default=15)
    parser.add_argument("--lr", type=float, default=3e-4, help="Learning Rate")
    parser.add_argument("--min_lr",
                        type=float,
                        default=3e-5,
                        help="Learning Rate")
    parser.add_argument("--sample_method",
                        type=str,
                        default="teacher_forcing",
                        choices=('teacher_forcing', 'exp', 'inv_sigmoid'),
                        help="The method to schedule sampling")
    parser.add_argument(
        "--decay_k",
        type=float,
        default=1.,
        help="Base of Exponential decay for Schedule Sampling. "
        "When sample method is Exponential deca;"
        "Or a constant in Inverse sigmoid decay Equation. "
        "See details in https://arxiv.org/pdf/1506.03099.pdf")

    parser.add_argument("--lr_decay",
                        type=float,
                        default=0.5,
                        help="Learning Rate Decay Rate")
    parser.add_argument("--lr_patience",
                        type=int,
                        default=3,
                        help="Learning Rate Decay Patience")
    parser.add_argument("--clip",
                        type=float,
                        default=2.0,
                        help="The max gradient norm")
    parser.add_argument("--save_dir",
                        type=str,
                        default="./ckpts",
                        help="The dir to save checkpoints")
    parser.add_argument("--print_freq",
                        type=int,
                        default=100,
                        help="The frequency to print message")
    parser.add_argument("--seed",
                        type=int,
                        default=2020,
                        help="The random seed for reproducing ")
    parser.add_argument("--from_check_point",
                        action='store_true',
                        default=False,
                        help="Training from checkpoint or not")  # 是否finetune
    parser.add_argument("--exp", default="")  # 实验名称,ckpt的名称

    args = parser.parse_args()
    max_epoch = args.epoches
    from_check_point = args.from_check_point
    if from_check_point:
        checkpoint_path = get_checkpoint(args.save_dir)
        checkpoint = torch.load(checkpoint_path)
        args = checkpoint['args']
    print("Training args:", args)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # Building vocab
    print("Load vocab...")
    vocab = load_vocab(args.data_path)

    use_cuda = True if args.cuda and torch.cuda.is_available() else False
    device = torch.device("cuda" if use_cuda else "cpu")

    # data loader
    print("Construct data loader...")
    train_loader = DataLoader(
        Im2LatexDataset(args.data_path, 'train', args.max_len),  # 测试偶尔用test
        # Im2LatexDataset(args.data_path, 'test', args.max_len),
        batch_size=args.batch_size,
        collate_fn=partial(collate_fn, vocab.sign2id),
        pin_memory=True
        if use_cuda else False,  # 锁页内存,这样的话数据都会加载到内存中,交换更快,但是要求设备更高
        num_workers=4)
    val_loader = DataLoader(Im2LatexDataset(args.data_path, 'validate',
                                            args.max_len),
                            batch_size=args.batch_size,
                            collate_fn=partial(collate_fn, vocab.sign2id),
                            pin_memory=True if use_cuda else False,
                            num_workers=4)

    # construct model
    print("Construct model")
    vocab_size = len(vocab)
    model = Im2LatexModel(vocab_size,
                          args.emb_dim,
                          args.dec_rnn_h,
                          add_pos_feat=args.add_position_features,
                          dropout=args.dropout)
    model = model.to(device)
    print("Model Settings:")
    print(model)

    # construct optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     "min",
                                     factor=args.lr_decay,
                                     patience=args.lr_patience,
                                     verbose=True,
                                     min_lr=args.min_lr)

    if from_check_point:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        lr_scheduler.load_state_dict(checkpoint['lr_sche'])
        # init trainer from checkpoint
        max_epoch = epoch + max_epoch  # 修改一个bug
        print('From %s To %s...' % (epoch, max_epoch))
        trainer = Trainer(optimizer,
                          model,
                          lr_scheduler,
                          train_loader,
                          val_loader,
                          args,
                          use_cuda=use_cuda,
                          init_epoch=epoch,
                          last_epoch=max_epoch)
    else:
        trainer = Trainer(optimizer,
                          model,
                          lr_scheduler,
                          train_loader,
                          val_loader,
                          args,
                          use_cuda=use_cuda,
                          init_epoch=1,
                          last_epoch=args.epoches,
                          exp=args.exp)
    # begin training
    trainer.train()
Пример #3
0

beamSize = 5
maxLen = 64

if args.info != "":
    try:
        with open(args.info, 'r') as file:
            params = json.load(file)
        for (k, v) in params.items():
            setattr(args, k, v)
    except:
        pass

# 读入词典模型
vocab = load_vocab(args.vocab_path)
use_cuda = True if torch.cuda.is_available() else False

# 加载模型
if not use_cuda:
    checkpoint = torch.load(join(args.model_path), map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(join(args.model_path))
model_args = checkpoint['args']


model = Im2LatexModel(
    len(vocab), model_args.emb_dim, model_args.dec_rnn_h,
    add_pos_feat=model_args.add_position_features,
    dropout=model_args.dropout
)
def main():

    # get args
    parser = argparse.ArgumentParser(description="Im2Latex Training Program")
    # parser.add_argument('--path', required=True, help='root of the model')

    # model args
    parser.add_argument("--emb_dim",
                        type=int,
                        default=80,
                        help="Embedding size")
    parser.add_argument("--dec_rnn_h",
                        type=int,
                        default=512,
                        help="The hidden state of the decoder RNN")
    parser.add_argument("--data_path",
                        type=str,
                        default="/root/private/im2latex/data/",
                        help="The dataset's dir")
    parser.add_argument("--add_position_features",
                        action='store_true',
                        default=False,
                        help="Use position embeddings or not")
    # training args
    parser.add_argument("--max_len",
                        type=int,
                        default=150,
                        help="Max size of formula")
    parser.add_argument("--dropout",
                        type=float,
                        default=0.,
                        help="Dropout probility")
    parser.add_argument("--cuda",
                        action='store_true',
                        default=True,
                        help="Use cuda or not")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--epoches", type=int, default=200)
    parser.add_argument("--lr", type=float, default=3e-4, help="Learning Rate")
    parser.add_argument("--min_lr",
                        type=float,
                        default=3e-5,
                        help="Learning Rate")
    parser.add_argument("--sample_method",
                        type=str,
                        default="teacher_forcing",
                        choices=('teacher_forcing', 'exp', 'inv_sigmoid'),
                        help="The method to schedule sampling")
    parser.add_argument("--decay_k", type=float, default=1.)

    parser.add_argument("--lr_decay",
                        type=float,
                        default=0.5,
                        help="Learning Rate Decay Rate")
    parser.add_argument("--lr_patience",
                        type=int,
                        default=3,
                        help="Learning Rate Decay Patience")
    parser.add_argument("--clip",
                        type=float,
                        default=2.0,
                        help="The max gradient norm")
    parser.add_argument("--save_dir",
                        type=str,
                        default="./ckpts",
                        help="The dir to save checkpoints")
    parser.add_argument("--print_freq",
                        type=int,
                        default=100,
                        help="The frequency to print message")
    parser.add_argument("--seed",
                        type=int,
                        default=2020,
                        help="The random seed for reproducing ")
    parser.add_argument("--from_check_point",
                        action='store_true',
                        default=False,
                        help="Training from checkpoint or not")
    parser.add_argument("--batch_size_per_gpu", type=int, default=16)
    parser.add_argument("--gpu_num", type=int, default=4)
    device_ids = [0, 1, 2, 3]

    args = parser.parse_args()
    max_epoch = args.epoches
    from_check_point = args.from_check_point
    if from_check_point:
        checkpoint_path = get_checkpoint(args.save_dir)
        checkpoint = torch.load(checkpoint_path)
        args = checkpoint['args']
    print("Training args:", args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Building vocab
    print("Load vocab...")
    vocab = load_vocab(args.data_path)

    use_cuda = True if args.cuda and torch.cuda.is_available() else False
    print(use_cuda)
    device = torch.device("cuda" if use_cuda else "cpu")

    # data loader
    print("Construct data loader...")
    # train_loader = DataLoader(
    #     Im2LatexDataset(args.data_path, 'train', args.max_len),
    #     batch_size=args.batch_size,
    #     collate_fn=partial(collate_fn, vocab.token2idx),
    #     pin_memory=True if use_cuda else False,
    #     num_workers=4)
    train_loader = DataLoader(
        Im2LatexDataset(args.data_path, 'train', args.max_len),
        batch_size=args.batch_size_per_gpu * args.gpu_num,
        collate_fn=partial(collate_fn, vocab.token2idx),
        pin_memory=True if use_cuda else False,
        num_workers=2)
    # val_loader = DataLoader(
    #     Im2LatexDataset(args.data_path, 'validate', args.max_len),
    #     batch_size=args.batch_size,
    #     collate_fn=partial(collate_fn, vocab.token2idx),
    #     pin_memory=True if use_cuda else False,
    #     num_workers=4)
    val_loader = DataLoader(Im2LatexDataset(args.data_path, 'validate',
                                            args.max_len),
                            batch_size=args.batch_size_per_gpu * args.gpu_num,
                            collate_fn=partial(collate_fn, vocab.token2idx),
                            pin_memory=True if use_cuda else False,
                            num_workers=2)

    # construct model
    print("Construct model")
    vocab_size = len(vocab)
    model = Im2LatexModel(vocab_size,
                          args.emb_dim,
                          args.dec_rnn_h,
                          add_pos_feat=args.add_position_features,
                          dropout=args.dropout)
    model = nn.DataParallel(model, device_ids=device_ids)
    # model = model.
    model = model.cuda()
    print("Model Settings:")
    print(model)

    # construct optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     "min",
                                     factor=args.lr_decay,
                                     patience=args.lr_patience,
                                     verbose=True,
                                     min_lr=args.min_lr)

    if from_check_point:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        lr_scheduler.load_state_dict(checkpoint['lr_sche'])
        # init trainer from checkpoint
        trainer = Trainer(optimizer,
                          model,
                          lr_scheduler,
                          train_loader,
                          val_loader,
                          args,
                          use_cuda=use_cuda,
                          init_epoch=epoch,
                          last_epoch=max_epoch)
    else:
        trainer = Trainer(optimizer,
                          model,
                          lr_scheduler,
                          train_loader,
                          val_loader,
                          args,
                          use_cuda=use_cuda,
                          init_epoch=1,
                          last_epoch=args.epoches)
    # begin training
    trainer.train()
Пример #5
0
try:
    with open(args.info, 'r') as file:
        params = json.load(file)
    for (k, v) in params.items():
        setattr(args, k, v)
except:
    pass

if not cuda:
    checkpoint = torch.load(join(modelPath), map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(join(modelPath))
model_args = checkpoint['args']

# 读入词典模型
vocab = load_vocab(vocabPath)

model = Im2LatexModel(len(vocab),
                      model_args.emb_dim,
                      model_args.dec_rnn_h,
                      add_pos_feat=model_args.add_position_features,
                      dropout=model_args.dropout)

model.load_state_dict(checkpoint['model_state_dict'])

latex_producer = LatexProducer(model,
                               vocab,
                               max_len=maxLen,
                               use_cuda=cuda,
                               beam_size=beamSize)