Example #1
0
def main():
    from arguments.qgen_args import qgen_arguments
    from data_provider.qgen_baseline_dataset import prepare_dataset
    from process_data.tokenizer import GWTokenizer
    parser = qgen_arguments()
    args, _ = parser.parse_known_args()
    args = vars(args)
    tokenizer = GWTokenizer('./../data/dict.json')
    loader = prepare_dataset("./../data/", "test", args, tokenizer)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = QGenNetwork(args, tokenizer, device).to(device)
    optimizer = torch.optim.Adam(model.parameters(), args["lr"])
    data_iter = iter(loader)
    model.train()
    for i in range(5):
        batch = next(data_iter)
        optimizer.zero_grad()
        model.zero_grad()
        _, loss = model(batch)
        loss.backward()
        _ = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args["clip_val"])
        optimizer.step()
        print("loss: {:.4f}".format(loss.item()))
    model.eval()
    batch = next(data_iter)
    result, _ = model.generate(batch)
    print("generate")
    print(tokenizer.decode(result[0]))
Example #2
0
def main():
    from arguments.qgen_args import qgen_arguments
    from process_data.tokenizer import GWTokenizer
    data_dir = "./../data/"
    tokenizer = GWTokenizer('./../data/dict.json')
    parser = qgen_arguments()
    args = parser.parse_args()
    args = vars(args)
    dataset = QuestionDataset(data_dir, 'test', args, tokenizer=tokenizer)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=8, collate_fn=question_collate)
    print(len(dataset), len(dataloader))
    dataiter = iter(dataloader)
    for i in range(1):
        batch = next(dataiter)
Example #3
0
            args["object"] = True
            args["image_arch"] = "rcnn"
            args["image_dim"] = 2048
        with open(save_path.format("args.json"), mode="w") as f:
            json.dump(args, f, indent=2, ensure_ascii=False)
        logger.info(args)
        model = QGenNetwork(args, tokenizer, device).to(device)
        train_loader, val_loader = prepare_dataset(data_dir, "train", args,
                                                   tokenizer)
        train(model, args, train_loader, val_loader, param_file)
    else:
        with open(save_path.format("args.json"), mode="r") as f:
            saved_args = json.load(f)
            saved_args["option"] = "test"
        args = saved_args
        logger.info(args)
        model = QGenNetwork(args, tokenizer, device).to(device)
        testloader = prepare_dataset(data_dir, "test", args, tokenizer)
        test(model, args, testloader, param_file)


if __name__ == "__main__":
    parser = qgen_arguments()
    flags, unknown = parser.parse_known_args()
    flags = vars(flags)
    model_dir = "./../out/qgen/" + flags["name"]
    os.makedirs(model_dir) if not os.path.exists(model_dir) else None
    save_path = model_dir + "/{}"
    logger = create_logger(save_path.format('train.log'), "w")
    main(flags)