Exemple #1
0
def train(args, config):
    # load data
    print("Loading data")
    train_dataset = data.get_loader(config.train_file, config.batch_size,
                                    config.version)
    dev_dataset = data.get_loader(config.dev_file, config.val_batch_size,
                                  config.version)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)

    print("loading embeddings")
    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "rb") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)

    # create model
    print("creating model")
    if config.model_type == "model6":
        Model = model_utils.get_model_func(config.model_type, config.version)
        # get pretrained model from config

        # takes a pretrained of QANet
        model_dict_ = torch.load(config.pretrained_model,
                                 pickle_module=dill)["model"]
        model_dict = {}
        for key in model_dict_:
            model_dict[key[7:]] = model_dict_[key]

        from models import QANet

        pretrained_model = QANet(word_mat, char_mat, config)
        # load its state
        model_data = pretrained_model.state_dict()
        model_data.update(model_dict)
        pretrained_model.load_state_dict(model_data)
        model = Model(pretrained_model, config).to(config.device)
        model = torch.nn.DataParallel(model)

    else:
        Model = model_utils.get_model_func(config.model_type, config.version)
        model = Model(word_mat, char_mat, config).to(config.device)
        model = torch.nn.DataParallel(model)
    print("Training Model")
    trainer = Trainer(model, train_dataset, dev_dataset, dev_eval_file, config)
    if args.model_file is not None:
        trainer.load(args.model_file)
        trainer.ema.resume(trainer.model)
    trainer.train()
Exemple #2
0
def train_entry(config):
    from models import QANet

    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "rb") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)

    print("Building model...")

    train_dataset = get_loader(config.train_record_file, config.batch_size)
    dev_dataset = get_loader(config.dev_record_file, config.batch_size)

    lr = config.learning_rate
    base_lr = 1
    lr_warm_up_num = config.lr_warm_up_num

    model = QANet(word_mat, char_mat).to(device)
    if torch.cuda.device_count() > 1:
        print('i can use gpu')
        model = torch.nn.DataParallel(model, device_ids=[0, 1])
    model.load_state_dict(
        torch.load('/home/cn/AI/QANet-pytorch-/model_state_dict.pt'))
    ema = EMA(config.decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    parameters = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(lr=base_lr,
                           betas=(0.9, 0.999),
                           eps=1e-7,
                           weight_decay=5e-8,
                           params=parameters)
    cr = lr / math.log2(lr_warm_up_num)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda ee: cr * math.log2(ee + 1)
        if ee < lr_warm_up_num else lr)
    best_f1 = 0
    best_em = 0
    patience = 0
    unused = False
    for iter in range(config.num_epoch):
        train(model, optimizer, scheduler, train_dataset, dev_dataset,
              dev_eval_file, iter, ema)
        print(iter)
        ema.assign(model)
        metrics = test(model, dev_dataset, dev_eval_file,
                       (iter + 1) * len(train_dataset))
        dev_f1 = metrics["f1"]
        dev_em = metrics["exact_match"]
        if dev_f1 < best_f1 and dev_em < best_em:
            patience += 1
            if patience > config.early_stop:
                break
        else:
            patience = 0
            best_f1 = max(best_f1, dev_f1)
            best_em = max(best_em, dev_em)

        fn = os.path.join(config.save_dir, "model.pt")
        torch.save(model, fn)
        torch.save(model.state_dict(), 'model_state_dict.pt')
        ema.resume(model)