Exemple #1
0
def test(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    if config.data_split == 'dev':
        with open(config.dev_eval_file, "r") as fh:
            dev_eval_file = json.load(fh)
    else:
        with open(config.test_eval_file, 'r') as fh:
            dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if config.cuda:
        torch.cuda.manual_seed_all(config.seed)

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    if config.data_split == 'dev':
        dev_buckets = get_buckets(config.dev_record_file)
        para_limit = config.para_limit
        ques_limit = config.ques_limit
    elif config.data_split == 'test':
        para_limit = None
        ques_limit = None
        dev_buckets = get_buckets(config.test_record_file)

    def build_dev_iterator():
        dev_dataset = HotpotDataset(dev_buckets)
        return DataIterator(dev_dataset,
                            config.para_limit,
                            config.ques_limit,
                            config.char_limit,
                            config.sent_limit,
                            batch_size=config.batch_size,
                            num_workers=2)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)
    ori_model = model.cuda() if config.cuda else model
    ori_model.load_state_dict(
        torch.load(os.path.join(config.save, 'model.pt'),
                   map_location=lambda storage, loc: storage))
    #    model = nn.DataParallel(ori_model)
    model = ori_model

    model.eval()
    predict(build_dev_iterator(), model, dev_eval_file, config,
            config.prediction_file)
Exemple #2
0
def test(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    if config.data_split == 'dev':
        with open(config.dev_eval_file, "r") as fh:
            dev_eval_file = json.load(fh)
    else:
        with open(config.test_eval_file, 'r') as fh:
            dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)
    with open(config.relation_emb_file, "r") as fh:
        relation_mat = np.array(json.load(fh), dtype=np.float32)  # (20, 100)
    with open(config.idx2relation_file, 'r') as fh:
        idx2relation_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    if config.data_split == 'dev':
        dev_buckets = get_buckets(config.dev_record_file)
        para_limit = config.para_limit
        ques_limit = config.ques_limit
    elif config.data_split == 'test':
        para_limit = None
        ques_limit = None
        dev_buckets = get_buckets(config.test_record_file)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, para_limit,
                            ques_limit, config.char_limit, False,
                            config.sent_limit, config.num_relations)

    model = Model(config, word_mat, char_mat, relation_mat)

    ori_model = model.cuda()
    ori_model.load_state_dict(torch.load(os.path.join(config.save,
                                                      'model.pt')))
    model = nn.DataParallel(ori_model)

    model.eval()
    predict(build_dev_iterator(), model, dev_eval_file, config,
            config.prediction_file, idx2relation_dict)
Exemple #3
0
def test(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    if config.data_split == 'dev':
        with open(config.dev_eval_file, "r") as fh:
            dev_eval_file = json.load(fh)
    else:
        with open(config.test_eval_file, 'r') as fh:
            dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    if config.data_split == 'dev':
#         我们暂时用的是dev 
#         不知道data_split的含义
        dev_buckets = get_buckets(config.dev_record_file)
        para_limit = config.para_limit
        ques_limit = config.ques_limit
    elif config.data_split == 'test':
        para_limit = None
        ques_limit = None
        dev_buckets = get_buckets(config.test_record_file)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, para_limit,
            ques_limit, config.char_limit, False, config.sent_limit)
Exemple #4
0
def train(use_attention, num_steps=1000, ckpt_dir="./ckp-dir/", write_summary=True, tag=None):
    test_buckets, data_buckets, train_buckets_scale = get_buckets()

    if not use_attention:
        model = BasicChatBotModel(batch_size=config.BATCH_SIZE)
    else:
        model = AttentionChatBotModel(batch_size=config.BATCH_SIZE)
    model.build()

    cfg = tf.ConfigProto()
    cfg.gpu_options.allow_growth = True
    with tf.Session(config=cfg) as sess:
        saver = tf.train.Saver()
        log_root = "./logs/"
        exp_name = (("attention" if use_attention else "basic") +
                    "-step_" + str(num_steps) +
                    "-batch_" + str(config.BATCH_SIZE) +
                    "-lr_" + str(config.LR))
        if tag:
            exp_name += "-" + tag
        summary_writer = tf.summary.FileWriter(log_root + exp_name, graph=sess.graph)
        sess.run(tf.global_variables_initializer())
        for step in range(num_steps + 1):
            bucket_id = _get_random_bucket(train_buckets_scale)
            encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(
                data_buckets[bucket_id], bucket_id, batch_size=config.BATCH_SIZE)
            decoder_lens = np.sum(np.transpose(np.array(decoder_masks), (1, 0)), axis=1)
            feed_dict = {model.encoder_inputs_tensor: encoder_inputs, model.decoder_inputs_tensor: decoder_inputs,
                         model.decoder_length_tensor: decoder_lens,
                         model.bucket_length: config.BUCKETS[bucket_id]}
            output_logits, res_loss, _ = sess.run([model.final_outputs, model.loss, model.train_op],
                                                  feed_dict=feed_dict)

            if need_print_log(step):
                print("Iteration {} - loss:{}".format(step, res_loss))
                if write_summary:
                    summaries = sess.run(model.summaries, feed_dict=feed_dict)
                    summary_writer.add_summary(summaries, step)
                saver.save(sess, ckpt_dir + exp_name + "/checkpoints", global_step=step)
Exemple #5
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if config.cuda:
        torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        train_dataset = HotpotDataset(train_buckets)
        return DataIterator(train_dataset,
                            config.para_limit,
                            config.ques_limit,
                            config.char_limit,
                            config.sent_limit,
                            batch_size=config.batch_size,
                            sampler=RandomSampler(train_dataset),
                            num_workers=2)

    def build_dev_iterator():
        dev_dataset = HotpotDataset(dev_buckets)
        return DataIterator(dev_dataset,
                            config.para_limit,
                            config.ques_limit,
                            config.char_limit,
                            config.sent_limit,
                            batch_size=config.batch_size,
                            num_workers=2)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda() if config.cuda else model
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    train_iterator = build_train_iterator()
    dev_iterator = build_dev_iterator()

    for epoch in range(10000):
        for data in train_iterator:
            if config.cuda:
                data = {
                    k: (data[k].cuda() if k != 'ids' else data[k])
                    for k in data
                }
            context_idxs = data['context_idxs']
            ques_idxs = data['ques_idxs']
            context_char_idxs = data['context_char_idxs']
            ques_char_idxs = data['ques_char_idxs']
            context_lens = data['context_lens']
            y1 = data['y1']
            y2 = data['y2']
            q_type = data['q_type']
            is_support = data['is_support']
            start_mapping = data['start_mapping']
            end_mapping = data['end_mapping']
            all_mapping = data['all_mapping']

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                context_lens.sum(1).max().item(),
                return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2

            optimizer.zero_grad()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                config.max_grad_norm if config.max_grad_norm > 0 else 1e10)
            optimizer.step()

            total_loss += loss.item()
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | gradnorm: {:6.3}'
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss,
                            grad_norm))
                total_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(dev_iterator, model, 0, dev_eval_file,
                                         config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Exemple #6
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size,
                            config.para_limit, config.ques_limit,
                            config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    #ori_model = model.cuda()
    #ori_model = model
    #model = nn.DataParallel(ori_model)
    model = model.cuda()

    lr = config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}

    best_dev_sp_f1 = 0

    # total_support_facts = 0
    # total_contexes = 0
    total_support_facts = 0
    total_contexes = 0

    for epoch in range(5):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])

            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])

            is_support = Variable(data['is_support'])

            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            # print(all_mapping.size())

            # total_support_facts += torch.sum(torch.sum(is_support))
            # total_contexes += is_support.size(0) * is_support.size(1)

            # context_idxs : torch.Size([1, 1767])
            # ques_idxs : torch.Size([1, 17])
            # context_char_idxs : torch.Size([1, 1767, 16])
            # ques_char_idxs : torch.Size([1, 17, 16])
            # context_lens : tensor([1767])

            # start_mapping : torch.Size([1, 1767, 65])
            # end_mapping : torch.Size([1, 1767, 65])
            # all_mapping : torch.Size([1, 1767, 65])

            # continue

            # change the input format into Sentences (input) -> Is support fact (target)

            # get total number of sentences
            number_of_sentences = int(
                torch.sum(torch.sum(start_mapping, dim=1)).item())
            #print('number_of_sentences=', number_of_sentences)
            # get sentence limit
            sentence_limit = config.sent_limit
            # get question limit
            ques_limit = config.ques_limit
            # get character limit
            char_limit = config.char_limit
            sent_limit = 600

            # allocate space
            sentence_idxs = torch.zeros(number_of_sentences,
                                        sent_limit,
                                        dtype=torch.long).cuda()
            sentence_char_idxs = torch.zeros(number_of_sentences,
                                             sent_limit,
                                             char_limit,
                                             dtype=torch.long).cuda()
            sentence_lengths = torch.zeros(number_of_sentences,
                                           dtype=torch.long).cuda()
            is_support_fact = torch.zeros(number_of_sentences,
                                          dtype=torch.long).cuda()

            ques_idxs_sen_impl = torch.zeros(number_of_sentences,
                                             ques_limit,
                                             dtype=torch.long).cuda()
            ques_char_idxs_sen_impl = torch.zeros(number_of_sentences,
                                                  ques_limit,
                                                  char_limit,
                                                  dtype=torch.long).cuda()

            # sentence_idxs = []
            # sentence_char_idxs = []
            # sentence_lengths = []
            # is_support_fact = []

            # ques_idxs_sen_impl = []
            # ques_char_idxs_sen_impl = []

            index = 0

            # for every batch
            for b in range(all_mapping.size(0)):
                # for every sentence
                for i in range(all_mapping.size(2)):
                    # get sentence map
                    sentence_map = all_mapping[b, :, i]
                    s = torch.sum(sentence_map)
                    # if there are no more sentences on this batch (but only zero-pads) then continue to next batch
                    if s == 0:
                        continue

                    # get sentence

                    # get starting index
                    starting_index = torch.argmax(start_mapping[b, :, i])
                    # get ending index
                    ending_index = torch.argmax(end_mapping[b, :, i]) + 1
                    # get sentence length
                    sentence_length = int(torch.sum(all_mapping[b, :, i]))
                    sentence = context_idxs[b, starting_index:ending_index]
                    sentence_chars = context_char_idxs[
                        b, starting_index:ending_index, :]
                    #print('sentence=', sentence)

                    #if sentence_length>100:
                    #    print('Sentence starts :', starting_index, ', & end :', ending_index, ', Total tokens sentence :', torch.sum(all_mapping[b,:,i]), 'sentence_length=', sentence_length, 'start mapping=',start_mapping[b,:,i], 'end mapping=', end_mapping[b,:,i])
                    #    os.system("pause")

                    sentence_idxs[index, :sentence_length] = sentence
                    sentence_char_idxs[
                        index, :sentence_length, :] = sentence_chars
                    sentence_lengths[index] = sentence_length
                    is_support_fact[index] = is_support[b, i]
                    # repeat for the question

                    ques_idxs_sen_impl[index, :ques_idxs[
                        b, :].size(0)] = ques_idxs[b, :]
                    ques_char_idxs_sen_impl[index, :ques_idxs[
                        b, :].size(0), :] = ques_char_idxs[b, :, :]
                    # append to lists
                    # sentence_idxs.append(sentence)
                    # sentence_char_idxs.append(sentence_chars)
                    # sentence_lengths.append(sentence_length)
                    # is_support_fact.append(is_support[b,i])

                    # repeat for the question
                    # ques_idxs_sen_impl.append(ques_idxs[b,:])
                    # ques_char_idxs_sen_impl.append(ques_char_idxs[b,:,:])

                    index += 1

            # zero padd
            sentence_length = torch.max(sentence_lengths)

            # torch.Tensor()

            # for i in range(len(sentence_idxs)):

            # sentence_idxs = torch.stack(sentence_idxs)
            # sentence_char_idxs = torch.stack(sentence_char_idxs)
            # sentence_lengths = torch.stack(sentence_lengths)
            # is_support_fact = torch.stack(is_support_fact)

            # ques_idxs_sen_impl = torch.stack(ques_idxs_sen_impl)
            # ques_char_idxs_sen_impl = torch.stack(ques_char_idxs_sen_impl)

            # predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)

            predict_support = model(sentence_idxs, ques_idxs_sen_impl,
                                    sentence_char_idxs,
                                    ques_char_idxs_sen_impl, sentence_length,
                                    start_mapping, end_mapping)

            # logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)
            # loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0)
            #loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1))
            #print('predict_support sz=',predict_support.size(), 'is_support_fact sz=', is_support_fact.size(), 'is_support_fact.unsqueeze=', is_support_fact.unsqueeze(1).size())
            loss_2 = nll_average(predict_support.contiguous(),
                                 is_support_fact.contiguous())
            # loss = loss_1 + config.sp_lambda * loss_2
            loss = loss_2

            # update train metrics
            #            train_metrics = update_sp(train_metrics, predict_support.view(-1, 2), is_support.view(-1))
            train_metrics = update_sp(train_metrics, predict_support,
                                      is_support_fact)

            #exit()

            # ps = predict_support.view(-1, 2)
            # iss = is_support.view(-1)

            # print('Predicted SP output  and ground truth:')
            # length = predict_support.view(-1, 2).shape[0]
            # for jj in range(length):
            #     print(ps[jj,1] , '   :   ', iss[jj])

            # temp = torch.cat([ predict_support.view(-1, 2).float(), is_support.view(-1)], dim=-1).contiguous()
            # print(temp)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data.item()
            global_step += 1

            if global_step % config.period == 0:

                # avegage metrics
                for key in train_metrics:
                    train_metrics[key] /= float(config.period)

                # cur_loss = total_loss / config.period
                cur_loss = total_loss

                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss,
                            train_metrics['sp_em'], train_metrics['sp_f1'],
                            train_metrics['sp_prec'],
                            train_metrics['sp_recall']))

                total_loss = 0
                train_metrics = {
                    'sp_em': 0,
                    'sp_f1': 0,
                    'sp_prec': 0,
                    'sp_recall': 0
                }
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()

                # metrics = evaluate_batch(build_dev_iterator(), model, 5, dev_eval_file, config)
                eval_metrics = evaluate_batch(build_dev_iterator(), model, 500,
                                              dev_eval_file, config)
                model.train()

                logging('-' * 89)
                # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint,
                #     epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f}| SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time,
                            eval_metrics['loss'], eval_metrics['sp_em'],
                            eval_metrics['sp_f1'], eval_metrics['sp_prec'],
                            eval_metrics['sp_recall']))
                logging('-' * 89)

                if eval_metrics['sp_f1'] > best_dev_sp_f1:
                    best_dev_sp_f1 = eval_metrics['sp_f1']
                    torch.save(model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr *= 0.75
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0

                eval_start_time = time.time()

        total_support_facts += torch.sum(torch.sum(is_support))
        total_contexes += is_support.size(0) * is_support.size(0)

        # print('total_support_facts :', total_support_facts)
        # print('total_contexes :', total_contexes)
        # exit()

        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_sp_f1))
Exemple #7
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.relation_emb_file, "r") as fh:
        relation_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save,
                   scripts_to_save=['run.py', 'model.py', 'util.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, \
            config.char_limit, True, config.sent_limit, config.num_relations)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, \
            config.char_limit, False, config.sent_limit, config.num_relations)

    model = Model(config, word_mat, char_mat, relation_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(10000):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            #
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])
            #
            subject_y1 = Variable(data['subject_y1'])
            subject_y2 = Variable(data['subject_y2'])
            object_y1 = Variable(data['object_y1'])
            object_y2 = Variable(data['object_y2'])
            relations = Variable(data['relations'])
            #
            #
            model_results = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, relations, \
                context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)
            #
            (logit1, logit2, predict_type, predict_support, logit_subject_start, logit_subject_end, \
                logit_object_start, logit_object_end, k_relations, loss_relation) = model_results
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss_3_r = torch.sum(loss_relation)
            loss_3_s = (nll_sum(logit_subject_start, subject_y1) + nll_sum(
                logit_subject_end, subject_y2)) / context_idxs.size(0)
            loss_3_o = (nll_sum(logit_object_start, object_y1) + nll_sum(
                logit_object_end, object_y2)) / context_idxs.size(0)
            #
            loss = loss_1 + config.sp_lambda * loss_2 + config.evi_lambda * (
                loss_3_s + loss_3_r + loss_3_o)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()  # total_loss += loss.data[0]
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss))
                total_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Exemple #8
0
def train(config):
    experiment = Experiment(api_key="Q8LzfxMlAfA3ABWwq9fJDoR6r",
                            project_name="hotpot",
                            workspace="fan-luo")
    experiment.set_name(config.run_name)

    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size,
                            config.para_limit, config.ques_limit,
                            config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)
    print("next(model.parameters()).is_cuda: " +
          str(next(model.parameters()).is_cuda))
    print("next(ori_model.parameters()).is_cuda: " +
          str(next(ori_model.parameters()).is_cuda))

    lr = config.init_lr
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=config.init_lr)
    cur_patience = 0
    total_loss = 0
    total_ans_loss = 0
    total_sp_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(10000):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                return_yp=False)
            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data[0]
            total_ans_loss += loss_1.data[0]
            total_sp_loss += loss_2.data[0]
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                cur_ans_loss = total_ans_loss / config.period
                cur_sp_loss = total_sp_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | answer loss {:8.3f} | supporting facts loss {:8.3f} '
                    .format(epoch, global_step, lr,
                            elapsed * 1000 / config.period, cur_loss,
                            cur_ans_loss, cur_sp_loss))
                experiment.log_metrics(
                    {
                        'train loss': cur_loss,
                        'train answer loss': cur_ans_loss,
                        'train supporting facts loss': cur_sp_loss
                    },
                    step=global_step)
                total_loss = 0
                total_ans_loss = 0
                total_sp_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | answer loss {:8.3f} | supporting facts loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['ans_loss'], metrics['sp_loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)
                experiment.log_metrics(
                    {
                        'dev loss': metrics['loss'],
                        'dev answer loss': metrics['ans_loss'],
                        'dev supporting facts loss': metrics['sp_loss'],
                        'EM': metrics['exact_match'],
                        'F1': metrics['f1']
                    },
                    step=global_step)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Exemple #9
0
def test(config):
    # Inference mode (after training)
    if config.bert:
        if config.bert_with_glove:
            with open(config.word_emb_file, "r") as fh:
                word_mat = np.array(json.load(fh), dtype=np.float32)
        else:
            word_mat = None
    else:
        with open(config.word_emb_file, "r") as fh:
            word_mat = np.array(json.load(fh), dtype=np.float32)

    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    if config.data_split == 'dev':
        with open(config.dev_eval_file, "r") as fh:
            dev_eval_file = json.load(fh)
    else:
        with open(config.test_eval_file, 'r') as fh:
            dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    if config.data_split == 'dev':
        dev_buckets = get_buckets(
            config.dev_record_file,
            os.path.join(config.bert_dir, config.dev_example_ids))
        if config.bert:
            dev_context_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.dev_bert_emb_context))
            dev_ques_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.dev_bert_emb_ques))

        para_limit = config.para_limit
        ques_limit = config.ques_limit
    elif config.data_split == 'test':
        para_limit = None
        ques_limit = None
        dev_buckets = get_buckets(
            config.test_record_file,
            os.path.join(config.bert_dir, config.test_example_ids))
        if config.bert:
            dev_context_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.test_bert_emb_context))
            dev_ques_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.test_bert_emb_ques))

    def build_dev_iterator():
        if config.bert:
            return DataIterator(dev_buckets,
                                config.batch_size,
                                para_limit,
                                ques_limit,
                                config.char_limit,
                                False,
                                config.sent_limit,
                                bert=True,
                                bert_buckets=list(
                                    zip(dev_context_buckets,
                                        dev_ques_buckets)))
        else:
            return DataIterator(dev_buckets,
                                config.batch_size,
                                para_limit,
                                ques_limit,
                                config.char_limit,
                                False,
                                config.sent_limit,
                                bert=False)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)
    ori_model = model.cuda()
    ori_model.load_state_dict(torch.load(os.path.join(config.save,
                                                      'model.pt')))
    model = nn.DataParallel(ori_model)

    model.eval()
    predict(build_dev_iterator(), model, dev_eval_file, config,
            config.prediction_file)
Exemple #10
0
def train(config):
    if config.bert:
        if config.bert_with_glove:
            with open(config.word_emb_file, "r") as fh:
                word_mat = np.array(json.load(fh), dtype=np.float32)
        else:
            word_mat = None
    else:
        with open(config.word_emb_file, "r") as fh:
            word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save,
                   scripts_to_save=[
                       'run.py', 'model.py', 'util.py', 'sp_model.py',
                       'macnet_v2.py'
                   ])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)
    max_iter = math.ceil(len(train_buckets[0]) / config.batch_size)

    def build_train_iterator():
        if config.bert:
            train_context_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.train_bert_emb_context))
            train_ques_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.train_bert_emb_ques))
            return DataIterator(train_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                True,
                                config.sent_limit,
                                bert=True,
                                bert_buckets=list(
                                    zip(train_context_buckets,
                                        train_ques_buckets)),
                                new_spans=True)
        else:
            return DataIterator(train_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                True,
                                config.sent_limit,
                                bert=False,
                                new_spans=True)

    def build_dev_iterator():
        # Iterator for inference during training
        if config.bert:
            dev_context_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.dev_bert_emb_context))
            dev_ques_buckets = get_buckets_bert(
                os.path.join(config.bert_dir, config.dev_bert_emb_ques))
            return DataIterator(dev_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                False,
                                config.sent_limit,
                                bert=True,
                                bert_buckets=list(
                                    zip(dev_context_buckets,
                                        dev_ques_buckets)))
        else:
            return DataIterator(dev_buckets,
                                config.batch_size,
                                config.para_limit,
                                config.ques_limit,
                                config.char_limit,
                                False,
                                config.sent_limit,
                                bert=False)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    if config.optim == 'sgd':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=config.init_lr)
    elif config.optim == 'adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=config.init_lr)
    if config.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               config.epoch *
                                                               max_iter,
                                                               eta_min=0.0001)
    elif config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=config.patience,
            verbose=True)
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(config.epoch):
        for data in build_train_iterator():
            scheduler.step()
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            is_y1 = Variable(data['is_y1'])
            is_y2 = Variable(data['is_y2'])

            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])
            if config.bert:
                bert_context = Variable(data['bert_context'])
                bert_ques = Variable(data['bert_ques'])
            else:
                bert_context = None
                bert_ques = None

            support_ids = (is_support == 1).nonzero()
            sp_dict = {idx: [] for idx in range(context_idxs.shape[0])}
            for row in support_ids:
                bid, sp_sent_id = row
                sp_dict[bid.data.cpu()[0]].append(sp_sent_id.data.cpu()[0])

            if config.sp_shuffle:
                [random.shuffle(value) for value in sp_dict.values()]

            sp1_labels = []
            sp2_labels = []
            sp3_labels = []
            sp4_labels = []
            sp5_labels = []
            for item in sorted(sp_dict.items(), key=lambda t: t[0]):
                bid, supports = item
                if len(supports) == 1:
                    sp1_labels.append(supports[0])
                    sp2_labels.append(-1)
                    sp3_labels.append(-2)
                    sp4_labels.append(-2)
                    sp5_labels.append(-2)
                elif len(supports) == 2:
                    sp1_labels.append(supports[0])
                    sp2_labels.append(supports[1])
                    sp3_labels.append(-1)
                    sp4_labels.append(-2)
                    sp5_labels.append(-2)
                elif len(supports) == 3:
                    sp1_labels.append(supports[0])
                    sp2_labels.append(supports[1])
                    sp3_labels.append(supports[2])
                    sp4_labels.append(-1)
                    sp5_labels.append(-2)
                elif len(
                        supports) >= 4:  # 4 or greater sp are treated the same
                    sp1_labels.append(supports[0])
                    sp2_labels.append(supports[1])
                    sp3_labels.append(supports[2])
                    sp4_labels.append(supports[3])
                    sp5_labels.append(-1)

            # We will append 2 vectors to the front (sp_output_with_end), so we increment indices by 2
            sp1_labels = np.array(sp1_labels) + 2
            sp2_labels = np.array(sp2_labels) + 2
            sp3_labels = np.array(sp3_labels) + 2
            sp4_labels = np.array(sp4_labels) + 2
            sp5_labels = np.array(sp5_labels) + 2
            sp_labels_mod = Variable(
                torch.LongTensor(
                    np.stack([
                        sp1_labels, sp2_labels, sp3_labels, sp4_labels,
                        sp5_labels
                    ], 1)).cuda())

            sp_mod_mask = sp_labels_mod > 0

            if config.bert:
                logit1, logit2, grn_logit1, grn_logit2, coarse_att, predict_type, predict_support, sp_att_logits = model(
                    context_idxs,
                    ques_idxs,
                    context_char_idxs,
                    ques_char_idxs,
                    context_lens,
                    start_mapping,
                    end_mapping,
                    all_mapping,
                    return_yp=False,
                    support_labels=is_support,
                    is_train=True,
                    sp_labels_mod=sp_labels_mod,
                    bert=True,
                    bert_context=bert_context,
                    bert_ques=bert_ques)
            else:
                logit1, logit2, grn_logit1, grn_logit2, coarse_att, predict_type, predict_support, sp_att_logits = model(
                    context_idxs,
                    ques_idxs,
                    context_char_idxs,
                    ques_char_idxs,
                    context_lens,
                    start_mapping,
                    end_mapping,
                    all_mapping,
                    return_yp=False,
                    support_labels=is_support,
                    is_train=True,
                    sp_labels_mod=sp_labels_mod,
                    bert=False)

            sp_att1 = sp_att_logits[0].squeeze(-1)
            sp_att2 = sp_att_logits[1].squeeze(-1)
            sp_att3 = sp_att_logits[2].squeeze(-1)
            if config.reasoning_steps == 4:
                sp_att4 = sp_att_logits[3].squeeze(-1)

            # Add masks to targets(labels) to ignore in loss calculation
            sp1_labels = ((sp_mod_mask[:, 0].float() - 1) *
                          100).cpu().data.numpy() + sp1_labels
            sp2_labels = ((sp_mod_mask[:, 1].float() - 1) *
                          100).cpu().data.numpy() + sp2_labels
            sp3_labels = ((sp_mod_mask[:, 2].float() - 1) *
                          100).cpu().data.numpy() + sp3_labels
            sp4_labels = ((sp_mod_mask[:, 3].float() - 1) *
                          100).cpu().data.numpy() + sp4_labels

            sp1_loss = nll_average(
                sp_att1, Variable(torch.LongTensor(sp1_labels).cuda()))
            sp2_loss = nll_average(
                sp_att2, Variable(torch.LongTensor(sp2_labels).cuda()))
            sp3_loss = nll_average(
                sp_att3, Variable(torch.LongTensor(sp3_labels).cuda()))
            if config.reasoning_steps == 4:
                sp4_loss = nll_average(
                    sp_att4, Variable(torch.LongTensor(sp4_labels).cuda()))
            else:
                sp4_loss = 0

            GRN_SP_PRED = True  # temporary flag

            batch_losses = []
            for bid, spans in enumerate(data['gold_mention_spans']):
                if (not spans == None) and (not spans == []):
                    try:
                        _loss_start_avg = F.cross_entropy(
                            logit1[bid].view(1, -1).expand(
                                len(spans), len(logit1[bid])),
                            Variable(torch.LongTensor(spans))[:, 0].cuda(),
                            reduce=True)
                        _loss_end_avg = F.cross_entropy(
                            logit2[bid].view(1, -1).expand(
                                len(spans), len(logit2[bid])),
                            Variable(torch.LongTensor(spans))[:, 1].cuda(),
                            reduce=True)
                    except IndexError:
                        ipdb.set_trace()
                    batch_losses.append((_loss_start_avg + _loss_end_avg))
            loss_1 = torch.mean(torch.stack(batch_losses)) + nll_sum(
                predict_type, q_type) / context_idxs.size(0)

            if GRN_SP_PRED:
                loss_2 = 0
            else:
                loss_2 = nll_average(predict_support.view(-1, 2),
                                     is_support.view(-1))
            loss = loss_1 + config.sp_lambda * loss_2 + sp1_loss + sp2_loss + sp3_loss + sp4_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data[0]
            global_step += 1
            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'
                    .format(epoch, global_step,
                            optimizer.param_groups[0]['lr'],
                            elapsed * 1000 / config.period, cur_loss))
                total_loss = 0
                start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()
                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if config.scheduler == 'plateau':
                    scheduler.step(dev_F1)
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))
Exemple #11
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(
        config.save,
        scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    # train_buckets = dev_buckets

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size,
                            config.para_limit, config.ques_limit,
                            config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit,
                            config.ques_limit, config.char_limit, False,
                            config.sent_limit)

    if config.sp_lambda > 0:
        model = SPModel(config, word_mat, char_mat)
    else:
        model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(
        sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    # ori_model = model.cpu()

    # flag (checking if the learning will be loaded from the file or not)
    # lr = 0

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()))
    # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr)

    # if the training was interrupted, then we load last trained state of the model
    logging('Checking if previous training was interrupted...')

    my_file = Path('temp_model')
    if my_file.exists():
        logging(
            'Previous training was interupted, loading model and optimizer state'
        )
        ori_model.load_state_dict(torch.load('temp_model'))
        optimizer.load_state_dict(torch.load('temp_optimizer'))

        # training_state = training_state.load_whole_class('last_training_state.pickle')
        # training_state = training_state.load_whole_class('last_training_state.pickle')

    # # if os.path.exists('last_model.pt'):
    # for dp, dn, filenames in os.walk('.'):
    #     model_filenames = []
    #     corresponding_learning_rates = []
    #     for ff in filenames:
    #         if ff.endswith("last_model.pt"):
    #             # putting all found models on list
    #             lr = float(ff.split('_')[0])
    #             model_filenames.append(ff)
    #             corresponding_learning_rates.append(lr)

    #     if len( model_filenames) > 0:
    #         # selecting the model with the smallest learning rate to be loaded
    #         loading_model_index = np.argmin(corresponding_learning_rates)
    #         # continuing with the previous learning rate
    #         lr = corresponding_learning_rates[loading_model_index]

    #         logging('Previous training was interrupted so loading last saved state model and continuing.')
    #         logging('Was stopped with learning rate: ' +  str( corresponding_learning_rates[loading_model_index] ) )
    #         logging('Loading file : ' + model_filenames[loading_model_index])
    #         ori_model.load_state_dict(torch.load(model_filenames[loading_model_index]))

    # model = nn.DataParallel(ori_model)
    # if the learning rate was not loaded then we set it equal to the initial one
    # if lr == 0:
    #     lr = config.init_lr

    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}

    best_dev_sp_f1 = 0

    # total_support_facts = 0
    # total_contexes = 0

    for epoch in range(1000):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])

            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            sentence_scores = Variable(data['sentence_scores'])

            # get model's output
            # predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False)

            # calculating loss
            # loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1) )

            # logging('Batch training loss :' + str(loss_2.item()) )

            # loss = loss_2

            logit1, logit2, predict_type, predict_support = model(
                context_idxs,
                ques_idxs,
                context_char_idxs,
                ques_char_idxs,
                context_lens,
                start_mapping,
                end_mapping,
                all_mapping,
                sentence_scores,
                return_yp=False)

            # print('predict_type :', predict_type)
            # print(' q_type:',q_type )
            # print(' logit1:', logit1)
            # print('y1 :',y1 )
            # print(' logit2:', logit2)
            # print(' y2:', y2)

            # print(' predict_support:', predict_support)
            # print(' is_support:', is_support)

            loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) +
                      nll_sum(logit2, y2)) / context_idxs.size(0)
            loss_2 = nll_average(predict_support.view(-1, 2),
                                 is_support.view(-1).long())

            loss = loss_1 + config.sp_lambda * loss_2

            # log both losses :

            # with open(os.path.join(config.save, 'losses_log.txt'), 'a+') as f_log:
            #     s = 'Answer loss :' + str(loss_1.data.item()) + ', SF loss :' + str(loss_2.data.item()) + ', Total loss :' + str(loss.data.item())
            #     f_log.write(s + '\n')

            # update train metrics
            # train_metrics = update_sp(train_metrics, predict_support, is_support)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data.item()
            global_step += 1

            if global_step % config.period == 0:

                # avegage metrics
                for key in train_metrics:
                    train_metrics[key] /= float(config.period)

                # cur_loss = total_loss / config.period
                cur_loss = total_loss

                elapsed = time.time() - start_time
                logging(
                    '| epoch {:3d} | step {:6d} | ms/batch {:5.2f} | train loss {:8.3f}'
                    .format(epoch, global_step, elapsed * 1000 / config.period,
                            cur_loss))
                # logging('| epoch {:3d} | step {:6d} | ms/batch {:5.2f} | train loss {:8.3f} | SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'.format(epoch, global_step, elapsed*1000/config.period, cur_loss, train_metrics['sp_em'], train_metrics['sp_f1'], train_metrics['sp_prec'], train_metrics['sp_recall']))

                total_loss = 0
                # train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}
                start_time = time.time()

                # logging('Saving model...')
                torch.save(ori_model.state_dict(), os.path.join('temp_model'))
                torch.save(optimizer.state_dict(),
                           os.path.join('temp_optimizer'))

            # if global_step % config.checkpoint == 0:
            #     model.eval()

            #     # metrics = evaluate_batch(build_dev_iterator(), model, 5, dev_eval_file, config)
            #     eval_metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config)
            #     model.train()

            #     logging('-' * 89)
            #     # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint,
            #     #     epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))
            #     logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f}| SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'.format(global_step//config.checkpoint,
            #         epoch, time.time()-eval_start_time, eval_metrics['loss'], eval_metrics['sp_em'], eval_metrics['sp_f1'], eval_metrics['sp_prec'], eval_metrics['sp_recall']))
            #     logging('-' * 89)

            #     eval_start_time = time.time()

            if global_step % config.checkpoint == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0,
                                         dev_eval_file, config)
                model.train()

                logging('-' * 89)
                logging(
                    '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'
                    .format(global_step // config.checkpoint, epoch,
                            time.time() - eval_start_time, metrics['loss'],
                            metrics['exact_match'], metrics['f1']))
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(),
                               os.path.join(config.save, 'model.pt'))
                    torch.save(optimizer.state_dict(),
                               os.path.join(config.save, 'optimizer.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        stop_train = True
                        break

                # if eval_metrics['sp_f1'] > best_dev_sp_f1:
                #     best_dev_sp_f1 = eval_metrics['sp_f1']
                #     torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt'))
                #     torch.save(optimizer.state_dict(), os.path.join(config.save,  'optimizer.pt') )
                #     # cur_lr_decrease_patience = 0
                #     cur_patience = 0
                # else:
                #     cur_patience += 1
                #     if cur_patience >= config.patience:
                #         stop_train = True
                #         break

                # lr *= 0.75
                # for param_group in optimizer.param_groups:
                #     param_group['lr'] = lr
                # if lr < config.init_lr * 1e-2:
                #     stop_train = True
                #     break
                # cur_patience = 0

                # cur_lr_decrease_patience += 1
                # if cur_lr_decrease_patience >= config.lr_decrease_patience:
                #     lr *= 0.75
                #     cur_early_stop_patience +=1
                #     if cur_early_stop_patience >= config.early_stop_patience:
                #         stop_train = True
                #         break

                #     for param_group in optimizer.param_groups:
                #         param_group['lr'] = lr

        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))

    # delete last temporary trained model, since the training has completed
    print('Deleting last temp model files...')
    # for dp, dn, filenames in os.walk('.'):
    #     for ff in filenames:
    # if ff.endswith("last_model.pt"):
    os.remove('temp_model')
    os.remove('temp_optimizer')
Exemple #12
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py'])
    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file)
    dev_buckets = get_buckets(config.dev_record_file)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit)

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit)

    # if config.sp_lambda > 0:
    #     model = SPModel(config, word_mat, char_mat)
    # else:
    model = Model(config, word_mat, char_mat)

    logging('nparams {}'.format(sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    optimizer = torch.optim.Adam(model.parameters(), lr=float(0.5))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    clip_grad = float(5.0)
    train_time = begin_time = time.time()
    # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr)
    # cur_patience = 0
    # total_loss = 0
    # global_step = 0
    # best_dev_F1 = None
    # stop_train = False
    # start_time = time.time()
    # eval_start_time = time.time()
    model.train()

    for epoch in range(10000):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])

            question_lens = Variable(data['question_lens'])

            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])
            q_type = Variable(data['q_type'])
            is_support = Variable(data['is_support'])
            start_mapping = Variable(data['start_mapping'])
            end_mapping = Variable(data['end_mapping'])
            all_mapping = Variable(data['all_mapping'])

            # logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping,question_lens, return_yp=False)
            # loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0)
            # loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1))
            # loss = loss_1 + config.sp_lambda * loss_2
            question_lens, indices = torch.sort(question_lens, descending=True)

            ques_idxs = ques_idxs.cpu().detach().numpy()
            indices = indices.data.cpu().numpy()
            ques_idxs = torch.tensor(ques_idxs[indices], device='cuda:0')
            train_iter += 1
            optimizer.zero_grad()
            example_losses = -model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, question_lens,start_mapping, end_mapping, all_mapping)
            batch_loss = example_losses.sum()
            loss = batch_loss / config.batch_size

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

            optimizer.step()

            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(len(s[1:]) for s in ques_idxs)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += config.batch_size
            cum_examples += config.batch_size


            if train_iter % 100 == 0:
                print(report_loss)
                print(report_examples)
                print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                      'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                         report_loss / report_examples,
                                                                                         math.exp(report_loss / report_tgt_words),
                                                                                         cum_examples,
                                                                                         report_tgt_words / (time.time() - train_time),
                                                                                         time.time() - begin_time), file=sys.stderr)

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.
Exemple #13
0
def gevent_sync(bucket_number=10):
    pages = range(1, 50)
    buckets = get_buckets(pages, bucket_number)

    jobs = [gevent.spawn(sync_code, ps) for ps in buckets]
    gevent.joinall(jobs)
Exemple #14
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    # with open(config.train_eval_file, "r") as fh:
    #     train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.idx2word_file, 'r') as fh:
        idx2word_dict = json.load(fh)

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    if config.pre_att_id != '':
        config.save = 'T16-v2-{}-kp0{}-cond{}-ori{}-attcnt{}-gatefuse{}-lr{}-opt{}'.format(config.pre_att_id, config.keep_prob0, int(config.condition), int(config.original_ptr), config.att_cnt, config.gate_fuse, config.init_lr, config.optim)
        if config.use_elmo:
            config.save += "_ELMO"
        if config.train_emb:
            raise ValueError
            config.save += "_TE"
        if config.trnn:
            config.save += '_TRNN'
    else:
        config.save = 'baseline-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
        if config.use_elmo:
            config.save += "_ELMO"
        if config.uniform_graph:
            config.save += '_UNIFORM'
    # non overwriting
    # if os.path.exists(config.save):
    #     sys.exit(1)
    create_exp_dir(config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'main.py'])
    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    if config.pre_att_id != '':
        sys.path.insert(0, '../pretrain')
        from data import Vocab
        vocab = Vocab('../pretrain/vocabv2.pkl', 100000, '<unk>')
        
        # from model8 import StructurePredictor
        # model = StructurePredictor(512, len(vocab), 1, 1, 0.0)
        # model.load_state_dict(torch.load('../skip_thought/{}/st_predictor.pt'.format(config.pre_att_id)))
        # model.cuda()
        # model.eval()

        model = torch.load('../pretrain/{}/model.pt'.format(config.pre_att_id))
        # if 'gru' in config.pre_att_id:
        #     model.set_gru(True)
        # elif 'add' in config.pre_att_id:
        #     model.set_gru(False)
        # else:
        #     assert False
        model.cuda()
        ori_model = model
        model = nn.DataParallel(model)
        model.eval()
        import re
        try:
            nly = int(re.search(r'ly(\d+)', config.pre_att_id).group(1))
        except:
            nly = len(ori_model.enc_net.nets)
        if config.gate_fuse < 3:
            config.num_mixt = nly * 8
        else:
            config.num_mixt = (nly + nly - 1) * 8

        # old_model = torch.load('../skip_thought/{}/model.pt'.format(config.pre_att_id))
        # from model5 import GraphModel
        # model = GraphModel(old_model).cuda()
        # model = nn.DataParallel(model)
        # model.eval()
        # del old_model
        # import gc
        # gc.collect()
        # from data import Vocab
        # vocab = Vocab('../skip_thought/vocabv2.pkl', 100000, '<unk>')
        
        del sys.path[0]
        
        pre_att_data = {'model': model, 'vocab': vocab}
    else:
        pre_att_data = None

    logging('Config')
    for k, v in config.__dict__.items():
        logging('    - {} : {}'.format(k, v))

    if config.use_elmo and config.load_elmo:
        ee = torch.load(config.elmo_ee_file)
    else:
        ee = None

    logging("Building model...")
    train_buckets = get_buckets(config.train_record_file, config, limit=True)
    dev_buckets = get_buckets(config.dev_record_file, config, limit=False)

    def build_train_iterator():
        return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, pre_att_data, config, ee, idx2word_dict, 'train')

    def build_dev_iterator():
        return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, pre_att_data, config, ee, idx2word_dict, 'dev')

    model = Model(config, word_mat, char_mat) if not config.trnn else ModelTRNN(config, word_mat, char_mat)
    # logging('nparams {}'.format(sum([p.nelement() for p in model.parameters() if p.requires_grad])))
    ori_model = model.cuda()
    # ori_model.word_emb.cpu()
    # model = ori_model
    model = nn.DataParallel(ori_model)

    lr = config.init_lr
    # optimizer = optim.SGD(model.parameters(), lr=config.init_lr, momentum=config.momentum)
    if config.optim == "adadelta":  # default
        optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr, rho=0.95)
    elif config.optim == "sgd":
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr, momentum=config.momentum)
    elif config.optim == "adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr, betas=(config.momentum, 0.999))
    cur_patience = 0
    total_loss = 0
    global_step = 0
    best_dev_F1 = None
    stop_train = False
    start_time = time.time()
    eval_start_time = time.time()
    model.train()

    for epoch in range(10000 * 32 // config.batch_size):
        for data in build_train_iterator():
            context_idxs = Variable(data['context_idxs'])
            ques_idxs = Variable(data['ques_idxs'])
            context_char_idxs = Variable(data['context_char_idxs'])
            ques_char_idxs = Variable(data['ques_char_idxs'])
            context_lens = Variable(data['context_lens'])
            y1 = Variable(data['y1'])
            y2 = Variable(data['y2'])

            graph = data['graph']
            graph_q = data['graph_q']
            if graph is not None:
                graph.volatile = False
                graph.requires_grad = False
                graph_q.volatile = False
                graph_q.requires_grad = False

            elmo, elmo_q = data['elmo'], data['elmo_q']
            if elmo is not None:
                elmo.volatile = False
                elmo.requires_grad = False
                elmo_q.volatile = False
                elmo_q.requires_grad = False

            logit1, logit2 = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, pre_att=graph, pre_att_q=graph_q, elmo=elmo, elmo_q=elmo_q)
            loss = criterion(logit1, y1) + criterion(logit2, y2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            import gc; gc.collect()

            total_loss += loss.data[0]
            global_step += 1

            if global_step % config.period == 0:
                cur_loss = total_loss / config.period
                elapsed = time.time() - start_time
                logging('| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'.format(epoch, global_step, lr, elapsed*1000/config.period, cur_loss))
                total_loss = 0
                start_time = time.time()

            if global_step % (config.checkpoint * 32 // config.batch_size)  == 0:
                model.eval()
                metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file)
                model.train()

                logging('-' * 89)
                logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint,
                    epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1']))
                debug_s = ''
                if hasattr(ori_model, 'scales'):
                    debug_s += '| scales {} '.format(ori_model.scales.data.cpu().numpy().tolist())
                # if hasattr(ori_model, 'mixt_logits') and (not hasattr(ori_model, 'condition') or not ori_model.condition):
                #     debug_s += '| mixt {}'.format(F.softmax(ori_model.mixt_logits, dim=-1).data.cpu().numpy().tolist())
                if debug_s != '':
                    logging(debug_s)
                logging('-' * 89)

                eval_start_time = time.time()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt'))
                    cur_patience = 0
                else:
                    cur_patience += 1
                    if cur_patience >= config.patience:
                        lr /= 2.0
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        if lr < config.init_lr * 1e-2:
                            stop_train = True
                            break
                        cur_patience = 0
        if stop_train: break
    logging('best_dev_F1 {}'.format(best_dev_F1))