Beispiel #1
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if config.cuda:
        torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        train_dataset = HotpotDataset(train_buckets)
        return DataIterator(train_dataset,
                            config.para_limit,
                            config.ques_limit,
                            config.char_limit,
                            config.sent_limit,
                            batch_size=config.batch_size,
                            sampler=RandomSampler(train_dataset),
                            num_workers=2)

    def build_dev_iterator():
        dev_dataset = HotpotDataset(dev_buckets)
        return DataIterator(dev_dataset,
                            config.para_limit,
                            config.ques_limit,
                            config.char_limit,
                            config.sent_limit,
                            batch_size=config.batch_size,
                            num_workers=2)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda() if config.cuda else model
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    train_iterator = build_train_iterator()
    dev_iterator = build_dev_iterator()

    for epoch in range(10000):
        for data in train_iterator:
            if config.cuda:
                data = {
                    k: (data[k].cuda() if k != 'ids' else data[k])
                    for k in data
                }
            context_idxs = data['context_idxs']
            ques_idxs = data['ques_idxs']
            context_char_idxs = data['context_char_idxs']
            ques_char_idxs = data['ques_char_idxs']
            context_lens = data['context_lens']
            y1 = data['y1']
            y2 = data['y2']
            q_type = data['q_type']
            is_support = data['is_support']
            start_mapping = data['start_mapping']
            end_mapping = data['end_mapping']
            all_mapping = data['all_mapping']

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                context_lens.sum(1).max().item(),
                return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2

            optimizer.zero_grad()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                config.max_grad_norm if config.max_grad_norm > 0 else 1e10)
            optimizer.step()

            total_loss += loss.item()
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | gradnorm: {:6.3}'
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss,
                            grad_norm))
                total_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(dev_iterator, model, 0, dev_eval_file,
                                         config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Beispiel #2
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])
    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
# when function is defined (element for element in iterable if function(element))
# when function is None (element for element in iterable if element)
# lamda 为一个function argument 为p  返回值为p.requires_grad
# filter(lambda p: p.requires_grad, model.parameters) 为: 筛选出model.parameters 中所有requires_grad的元素
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
#     声明training: for drop_out 层
    model.train()

    for epoch in range(10000):
#         每一个data 是一个batch
        for idx, data in enumerate(build_train_iterator()):
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            is_support_word=  Variable(data['is_support_word'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, is_support_word,return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2

#             optimizer.zero_grad()
            loss.backward()
#             optimizer.step()
            if (i + 1) % 2 == 0:
                optimizer.step()
                optimizer.zero_grad()
            total_loss += loss.data.item()
            global_step += 1
#           记录损失
            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging('| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'.format(epoch, global_step, lr, elapsed*1000/config.period, cur_loss))
                total_loss = 0
                start_time = time.time()
# 存入checkpoint
            if global_step % config.checkpoint == 0:
#         设置model 为eval 状态
                model.eval()
#         返回值是一个dict: f1 
                metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint,
                    epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Beispiel #3
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size,
                            config.para_limit, config.ques_limit,
                            config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    #ori_model = model.cuda()
    #ori_model = model
    #model = nn.DataParallel(ori_model)
    model = model.cuda()

    lr = config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}

    best_dev_sp_f1 = 0

    # total_support_facts = 0
    # total_contexes = 0
    total_support_facts = 0
    total_contexes = 0

    for epoch in range(5):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])

            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])

            is_support = Variable(data['is_support'])

            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            # print(all_mapping.size())

            # total_support_facts += torch.sum(torch.sum(is_support))
            # total_contexes += is_support.size(0) * is_support.size(1)

            # context_idxs : torch.Size([1, 1767])
            # ques_idxs : torch.Size([1, 17])
            # context_char_idxs : torch.Size([1, 1767, 16])
            # ques_char_idxs : torch.Size([1, 17, 16])
            # context_lens : tensor([1767])

            # start_mapping : torch.Size([1, 1767, 65])
            # end_mapping : torch.Size([1, 1767, 65])
            # all_mapping : torch.Size([1, 1767, 65])

            # continue

            # change the input format into Sentences (input) -> Is support fact (target)

            # get total number of sentences
            number_of_sentences = int(
                torch.sum(torch.sum(start_mapping, dim=1)).item())
            #print('number_of_sentences=', number_of_sentences)
            # get sentence limit
            sentence_limit = config.sent_limit
            # get question limit
            ques_limit = config.ques_limit
            # get character limit
            char_limit = config.char_limit
            sent_limit = 600

            # allocate space
            sentence_idxs = torch.zeros(number_of_sentences,
                                        sent_limit,
                                        dtype=torch.long).cuda()
            sentence_char_idxs = torch.zeros(number_of_sentences,
                                             sent_limit,
                                             char_limit,
                                             dtype=torch.long).cuda()
            sentence_lengths = torch.zeros(number_of_sentences,
                                           dtype=torch.long).cuda()
            is_support_fact = torch.zeros(number_of_sentences,
                                          dtype=torch.long).cuda()

            ques_idxs_sen_impl = torch.zeros(number_of_sentences,
                                             ques_limit,
                                             dtype=torch.long).cuda()
            ques_char_idxs_sen_impl = torch.zeros(number_of_sentences,
                                                  ques_limit,
                                                  char_limit,
                                                  dtype=torch.long).cuda()

            # sentence_idxs = []
            # sentence_char_idxs = []
            # sentence_lengths = []
            # is_support_fact = []

            # ques_idxs_sen_impl = []
            # ques_char_idxs_sen_impl = []

            index = 0

            # for every batch
            for b in range(all_mapping.size(0)):
                # for every sentence
                for i in range(all_mapping.size(2)):
                    # get sentence map
                    sentence_map = all_mapping[b, :, i]
                    s = torch.sum(sentence_map)
                    # if there are no more sentences on this batch (but only zero-pads) then continue to next batch
                    if s == 0:
                        continue

                    # get sentence

                    # get starting index
                    starting_index = torch.argmax(start_mapping[b, :, i])
                    # get ending index
                    ending_index = torch.argmax(end_mapping[b, :, i]) + 1
                    # get sentence length
                    sentence_length = int(torch.sum(all_mapping[b, :, i]))
                    sentence = context_idxs[b, starting_index:ending_index]
                    sentence_chars = context_char_idxs[
                        b, starting_index:ending_index, :]
                    #print('sentence=', sentence)

                    #if sentence_length>100:
                    #    print('Sentence starts :', starting_index, ', & end :', ending_index, ', Total tokens sentence :', torch.sum(all_mapping[b,:,i]), 'sentence_length=', sentence_length, 'start mapping=',start_mapping[b,:,i], 'end mapping=', end_mapping[b,:,i])
                    #    os.system("pause")

                    sentence_idxs[index, :sentence_length] = sentence
                    sentence_char_idxs[
                        index, :sentence_length, :] = sentence_chars
                    sentence_lengths[index] = sentence_length
                    is_support_fact[index] = is_support[b, i]
                    # repeat for the question

                    ques_idxs_sen_impl[index, :ques_idxs[
                        b, :].size(0)] = ques_idxs[b, :]
                    ques_char_idxs_sen_impl[index, :ques_idxs[
                        b, :].size(0), :] = ques_char_idxs[b, :, :]
                    # append to lists
                    # sentence_idxs.append(sentence)
                    # sentence_char_idxs.append(sentence_chars)
                    # sentence_lengths.append(sentence_length)
                    # is_support_fact.append(is_support[b,i])

                    # repeat for the question
                    # ques_idxs_sen_impl.append(ques_idxs[b,:])
                    # ques_char_idxs_sen_impl.append(ques_char_idxs[b,:,:])

                    index += 1

            # zero padd
            sentence_length = torch.max(sentence_lengths)

            # torch.Tensor()

            # for i in range(len(sentence_idxs)):

            # sentence_idxs = torch.stack(sentence_idxs)
            # sentence_char_idxs = torch.stack(sentence_char_idxs)
            # sentence_lengths = torch.stack(sentence_lengths)
            # is_support_fact = torch.stack(is_support_fact)

            # ques_idxs_sen_impl = torch.stack(ques_idxs_sen_impl)
            # ques_char_idxs_sen_impl = torch.stack(ques_char_idxs_sen_impl)

            # predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)

            predict_support = model(sentence_idxs, ques_idxs_sen_impl,
                                    sentence_char_idxs,
                                    ques_char_idxs_sen_impl, sentence_length,
                                    start_mapping, end_mapping)

            # logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)
            # loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0)
            #loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1))
            #print('predict_support sz=',predict_support.size(), 'is_support_fact sz=', is_support_fact.size(), 'is_support_fact.unsqueeze=', is_support_fact.unsqueeze(1).size())
            loss_2 = nll_average(predict_support.contiguous(),
                                 is_support_fact.contiguous())
            # loss = loss_1 + config.sp_lambda * loss_2
            loss = loss_2

            # update train metrics
            #            train_metrics = update_sp(train_metrics, predict_support.view(-1, 2), is_support.view(-1))
            train_metrics = update_sp(train_metrics, predict_support,
                                      is_support_fact)

            #exit()

            # ps = predict_support.view(-1, 2)
            # iss = is_support.view(-1)

            # print('Predicted SP output  and ground truth:')
            # length = predict_support.view(-1, 2).shape[0]
            # for jj in range(length):
            #     print(ps[jj,1] , '   :   ', iss[jj])

            # temp = torch.cat([ predict_support.view(-1, 2).float(), is_support.view(-1)], dim=-1).contiguous()
            # print(temp)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data.item()
            global_step += 1

            if global_step % config.period == 0:

                # avegage metrics
                for key in train_metrics:
                    train_metrics[key] /= float(config.period)

                # cur_loss = total_loss / config.period
                cur_loss = total_loss

                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss,
                            train_metrics['sp_em'], train_metrics['sp_f1'],
                            train_metrics['sp_prec'],
                            train_metrics['sp_recall']))

                total_loss = 0
                train_metrics = {
                    'sp_em': 0,
                    'sp_f1': 0,
                    'sp_prec': 0,
                    'sp_recall': 0
                }
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()

                # metrics = evaluate_batch(build_dev_iterator(), model, 5, dev_eval_file, config)
                eval_metrics = evaluate_batch(build_dev_iterator(), model, 500,
                                              dev_eval_file, config)
                model.train()

                logging('-' * 89)
                # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint,
                #     epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f}| SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time,
                            eval_metrics['loss'], eval_metrics['sp_em'],
                            eval_metrics['sp_f1'], eval_metrics['sp_prec'],
                            eval_metrics['sp_recall']))
                logging('-' * 89)

                if eval_metrics['sp_f1'] > best_dev_sp_f1:
                    best_dev_sp_f1 = eval_metrics['sp_f1']
                    torch.save(model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr *= 0.75
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0

                eval_start_time = time.time()

        total_support_facts += torch.sum(torch.sum(is_support))
        total_contexes += is_support.size(0) * is_support.size(0)

        # print('total_support_facts :', total_support_facts)
        # print('total_contexes :', total_contexes)
        # exit()

        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_sp_f1))
Beispiel #4
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    train_buckets_squad = get_buckets('squad_record.pkl')  #yxh
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator(buckets=train_buckets,
                             batch_size=config.batch_size):
        return DataIterator(train_buckets, batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, True,
                            config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    #ori_model.load_state_dict(torch.load(os.path.join('HOTPOT-20191113-222741', 'model.pt')))#yxh
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    best_loss = None  #yxh
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(10000):
        #it2 = build_train_iterator(train_buckets_squad, 12)
        for data in build_train_iterator(batch_size=24):
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                is_support=is_support,
                return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2
            #loss = config.sp_lambda * nll_average(predict_support.view(-1, 2), is_support.view(-1))#yxh
            ###
            '''
            try:
                data2 = next(it2)
                context_idxs2 = Variable(data2['context_idxs'])
                ques_idxs2 = Variable(data2['ques_idxs'])
                context_char_idxs2 = Variable(data2['context_char_idxs'])
                ques_char_idxs2 = Variable(data2['ques_char_idxs'])
                context_lens2 = Variable(data2['context_lens'])
                y12 = Variable(data2['y1'])
                y22 = Variable(data2['y2'])
                q_type2 = Variable(data2['q_type'])
                is_support2 = Variable(data2['is_support'])
                start_mapping2 = Variable(data2['start_mapping'])
                end_mapping2 = Variable(data2['end_mapping'])
                all_mapping2 = Variable(data2['all_mapping'])
    
                logit12, logit22, predict_type2, predict_support2 = model(context_idxs2, ques_idxs2, context_char_idxs2, ques_char_idxs2, context_lens2, start_mapping2, end_mapping2, all_mapping2, return_yp=False)
                #loss_12 = (nll_sum(predict_type2, q_type2) + nll_sum(logit12, y12) + nll_sum(logit22, y22)) / context_idxs.size(0)
                loss_22 = nll_average(predict_support2.view(-1, 2), is_support2.view(-1))
                loss2 = config.sp_lambda * loss_22            
                loss = loss+loss2
            except:
                pass
            '''
            ###
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data  #[0]
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss))
                total_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
            #break#yxh
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Beispiel #5
0
def train(config):
    experiment = Experiment(api_key="Q8LzfxMlAfA3ABWwq9fJDoR6r",
                            project_name="hotpot",
                            workspace="fan-luo")
    experiment.set_name(config.run_name)

    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size,
                            config.para_limit, config.ques_limit,
                            config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)
    print("next(model.parameters()).is_cuda: " +
          str(next(model.parameters()).is_cuda))
    print("next(ori_model.parameters()).is_cuda: " +
          str(next(ori_model.parameters()).is_cuda))

    lr = config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    total_ans_loss = 0
    total_sp_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(10000):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data[0]
            total_ans_loss += loss_1.data[0]
            total_sp_loss += loss_2.data[0]
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                cur_ans_loss = total_ans_loss / config.period
                cur_sp_loss = total_sp_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | answer loss {:8.3f} | supporting facts loss {:8.3f} '
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss,
                            cur_ans_loss, cur_sp_loss))
                experiment.log_metrics(
                    {
                        'train loss': cur_loss,
                        'train answer loss': cur_ans_loss,
                        'train supporting facts loss': cur_sp_loss
                    },
                    step=global_step)
                total_loss = 0
                total_ans_loss = 0
                total_sp_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | answer loss {:8.3f} | supporting facts loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['ans_loss'], metrics['sp_loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)
                experiment.log_metrics(
                    {
                        'dev loss': metrics['loss'],
                        'dev answer loss': metrics['ans_loss'],
                        'dev supporting facts loss': metrics['sp_loss'],
                        'EM': metrics['exact_match'],
                        'F1': metrics['f1']
                    },
                    step=global_step)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Beispiel #6
0
def train(config):
    if config.bert:
        if config.bert_with_glove:
            with open(config.word_emb_file, "r") as fh:
                word_mat = np.array(json.load(fh), dtype=np.float32)
        else:
            word_mat = None
    else:
        with open(config.word_emb_file, "r") as fh:
            word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save,
                   scripts_to_save=[
                       'run.py', 'model.py', 'util.py', 'sp_model.py',
                       'macnet_v2.py'
                   ])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)
    max_iter = math.ceil(len(train_buckets[0]) / config.batch_size)

    def build_train_iterator():
        if config.bert:
            train_context_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.train_bert_emb_context))
            train_ques_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.train_bert_emb_ques))
            return DataIterator(train_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                True,
                                config.sent_limit,
                                bert=True,
                                bert_buckets=list(
                                    zip(train_context_buckets,
                                        train_ques_buckets)),
                                new_spans=True)
        else:
            return DataIterator(train_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                True,
                                config.sent_limit,
                                bert=False,
                                new_spans=True)

    def build_dev_iterator():
        # Iterator for inference during training
        if config.bert:
            dev_context_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.dev_bert_emb_context))
            dev_ques_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.dev_bert_emb_ques))
            return DataIterator(dev_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                False,
                                config.sent_limit,
                                bert=True,
                                bert_buckets=list(
                                    zip(dev_context_buckets,
                                        dev_ques_buckets)))
        else:
            return DataIterator(dev_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                False,
                                config.sent_limit,
                                bert=False)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    if config.optim == 'sgd':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=config.init_lr)
    elif config.optim == 'adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=config.init_lr)
    if config.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               config.epoch *
                                                               max_iter,
                                                               eta_min=0.0001)
    elif config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=config.patience,
            verbose=True)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(config.epoch):
        for data in build_train_iterator():
            scheduler.step()
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            is_y1 = Variable(data['is_y1'])
            is_y2 = Variable(data['is_y2'])

            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])
            if config.bert:
                bert_context = Variable(data['bert_context'])
                bert_ques = Variable(data['bert_ques'])
            else:
                bert_context = None
                bert_ques = None

            support_ids = (is_support == 1).nonzero()
            sp_dict = {idx: [] for idx in range(context_idxs.shape[0])}
            for row in support_ids:
                bid, sp_sent_id = row
                sp_dict[bid.data.cpu()[0]].append(sp_sent_id.data.cpu()[0])

            if config.sp_shuffle:
                [random.shuffle(value) for value in sp_dict.values()]

            sp1_labels = []
            sp2_labels = []
            sp3_labels = []
            sp4_labels = []
            sp5_labels = []
            for item in sorted(sp_dict.items(), key=lambda t: t[0]):
                bid, supports = item
                if len(supports) == 1:
                    sp1_labels.append(supports[0])
                    sp2_labels.append(-1)
                    sp3_labels.append(-2)
                    sp4_labels.append(-2)
                    sp5_labels.append(-2)
                elif len(supports) == 2:
                    sp1_labels.append(supports[0])
                    sp2_labels.append(supports[1])
                    sp3_labels.append(-1)
                    sp4_labels.append(-2)
                    sp5_labels.append(-2)
                elif len(supports) == 3:
                    sp1_labels.append(supports[0])
                    sp2_labels.append(supports[1])
                    sp3_labels.append(supports[2])
                    sp4_labels.append(-1)
                    sp5_labels.append(-2)
                elif len(
                        supports) >= 4:  # 4 or greater sp are treated the same
                    sp1_labels.append(supports[0])
                    sp2_labels.append(supports[1])
                    sp3_labels.append(supports[2])
                    sp4_labels.append(supports[3])
                    sp5_labels.append(-1)

            # We will append 2 vectors to the front (sp_output_with_end), so we increment indices by 2
            sp1_labels = np.array(sp1_labels) + 2
            sp2_labels = np.array(sp2_labels) + 2
            sp3_labels = np.array(sp3_labels) + 2
            sp4_labels = np.array(sp4_labels) + 2
            sp5_labels = np.array(sp5_labels) + 2
            sp_labels_mod = Variable(
                torch.LongTensor(
                    np.stack([
                        sp1_labels, sp2_labels, sp3_labels, sp4_labels,
                        sp5_labels
                    ], 1)).cuda())

            sp_mod_mask = sp_labels_mod > 0

            if config.bert:
                logit1, logit2, grn_logit1, grn_logit2, coarse_att, predict_type, predict_support, sp_att_logits = model(
                    context_idxs,
                    ques_idxs,
                    context_char_idxs,
                    ques_char_idxs,
                    context_lens,
                    start_mapping,
                    end_mapping,
                    all_mapping,
                    return_yp=False,
                    support_labels=is_support,
                    is_train=True,
                    sp_labels_mod=sp_labels_mod,
                    bert=True,
                    bert_context=bert_context,
                    bert_ques=bert_ques)
            else:
                logit1, logit2, grn_logit1, grn_logit2, coarse_att, predict_type, predict_support, sp_att_logits = model(
                    context_idxs,
                    ques_idxs,
                    context_char_idxs,
                    ques_char_idxs,
                    context_lens,
                    start_mapping,
                    end_mapping,
                    all_mapping,
                    return_yp=False,
                    support_labels=is_support,
                    is_train=True,
                    sp_labels_mod=sp_labels_mod,
                    bert=False)

            sp_att1 = sp_att_logits[0].squeeze(-1)
            sp_att2 = sp_att_logits[1].squeeze(-1)
            sp_att3 = sp_att_logits[2].squeeze(-1)
            if config.reasoning_steps == 4:
                sp_att4 = sp_att_logits[3].squeeze(-1)

            # Add masks to targets(labels) to ignore in loss calculation
            sp1_labels = ((sp_mod_mask[:, 0].float() - 1) *
                          100).cpu().data.numpy() + sp1_labels
            sp2_labels = ((sp_mod_mask[:, 1].float() - 1) *
                          100).cpu().data.numpy() + sp2_labels
            sp3_labels = ((sp_mod_mask[:, 2].float() - 1) *
                          100).cpu().data.numpy() + sp3_labels
            sp4_labels = ((sp_mod_mask[:, 3].float() - 1) *
                          100).cpu().data.numpy() + sp4_labels

            sp1_loss = nll_average(
                sp_att1, Variable(torch.LongTensor(sp1_labels).cuda()))
            sp2_loss = nll_average(
                sp_att2, Variable(torch.LongTensor(sp2_labels).cuda()))
            sp3_loss = nll_average(
                sp_att3, Variable(torch.LongTensor(sp3_labels).cuda()))
            if config.reasoning_steps == 4:
                sp4_loss = nll_average(
                    sp_att4, Variable(torch.LongTensor(sp4_labels).cuda()))
            else:
                sp4_loss = 0

            GRN_SP_PRED = True  # temporary flag

            batch_losses = []
            for bid, spans in enumerate(data['gold_mention_spans']):
                if (not spans == None) and (not spans == []):
                    try:
                        _loss_start_avg = F.cross_entropy(
                            logit1[bid].view(1, -1).expand(
                                len(spans), len(logit1[bid])),
                            Variable(torch.LongTensor(spans))[:, 0].cuda(),
                            reduce=True)
                        _loss_end_avg = F.cross_entropy(
                            logit2[bid].view(1, -1).expand(
                                len(spans), len(logit2[bid])),
                            Variable(torch.LongTensor(spans))[:, 1].cuda(),
                            reduce=True)
                    except IndexError:
                        ipdb.set_trace()
                    batch_losses.append((_loss_start_avg + _loss_end_avg))
            loss_1 = torch.mean(torch.stack(batch_losses)) + nll_sum(
                predict_type, q_type) / context_idxs.size(0)

            if GRN_SP_PRED:
                loss_2 = 0
            else:
                loss_2 = nll_average(predict_support.view(-1, 2),
                                     is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2 + sp1_loss + sp2_loss + sp3_loss + sp4_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data[0]
            global_step += 1
            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'
                    .format(epoch, global_step,
                            optimizer.param_groups[0]['lr'],
                            elapsed * 1000 / config.period, cur_loss))
                total_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()
                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if config.scheduler == 'plateau':
                    scheduler.step(dev_F1)
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Beispiel #7
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    # train_buckets = dev_buckets

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size,
                            config.para_limit, config.ques_limit,
                            config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    # ori_model = model.cpu()

    # flag (checking if the learning will be loaded from the file or not)
    # lr = 0

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()))
    # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr)

    # if the training was interrupted, then we load last trained state of the model
    logging('Checking if previous training was interrupted...')

    my_file = Path('temp_model')
    if my_file.exists():
        logging(
            'Previous training was interupted, loading model and optimizer state'
        )
        ori_model.load_state_dict(torch.load('temp_model'))
        optimizer.load_state_dict(torch.load('temp_optimizer'))

        # training_state = training_state.load_whole_class('last_training_state.pickle')
        # training_state = training_state.load_whole_class('last_training_state.pickle')

    # # if os.path.exists('last_model.pt'):
    # for dp, dn, filenames in os.walk('.'):
    #     model_filenames = []
    #     corresponding_learning_rates = []
    #     for ff in filenames:
    #         if ff.endswith("last_model.pt"):
    #             # putting all found models on list
    #             lr = float(ff.split('_')[0])
    #             model_filenames.append(ff)
    #             corresponding_learning_rates.append(lr)

    #     if len( model_filenames) > 0:
    #         # selecting the model with the smallest learning rate to be loaded
    #         loading_model_index = np.argmin(corresponding_learning_rates)
    #         # continuing with the previous learning rate
    #         lr = corresponding_learning_rates[loading_model_index]

    #         logging('Previous training was interrupted so loading last saved state model and continuing.')
    #         logging('Was stopped with learning rate: ' +  str( corresponding_learning_rates[loading_model_index] ) )
    #         logging('Loading file : ' + model_filenames[loading_model_index])
    #         ori_model.load_state_dict(torch.load(model_filenames[loading_model_index]))

    # model = nn.DataParallel(ori_model)
    # if the learning rate was not loaded then we set it equal to the initial one
    # if lr == 0:
    #     lr = config.init_lr

    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}

    best_dev_sp_f1 = 0

    # total_support_facts = 0
    # total_contexes = 0

    for epoch in range(1000):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])

            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            sentence_scores = Variable(data['sentence_scores'])

            # get model's output
            # predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)

            # calculating loss
            # loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1) )

            # logging('Batch training loss :' + str(loss_2.item()) )

            # loss = loss_2

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                sentence_scores,
                return_yp=False)

            # print('predict_type :', predict_type)
            # print(' q_type:',q_type )
            # print(' logit1:', logit1)
            # print('y1 :',y1 )
            # print(' logit2:', logit2)
            # print(' y2:', y2)

            # print(' predict_support:', predict_support)
            # print(' is_support:', is_support)

            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1).long())

            loss = loss_1 + config.sp_lambda * loss_2

            # log both losses :

            # with open(os.path.join(config.save, 'losses_log.txt'), 'a+') as f_log:
            #     s = 'Answer loss :' + str(loss_1.data.item()) + ', SF loss :' + str(loss_2.data.item()) + ', Total loss :' + str(loss.data.item())
            #     f_log.write(s + '\n')

            # update train metrics
            # train_metrics = update_sp(train_metrics, predict_support, is_support)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data.item()
            global_step += 1

            if global_step % config.period == 0:

                # avegage metrics
                for key in train_metrics:
                    train_metrics[key] /= float(config.period)

                # cur_loss = total_loss / config.period
                cur_loss = total_loss

                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | ms/batch {:5.2f} | train loss {:8.3f}'
                    .format(epoch, global_step, elapsed * 1000 / config.period,
                            cur_loss))
                # logging('| epoch {:3d} | step {:6d} | ms/batch {:5.2f} | train loss {:8.3f} | SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'.format(epoch, global_step, elapsed*1000/config.period, cur_loss, train_metrics['sp_em'], train_metrics['sp_f1'], train_metrics['sp_prec'], train_metrics['sp_recall']))

                total_loss = 0
                # train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}
                start_time = time.time()

                # logging('Saving model...')
                torch.save(ori_model.state_dict(), os.path.join('temp_model'))
                torch.save(optimizer.state_dict(),
                           os.path.join('temp_optimizer'))

            # if global_step % config.checkpoint == 0:
            #     model.eval()

            #     # metrics = evaluate_batch(build_dev_iterator(), model, 5, dev_eval_file, config)
            #     eval_metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config)
            #     model.train()

            #     logging('-' * 89)
            #     # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint,
            #     #     epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))
            #     logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f}| SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'.format(global_step//config.checkpoint,
            #         epoch, time.time()-eval_start_time, eval_metrics['loss'], eval_metrics['sp_em'], eval_metrics['sp_f1'], eval_metrics['sp_prec'], eval_metrics['sp_recall']))
            #     logging('-' * 89)

            #     eval_start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    torch.save(optimizer.state_dict(),
                               os.path.join(config.save, 'optimizer.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        stop_train = True
                        break

                # if eval_metrics['sp_f1'] > best_dev_sp_f1:
                #     best_dev_sp_f1 = eval_metrics['sp_f1']
                #     torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt'))
                #     torch.save(optimizer.state_dict(), os.path.join(config.save,  'optimizer.pt') )
                #     # cur_lr_decrease_patience = 0
                #     cur_patience = 0
                # else:
                #     cur_patience += 1
                #     if cur_patience >= config.patience:
                #         stop_train = True
                #         break

                # lr *= 0.75
                # for param_group in optimizer.param_groups:
                #     param_group['lr'] = lr
                # if lr < config.init_lr * 1e-2:
                #     stop_train = True
                #     break
                # cur_patience = 0

                # cur_lr_decrease_patience += 1
                # if cur_lr_decrease_patience >= config.lr_decrease_patience:
                #     lr *= 0.75
                #     cur_early_stop_patience +=1
                #     if cur_early_stop_patience >= config.early_stop_patience:
                #         stop_train = True
                #         break

                #     for param_group in optimizer.param_groups:
                #         param_group['lr'] = lr

        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))

    # delete last temporary trained model, since the training has completed
    print('Deleting last temp model files...')
    # for dp, dn, filenames in os.walk('.'):
    #     for ff in filenames:
    # if ff.endswith("last_model.pt"):
    os.remove('temp_model')
    os.remove('temp_optimizer')
Beispiel #8
0
def train():
    # 1. 数据集加载
    with open(Config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(Config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(Config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(Config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    train_buckets = [torch.load(Config.train_record_file)]
    dev_buckets = [torch.load(Config.dev_record_file)]

    # (self, buckets, bsz, para_limit, ques_limit, char_limit, shuffle, sent_limit
    def build_train_iterator():
        return DataIterator(train_buckets, Config.batch_size,
                            Config.para_limit, Config.ques_limit,
                            Config.char_limit, True, Config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, Config.batch_size, Config.para_limit,
                            Config.ques_limit, Config.char_limit, False,
                            Config.sent_limit)

    if Config.sp_lambda > 0:
        model = SPModel(word_mat, char_mat)
    else:
        model = Model(word_mat, char_mat)

    print('需要更新的参数量:{}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    # 需要更新的参数量:235636

    ori_model = model.to(Config.device)

    lr = Config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=Config.init_lr)

    cur_patience = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    model.train()
    for epoch in range(10000):
        for data in build_train_iterator():
            global_step += 1
            context_idxs = data['context_idxs'].to(Config.device)
            ques_idxs = data['ques_idxs'].to(Config.device)
            context_char_idxs = data['context_char_idxs'].to(Config.device)
            ques_char_idxs = data['ques_char_idxs'].to(Config.device)
            context_lens = data['context_lens'].to(Config.device)

            y1 = data['y1'].to(Config.device)
            y2 = data['y2'].to(Config.device)
            q_type = data['q_type'].to(Config.device)
            is_support = data['is_support'].to(Config.device)
            start_mapping = data['start_mapping'].to(Config.device)
            end_mapping = data['end_mapping'].to(Config.device)
            all_mapping = data['all_mapping'].to(Config.device)
            # print(context_idxs.size())    # torch.Size([2, 942])
            # print(context_char_idxs.size())   # torch.Size([2, 942, 16])
            # print(ques_idxs.size())     # torch.Size([2, 18])
            # print(ques_char_idxs.size())    # torch.Size([2, 18, 16])
            # print(y1.size())    # torch.Size([2])
            # print(y2.size())    # torch.Size([2])
            # print(q_type.size())    # torch.Size([2])
            # print(is_support.size())     # torch.Size([2, 49])
            # print(start_mapping.size())   # torch.Size([2, 942, 49])
            # print(end_mapping.size())    # torch.Size([2, 942, 49])
            # print(all_mapping.size())   # torch.Size([2, 942, 49])

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                return_yp=False)

            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss = loss_1 + Config.sp_lambda * loss_2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(
                '| epoch {:3d} | step {:6d} | lr {:05.5f} | train loss {:8.3f}'
                .format(epoch, global_step, lr, loss))

            if global_step % 10 == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file)
                model.train()

                print('dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(
                    epoch, metrics['loss'], metrics['exact_match'],
                    metrics['f1']))

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(Config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= Config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < Config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train:
            break
    print('best_dev_F1 {}'.format(best_dev_F1))