コード例 #1
0
ファイル: train.py プロジェクト: askintution/BiDAF
print('prepare data')
# config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pre_trained_ = load('pre_data/embed_pre.json')
pre_trained = torch.Tensor(pre_trained_[0])
del pre_trained_
print('loading train_dataset')
train_dataset = SQuADData('pre_data/input/train')
dev_dataset = SQuADData('pre_data/input/dev')

# define model
print('define model')
model = BiDAF(pre_trained)
# model = BiDAF(pre_trained, 128)
# model = torch.load('model/model.pt')
model = model.to(device)
lr = config.learning_rate
base_lr = 1.0
warm_up = config.lr_warm_up_num
cr = lr / log2(warm_up)
optimizer = torch.optim.Adam(lr=config.learning_rate,
                             betas=(config.beta1, config.beta2),
                             eps=config.eps,
                             weight_decay=3e-7,
                             params=model.parameters())
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=lambda ee: cr * log2(ee + 1)
                                        if ee < warm_up else lr)

print('begin train')
f = open('log/log.txt', 'w')
コード例 #2
0
valid_dataloader = DataLoader(valid_dataset,
                              shuffle=True,
                              batch_size=hyper_params["batch_size"],
                              num_workers=4)

print("Length of training data loader is:", len(train_dataloader))
print("Length of valid data loader is:", len(valid_dataloader))

# load the model
model = BiDAF(word_vectors=word_embedding_matrix,
              char_vectors=char_embedding_matrix,
              hidden_size=hyper_params["hidden_size"],
              drop_prob=hyper_params["drop_prob"])
if hyper_params["pretrained"]:
    model.load_state_dict(torch.load(os.path.join(experiment_path, "model.pkl"))["state_dict"])
model.to(device)

# define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(model.parameters(), hyper_params["learning_rate"], weight_decay=1e-4)

# best loss so far
if hyper_params["pretrained"]:
    best_valid_loss = torch.load(os.path.join(experiment_path, "model.pkl"))["best_valid_loss"]
    epoch_checkpoint = torch.load(os.path.join(experiment_path, "model_last_checkpoint.pkl"))["epoch"]
    print("Best validation loss obtained after {} epochs is: {}".format(epoch_checkpoint, best_valid_loss))
else:
    best_valid_loss = 100
    epoch_checkpoint = 0

# train the Model
コード例 #3
0
def eval(context, question):
    with open(os.path.join(config.data_dir, "train", "word2idx.pkl"), "rb") as wi, \
         open(os.path.join(config.data_dir, "train", "char2idx.pkl"), "rb") as ci, \
         open(os.path.join(config.data_dir, "train", "word_embeddings.pkl"), "rb") as wb, \
         open(os.path.join(config.data_dir, "train", "char_embeddings.pkl"), "rb") as cb:
        word2idx = pickle.load(wi)
        char2idx = pickle.load(ci)
        word_embedding_matrix = pickle.load(wb)
        char_embedding_matrix = pickle.load(cb)

    # transform them into Tensors
    word_embedding_matrix = torch.from_numpy(
        np.array(word_embedding_matrix)).type(torch.float32)
    char_embedding_matrix = torch.from_numpy(
        np.array(char_embedding_matrix)).type(torch.float32)
    idx2word = dict([(y, x) for x, y in word2idx.items()])

    context = clean_text(context)
    context = [w for w in word_tokenize(context) if w]

    question = clean_text(question)
    question = [w for w in word_tokenize(question) if w]

    if len(context) > config.max_len_context:
        print("The context is too long. Maximum accepted length is",
              config.max_len_context, "words.")
    if max([len(w) for w in context]) > config.max_len_word:
        print("Some words in the context are longer than", config.max_len_word,
              "characters.")
    if len(question) > config.max_len_question:
        print("The question is too long. Maximum accepted length is",
              config.max_len_question, "words.")
    if max([len(w) for w in question]) > config.max_len_word:
        print("Some words in the question are longer than",
              config.max_len_word, "characters.")
    if len(question) < 3:
        print(
            "The question is too short. It needs to be at least a three words question."
        )

    context_idx = np.zeros([config.max_len_context], dtype=np.int32)
    question_idx = np.zeros([config.max_len_question], dtype=np.int32)
    context_char_idx = np.zeros([config.max_len_context, config.max_len_word],
                                dtype=np.int32)
    question_char_idx = np.zeros(
        [config.max_len_question, config.max_len_word], dtype=np.int32)

    # replace 0 values with word and char IDs
    for j, word in enumerate(context):
        if word in word2idx:
            context_idx[j] = word2idx[word]
        else:
            context_idx[j] = 1
        for k, char in enumerate(word):
            if char in char2idx:
                context_char_idx[j, k] = char2idx[char]
            else:
                context_char_idx[j, k] = 1

    for j, word in enumerate(question):
        if word in word2idx:
            question_idx[j] = word2idx[word]
        else:
            question_idx[j] = 1
        for k, char in enumerate(word):
            if char in char2idx:
                question_char_idx[j, k] = char2idx[char]
            else:
                question_char_idx[j, k] = 1

    model = BiDAF(word_vectors=word_embedding_matrix,
                  char_vectors=char_embedding_matrix,
                  hidden_size=config.hidden_size,
                  drop_prob=config.drop_prob)
    try:
        if config.cuda:
            model.load_state_dict(
                torch.load(os.path.join(config.squad_models,
                                        "model_final.pkl"))["state_dict"])
        else:
            model.load_state_dict(
                torch.load(
                    os.path.join(config.squad_models, "model_final.pkl"),
                    map_location=lambda storage, loc: storage)["state_dict"])
        print("Model weights successfully loaded.")
    except:
        pass
        print(
            "Model weights not found, initialized model with random weights.")
    model.to(device)
    model.eval()
    with torch.no_grad():
        context_idx, context_char_idx, question_idx, question_char_idx = torch.tensor(context_idx, dtype=torch.int64).unsqueeze(0).to(device),\
                                                                         torch.tensor(context_char_idx, dtype=torch.int64).unsqueeze(0).to(device),\
                                                                         torch.tensor(question_idx, dtype=torch.int64).unsqueeze(0).to(device),\
                                                                         torch.tensor(question_char_idx, dtype=torch.int64).unsqueeze(0).to(device)

        pred1, pred2 = model(context_idx, context_char_idx, question_idx,
                             question_char_idx)
        starts, ends = discretize(pred1.exp(), pred2.exp(), 15, False)
        prediction = " ".join(context[starts.item():ends.item() + 1])

    return prediction