Exemple #1
0
def main5():
    # evaluation test
    torch.manual_seed(233)
    torch.cuda.set_device(0)
    args = get_args()
    if args.data == "nyt":
        vocab_file = '/home/ml/ydong26/data/nyt_c/processed/vocab_100d.p'
        with open(vocab_file, "rb") as f:
            vocab = pickle.load(f, encoding='latin1')
    else:
        vocab_file = '/home/ml/ydong26/data/CNNDM/CNN_DM_pickle_data/vocab_100d.p'
        with open(vocab_file, "rb") as f:
            vocab = pickle.load(f, encoding='latin1')
    config = Config(
        vocab_size=vocab.embedding.shape[0],
        embedding_dim=vocab.embedding.shape[1],
        category_size=args.category_size,
        category_dim=50,
        word_input_size=100,
        sent_input_size=2 * args.hidden,
        word_GRU_hidden_units=args.hidden,
        sent_GRU_hidden_units=args.hidden,
        pretrained_embedding=vocab.embedding,
        word2id=vocab.w2i,
        id2word=vocab.i2w,
    )
    extract_net = model_all.GeneralModel(config)
    extract_net.cuda()
    model_name = "/home/ml/lyu40/PycharmProjects/E_Yue/model/nyt/5/model.epoch.9.gm.tr"
    checkpoint = torch.load(model_name)
    best_eval_reward = checkpoint['best_eval_reward']
    extract_net.load_state_dict(checkpoint['state_dict'])

    eval_reward, lead3_reward = evaluate.ext_model_eval(
        extract_net, vocab, args, "val")
    print('epoch 9 reward in validation for gm model on nyt data set: ' +
          str(eval_reward) + ' lead3: ' + str(lead3_reward) +
          " best eval award: " + str(best_eval_reward))
Exemple #2
0
def train_model(args, vocab1, vocab2, device):
    print(args)
    print("generating config")
    config1 = Config1(
        vocab_size=len(vocab1),
        embedding_dim=args.embedding_dim,
        LSTM_layers=args.lstm_layer_1,
        LSTM_hidden_units=args.hidden,
        train_embed=args.train_embed,
        # pretrained_embedding=vocab1.embedding,
        word2id=vocab1.word_to_index,
        id2word=vocab1.index_to_word,
        dropout=args.dropout)
    config2 = Config2(
        vocab_size=len(vocab2),
        embedding_dim=args.embedding_dim,
        LSTM_layers=args.lstm_layer_2,
        LSTM_hidden_units=args.hidden,
        train_embed=args.train_embed,
        # pretrained_embedding=vocab2.embedding,
        word2id=vocab2.word_to_index,
        id2word=vocab2.index_to_word,
        dropout=args.dropout,
        decode_type=args.decode_type)
    model_name_1 = ".".join(
        (args.model_file_1, str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch), "learning_rate", str(args.lr_1), "bsz",
         str(args.batch_size), "data", args.data_dir.split('/')[0], "emb",
         str(config1.embedding_dim), "dropout", str(args.dropout), "max_num",
         str(args.max_num_of_ans), "train_embed", str(args.train_embed),
         'd2s'))
    # model_name_2 = ".".join((args.model_file_2,
    #                        "gamma",str(args.gamma),
    #                        "beta",str(args.beta),
    #                        "batch",str(args.train_batch),
    #                        "learning_rate",str(args.lr_2),
    #                        "data", args.data_dir.split('/')[0],
    #                        "emb", str(config2.embedding_dim),
    #                        "dropout", str(args.dropout),
    #                        'decode_type',str(args.decode_type),
    #                        'd2s'))

    log_name = ".".join(
        ("log/model", str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch), "lr_1", str(args.lr_1), "lr_2", str(args.lr_1),
         args.sampling_method, "bsz", str(args.batch_size), "data",
         args.data_dir.split('/')[0], "emb1", str(config1.embedding_dim),
         "emb2", str(config2.embedding_dim), "dropout", str(args.dropout),
         'decode_type', str(args.decode_type), "train_embed",
         str(args.train_embed), 'd2s'))

    print("initialising data loader and RL learner")
    data_loader = PickleReader(args.data_dir)
    data = args.data_dir.split('/')[0]
    num_data = 3398

    # init statistics
    reward_list = []
    loss_list1 = []
    loss_list2 = []
    best_eval_reward = 0.
    model_save_name_1 = model_name_1
    # model_save_name_2 = model_name_2

    bandit = ContextualBandit(b=args.batch_size,
                              rl_baseline_method=args.rl_baseline_method,
                              vocab=vocab2,
                              sample_method=args.sampling_method,
                              device=device)

    print("Loaded the Bandit")

    model1 = model.Bandit(config1).to(device)
    # model2 = model.Generator(config2).to(device)
    print("Loaded the models")

    if args.load_ext:
        model_name_1 = args.model_file_1
        # model_name_2 = args.model_file_2
        model_save_name_1 = model_name_1
        # model_save_name_2 = model_name_2
        print("loading existing models:1->%s" % model_name_1)
        # print("loading existing models:2->%s" % model_name_2)
        model1 = torch.load(model_name_1,
                            map_location=lambda storage, loc: storage)
        model1.to(device)
        # model2 = torch.load(model_name_2, map_location=lambda storage, loc: storage)
        # model2.to(device)
        log_name = 'log/' + model_name_1.split('/')[-1]
        print("finish loading and evaluate models:")
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward = evaluate.ext_model_eval(model1, None, vocab2, args,
                                                   "val", device)

    logging.basicConfig(filename='%s.log' % log_name,
                        level=logging.DEBUG,
                        format='%(asctime)s %(levelname)-10s %(message)s')
    logging.info("prev best eval reward:%.4f" % (best_eval_reward))
    # Loss and Optimizer
    optimizer1 = torch.optim.Adam([
        param for param in model1.parameters() if param.requires_grad == True
    ],
                                  lr=args.lr_1,
                                  betas=(args.beta, 0.999),
                                  weight_decay=1e-6)
    # optimizer2 = torch.optim.Adam([param for param in model2.parameters() if param.requires_grad == True ], lr=args.lr_2, betas=(args.beta, 0.999),weight_decay=1e-6)

    # if args.lr_sch ==1:
    #     scheduler = ReduceLROnPlateau(optimizer_ans, 'max',verbose=1,factor=0.9,patience=3,cooldown=3,min_lr=9e-5,epsilon=1e-6)
    #     if best_eval_reward:
    #         scheduler.step(best_eval_reward,0)
    #         print("init_scheduler")
    # elif args.lr_sch ==2:
    #     scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer_ans,args.lr, args.lr_2, step_size_up=3*int(num_data/args.train_batch), step_size_down=3*int(num_data/args.train_batch), mode='exp_range', gamma=0.98,cycle_momentum=False)
    print("starting training")
    start_time = time.time()
    n_step = 100
    gamma = args.gamma
    n_val = int(num_data / (7 * args.train_batch))
    supervised_loss = torch.nn.BCELoss()
    regression_loss = torch.nn.MSELoss()
    with torch.autograd.set_detect_anomaly(True):
        for epoch in tqdm(range(args.epochs_ext), desc="epoch:"):
            train_iter = data_loader.chunked_data_reader(
                "train", data_quota=args.train_example_quota)  #-1
            step_in_epoch = 0
            for dataset in train_iter:
                for step, contexts in tqdm(
                        enumerate(
                            BatchDataLoader(dataset,
                                            batch_size=args.train_batch,
                                            shuffle=True))):
                    try:
                        model1.train()
                        # model2.train()
                        step_in_epoch += 1
                        loss = 0.
                        reward = 0.
                        for context in contexts:
                            records = context.records
                            target = context.summary
                            records = torch.autograd.Variable(
                                torch.LongTensor(records)).to(device)
                            # target = torch.autograd.Variable(torch.LongTensor(target)).to(device)
                            # target_len = len(target)
                            prob, num_r = model1(records)
                            num_of_records = int(num_r.item() * 100)
                            sample_content, greedy_cp = bandit.sample(
                                prob, context, num_of_records)
                            # # apply data_parallel after this step
                            # sample_content.append((greedy_cp,0))
                            # gen_summaries = []
                            # total_loss = 0.
                            # for cp in [(greedy_cp.data,0)]:
                            #     gen_input = torch.autograd.Variable(r_cs[cp[0]].data).to(device)
                            #     e_k,prev_hidden, prev_emb = model2(gen_input,vocab2)
                            #     z_k = torch.autograd.Variable(records[cp[0]][:,0].data).to(device)
                            #     prev_t =0
                            #     loss=0.
                            #     gen_summary =[]
                            #     ## perform bptt here
                            #     for y_t in range(target_len):
                            #         p_out, prev_hidden = model2.forward_step(prev_emb,prev_hidden,gen_input,e_k,z_k)
                            #         topv,topi = p_out.topk(1)
                            #         gen_summary.append(topi)
                            #         prev_emb = model2.get_embedding(topi)
                            #         loss += decode_loss(p_out,target[y_t].unsqueeze(0))

                            #         if (y_t-prev_t)==50:
                            #             prev_t = y_t
                            #             loss.backward(retain_graph=True)
                            #             loss.detach()
                            #     if prev_t < target_len:
                            #         loss.backward()
                            #         loss.detach()
                            #     gen_summaries.append((gen_summary,cp[1]))
                            #     loss/=float(target_len)
                            #     total_loss+=loss
                            # optimizer2.step()
                            # optimizer2.zero_grad()
                            # total_loss/=len(sample_content)
                            bandit_loss, reward_b = bandit.calculate_loss(
                                sample_content, context.gold_index, greedy_cp)
                            true_numr = context.num_of_records / 100.
                            r_loss = regression_loss(
                                num_r,
                                torch.tensor(true_numr).type(
                                    torch.float).to(device))
                            #greedy_cp,bandit_loss = greedy_sample(prob,num_of_records+1,device)
                            #reward_b = generate_reward(None,None,gold_cp=context.gold_index,cp=greedy_cp)
                            labels = np.zeros(len(prob))
                            labels[context.gold_index] = 1.0
                            ml_loss = supervised_loss(
                                prob.view(-1),
                                torch.tensor(labels).type(
                                    torch.float).to(device))
                            loss_e = (gamma * (bandit_loss + r_loss)) + (
                                (1 - gamma) * ml_loss)
                            loss_e.backward()
                            reward += reward_b
                            loss += loss_e.item()

                        optimizer1.step()
                        optimizer1.zero_grad()
                        loss /= args.train_batch
                        reward /= args.train_batch
                        reward_list.append(reward)
                        loss_list1.append(loss)
                        # loss_list2.append(total_loss)

                        # if args.lr_sch==2:
                        #     scheduler.step()
                        # logging.info('Epoch %d Step %d Reward %.4f Loss1 %.4f Loss2 %.4f' % (epoch, step_in_epoch, reward,bandit_loss,total_loss))
                        logging.info(
                            'Epoch %d Step %d Reward %.4f Loss1 %.4f' %
                            (epoch, step_in_epoch, reward, loss))

                    except Exception as e:
                        print(e)
                        traceback.print_exc()

                    if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                        # logging.info('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) +
                        #     ' reward: ' + str(np.mean(reward_list))+' loss1: ' + str(np.mean(loss_list1))+' loss2: ' + str(np.mean(loss_list2)))
                        logging.info('Epoch ' + str(epoch) + ' Step ' +
                                     str(step_in_epoch) + ' reward: ' +
                                     str(np.mean(reward_list)) + ' loss1: ' +
                                     str(np.mean(loss_list1)))
                        reward_list = []
                        loss_list1 = []
                        # loss_list2=[]

                    if (step_in_epoch) % n_val == 0 and step_in_epoch != 0:
                        print("doing evaluation")
                        model1.eval()
                        # model2.eval()
                        #eval_reward = evaluate.ext_model_eval(mcan_cb, vocab, args, "test")
                        eval_reward = evaluate.ext_model_eval(
                            model1, None, vocab2, args, "val", device)

                        if eval_reward > best_eval_reward:
                            best_eval_reward = eval_reward
                            print(
                                "saving models %s : with eval_reward:" %
                                model_save_name_1, eval_reward)
                            logging.debug("saving models" +
                                          str(model_save_name_1) + " " +
                                          "with eval_reward:" +
                                          str(eval_reward))
                            torch.save(model1, model_save_name_1)
                            # torch.save(model2,model_save_name_2)
                        print('epoch ' + str(epoch) +
                              ' reward in validation: ' + str(eval_reward))
                        logging.debug('epoch ' + str(epoch) +
                                      ' reward in validation: ' +
                                      str(eval_reward))
                        logging.debug('time elapsed:' +
                                      str(time.time() - start_time))
            # if args.lr_sch ==1:
            #     mcan_cb.eval()
            #     eval_reward = evaluate.ext_model_eval(mcan_cb, vocab, args, "val")
            #     #eval_reward = evaluate.ext_model_eval(mcan_cb, vocab, args, "test")
            #     scheduler.step(eval_reward[0],epoch)
    return model1
Exemple #3
0
def extractive_training(args, vocab):
    print(args)
    print("generating config")
    config = Config(
        vocab_size=vocab.embedding.shape[0],
        embedding_dim=vocab.embedding.shape[1],
        position_size=500,
        position_dim=50,
        word_input_size=100,
        sent_input_size=2 * args.hidden,
        word_GRU_hidden_units=args.hidden,
        sent_GRU_hidden_units=args.hidden,
        pretrained_embedding=vocab.embedding,
        word2id=vocab.w2i,
        id2word=vocab.i2w,
        dropout=args.dropout,
    )
    model_name = ".".join((args.model_file,
                           str(args.ext_model),
                           str(args.rouge_metric), str(args.std_rouge),
                           str(args.rl_baseline_method), "oracle_l", str(args.oracle_length),
                           "bsz", str(args.batch_size), "rl_loss", str(args.rl_loss_method),
                           "train_example_quota", str(args.train_example_quota),
                           "length_limit", str(args.length_limit),
                           "data", os.path.split(args.data_dir)[-1],
                           "hidden", str(args.hidden),
                           "dropout", str(args.dropout),
                           'ext'))
    print(model_name)

    log_name = ".".join(("../log/model",
                         str(args.ext_model),
                         str(args.rouge_metric), str(args.std_rouge),
                         str(args.rl_baseline_method), "oracle_l", str(args.oracle_length),
                         "bsz", str(args.batch_size), "rl_loss", str(args.rl_loss_method),
                         "train_example_quota", str(args.train_example_quota),
                         "length_limit", str(args.length_limit),
                         "hidden", str(args.hidden),
                         "dropout", str(args.dropout),
                         'ext'))

    print("init data loader and RL learner")
    data_loader = PickleReader(args.data_dir)

    # init statistics
    reward_list = []
    best_eval_reward = 0.
    model_save_name = model_name

    if args.fine_tune:
        model_save_name = model_name + ".fine_tune"
        log_name = log_name + ".fine_tune"
        args.std_rouge = True
        print("fine_tune model with std_rouge, args.std_rouge changed to %s" % args.std_rouge)

    reinforce = ReinforceReward(std_rouge=args.std_rouge, rouge_metric=args.rouge_metric,
                                b=args.batch_size, rl_baseline_method=args.rl_baseline_method,
                                loss_method=1)

    print('init extractive model')

    if args.ext_model == "lstm_summarunner":
        extract_net = model.SummaRuNNer(config)
    elif args.ext_model == "gru_summarunner":
        extract_net = model.GruRuNNer(config)
    elif args.ext_model == "bag_of_words":
        extract_net = model.SimpleRuNNer(config)
    elif args.ext_model == "simpleRNN":
        extract_net = model.SimpleRNN(config)
    elif args.ext_model == "RNES":
        extract_net = model.RNES(config)
    elif args.ext_model == "Refresh":
        extract_net = model.Refresh(config)
    elif args.ext_model == "simpleCONV":
        extract_net = model.simpleCONV(config)
    else:
        print("this is no model to load")

    extract_net.cuda()

    # print("current model name: %s"%model_name)
    # print("current log file: %s"%log_name)

    logging.basicConfig(filename='%s.log' % log_name,
                        level=logging.INFO, format='%(asctime)s [INFO] %(message)s')
    if args.load_ext:
        print("loading existing model%s" % model_name)
        extract_net = torch.load(model_name, map_location=lambda storage, loc: storage)
        extract_net.cuda()
        print("finish loading and evaluate model %s" % model_name)
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward, _ = evaluate.ext_model_eval(extract_net, vocab, args, "val")

    # Loss and Optimizer
    optimizer_ext = torch.optim.Adam(extract_net.parameters(), lr=args.lr, betas=(0., 0.999))

    print("starting training")
    n_step = 100
    for epoch in range(args.epochs_ext):
        train_iter = data_loader.chunked_data_reader("train", data_quota=args.train_example_quota)
        step_in_epoch = 0
        for dataset in train_iter:
            for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True)):
                try:
                    extract_net.train()
                    # if True:
                    step_in_epoch += 1
                    # for i in range(1):  # how many times a single data gets updated before proceeding
                    doc = docs[0]
                    doc.content = tokens_to_sentences(doc.content)
                    doc.summary = tokens_to_sentences(doc.summary)
                    if args.oracle_length == -1:  # use true oracle length
                        oracle_summary_sent_num = len(doc.summary)
                    else:
                        oracle_summary_sent_num = args.oracle_length

                    x = prepare_data(doc, vocab)
                    if min(x.shape) == 0:
                        continue
                    sents = Variable(torch.from_numpy(x)).cuda()
                    outputs = extract_net(sents)

                    if args.prt_inf and np.random.randint(0, 100) == 0:
                        prt = True
                    else:
                        prt = False

                    loss, reward = reinforce.train(outputs, doc,
                                                   max_num_of_sents=oracle_summary_sent_num,
                                                   max_num_of_bytes=args.length_limit,
                                                   prt=prt)
                    if prt:
                        print('Probabilities: ', outputs.squeeze().data.cpu().numpy())
                        print('-' * 80)

                    reward_list.append(reward)

                    if isinstance(loss, Variable):
                        loss.backward()

                    if step % 1 == 0:
                        torch.nn.utils.clip_grad_norm(extract_net.parameters(), 1)  # gradient clipping
                        optimizer_ext.step()
                        optimizer_ext.zero_grad()
                    # print('Epoch %d Step %d Reward %.4f'%(epoch,step_in_epoch,reward))
                    logging.info('Epoch %d Step %d Reward %.4f' % (epoch, step_in_epoch, reward))
                except Exception as e:
                    print(e)

                if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                    print('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) +
                          ' reward: ' + str(np.mean(reward_list)))
                    reward_list = []

                if (step_in_epoch) % 10000 == 0 and step_in_epoch != 0:
                    print("doing evaluation")
                    extract_net.eval()
                    eval_reward, lead3_reward = evaluate.ext_model_eval(extract_net, vocab, args, "val")
                    if eval_reward > best_eval_reward:
                        best_eval_reward = eval_reward
                        print("saving model %s with eval_reward:" % model_save_name, eval_reward, "leadreward",
                              lead3_reward)
                        torch.save(extract_net, model_name)
                    print('epoch ' + str(epoch) + ' reward in validation: '
                          + str(eval_reward) + ' lead3: ' + str(lead3_reward))
    return extract_net
Exemple #4
0
def train_model(args):
    print(args)
    print("generating config")
    config = Config(
        input_dim=args.input_dim,
        dropout=args.dropout,
        highway=args.highway,
        nn_layers=args.nn_layers,
    )
    model_name = ".".join(
        (args.model_file, str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch),
         "learning_rate", str(args.lr) + "-" + str(args.lr_sch), "bsz",
         str(args.batch_size), "data", args.data_dir.split('/')[0],
         args.eval_data, "input_dim", str(config.input_dim), "max_num",
         str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout",
         str(args.dropout) + "-" + str(args.clip_grad), "highway",
         str(args.highway), "nn-" + str(args.nn_layers), 'ans'))

    log_name = ".".join(
        ("log_bert/model", str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch), "lr", str(args.lr) + "-" + str(args.lr_sch),
         "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0],
         args.eval_data, "input_dim", str(config.input_dim), "max_num",
         str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout",
         str(args.dropout) + "-" + str(args.clip_grad), "highway",
         str(args.highway), "nn-" + str(args.nn_layers), 'ans'))

    print("initialising data loader and RL learner")
    data_loader = PickleReader(args.data_dir)
    data = args.data_dir.split('/')[0]
    num_data = 0
    if data == "wiki_qa":
        num_data = 873
    elif data == "trec_qa":
        num_data = 1229
    else:
        assert (1 == 2)
    # init statistics
    reward_list = []
    loss_list = []
    best_eval_reward = 0.
    model_save_name = model_name

    bandit = ContextualBandit(b=args.batch_size,
                              rl_baseline_method=args.rl_baseline_method,
                              sample_method=args.sampling_method)

    print("Loaded the Bandit")

    bert_cb = model2.BERT_CB(config)

    print("Loaded the model")

    bert_cb.cuda()
    vocab = "vocab"

    if args.load_ext:
        model_name = args.model_file
        print("loading existing model%s" % model_name)
        bert_cb = torch.load(model_name,
                             map_location=lambda storage, loc: storage)
        bert_cb.cuda()
        model_save_name = model_name
        log_name = "/".join(("log_bert", model_name.split("/")[1]))
        print("finish loading and evaluate model %s" % model_name)
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args,
                                                   args.eval_data)[0]
    logging.basicConfig(filename='%s.log' % log_name,
                        level=logging.DEBUG,
                        format='%(asctime)s %(levelname)-10s %(message)s')
    # Loss and Optimizer
    optimizer_ans = torch.optim.Adam([
        param for param in bert_cb.parameters() if param.requires_grad == True
    ],
                                     lr=args.lr,
                                     betas=(args.beta, 0.999),
                                     weight_decay=1e-6)
    if args.lr_sch == 1:
        scheduler = ReduceLROnPlateau(optimizer_ans,
                                      'max',
                                      verbose=1,
                                      factor=0.9,
                                      patience=3,
                                      cooldown=3,
                                      min_lr=9e-5,
                                      epsilon=1e-6)
        if best_eval_reward:
            scheduler.step(best_eval_reward, 0)
            print("init_scheduler")
    elif args.lr_sch == 2:
        scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer_ans,
            args.lr,
            args.lr_2,
            step_size_up=3 * int(num_data / args.train_batch),
            step_size_down=3 * int(num_data / args.train_batch),
            mode='exp_range',
            gamma=0.98,
            cycle_momentum=False)
    print("starting training")
    start_time = time.time()
    n_step = 100
    gamma = args.gamma
    #vocab = "vocab"
    if num_data < 2000:

        n_val = int(num_data / (5 * args.train_batch))
    else:
        n_val = int(num_data / (7 * args.train_batch))
    with torch.autograd.set_detect_anomaly(True):
        for epoch in tqdm(range(args.epochs_ext), desc="epoch:"):
            train_iter = data_loader.chunked_data_reader(
                "train", data_quota=args.train_example_quota)  #-1
            step_in_epoch = 0
            for dataset in train_iter:
                for step, contexts in tqdm(
                        enumerate(
                            BatchDataLoader(dataset,
                                            batch_size=args.train_batch,
                                            shuffle=True))):
                    try:
                        bert_cb.train()
                        step_in_epoch += 1
                        loss = 0.
                        reward = 0.
                        for context in contexts:

                            # q_a = torch.autograd.Variable(torch.from_numpy(context.features)).cuda()
                            pre_processed, a_len, sorted_id = model2.bert_preprocess(
                                context.answers)
                            q_a = torch.autograd.Variable(
                                pre_processed.type(torch.float))
                            a_len = torch.autograd.Variable(a_len)

                            outputs = bert_cb(q_a, a_len)
                            context.labels = np.array(
                                context.labels)[sorted_id]

                            if args.prt_inf and np.random.randint(0, 100) == 0:
                                prt = True
                            else:
                                prt = False

                            loss_t, reward_t = bandit.train(
                                outputs,
                                context,
                                max_num_of_ans=args.max_num_of_ans,
                                reward_type=args.reward_type,
                                prt=prt)
                            #print(str(loss_t)+' '+str(len(a_len)))

                            #    loss_t = loss_t.view(-1)
                            true_labels = np.zeros(len(context.labels))
                            gold_labels = np.array(context.labels)
                            true_labels[gold_labels > 0] = 1.0
                            # ml_loss = F.binary_cross_entropy(outputs.view(-1),torch.tensor(true_labels).type(torch.float).cuda())
                            ml_loss = F.binary_cross_entropy(
                                outputs.view(-1),
                                torch.tensor(true_labels).type(
                                    torch.float).cuda())

                            loss_e = ((gamma * loss_t) +
                                      ((1 - gamma) * ml_loss))
                            loss_e.backward()
                            loss += loss_e.item()
                            reward += reward_t
                        loss = loss / args.train_batch
                        reward = reward / args.train_batch
                        if prt:
                            print('Probabilities: ',
                                  outputs.squeeze().data.cpu().numpy())
                            print('-' * 80)

                        reward_list.append(reward)
                        loss_list.append(loss)
                        #if isinstance(loss, Variable):
                        #    loss.backward()

                        if step % 1 == 0:
                            if args.clip_grad:
                                torch.nn.utils.clip_grad_norm_(
                                    bert_cb.parameters(),
                                    args.clip_grad)  # gradient clipping
                            optimizer_ans.step()
                            optimizer_ans.zero_grad()
                        if args.lr_sch == 2:
                            scheduler.step()
                        logging.info('Epoch %d Step %d Reward %.4f Loss %.4f' %
                                     (epoch, step_in_epoch, reward, loss))
                    except Exception as e:
                        print(e)
                        #print(loss)
                        #print(loss_e)
                        traceback.print_exc()

                    if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                        logging.info('Epoch ' + str(epoch) + ' Step ' +
                                     str(step_in_epoch) + ' reward: ' +
                                     str(np.mean(reward_list)) + ' loss: ' +
                                     str(np.mean(loss_list)))
                        reward_list = []
                        loss_list = []

                    if (step_in_epoch) % n_val == 0 and step_in_epoch != 0:
                        print("doing evaluation")
                        bert_cb.eval()
                        eval_reward = evaluate.ext_model_eval(
                            bert_cb, vocab, args, args.eval_data)

                        if eval_reward[0] > best_eval_reward:
                            best_eval_reward = eval_reward[0]
                            print(
                                "saving model %s with eval_reward:" %
                                model_save_name, eval_reward)
                            logging.debug("saving model" +
                                          str(model_save_name) +
                                          "with eval_reward:" +
                                          str(eval_reward))
                            torch.save(bert_cb, model_name)
                        print('epoch ' + str(epoch) +
                              ' reward in validation: ' + str(eval_reward))
                        logging.debug('epoch ' + str(epoch) +
                                      ' reward in validation: ' +
                                      str(eval_reward))
                        logging.debug('time elapsed:' +
                                      str(time.time() - start_time))
            if args.lr_sch == 1:
                bert_cb.eval()
                eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args,
                                                      args.eval_data)
                scheduler.step(eval_reward[0], epoch)
    return bert_cb
Exemple #5
0
def extractive_training(args, vocab):
    writer = SummaryWriter('../log')
    print(args)
    print("generating config")
    config = Config(
        vocab_size=vocab.embedding.shape[0],
        embedding_dim=vocab.embedding.shape[1],
        position_size=500,
        position_dim=50,
        word_input_size=100,
        sent_input_size=2 * args.hidden,
        word_GRU_hidden_units=args.hidden,
        sent_GRU_hidden_units=args.hidden,
        pretrained_embedding=vocab.embedding,
        word2id=vocab.w2i,
        id2word=vocab.i2w,
        dropout=args.dropout,
        pooling_way=args.pooling_way,
        num_layers = args.num_layers,
        num_directions = args.num_directions,
        fixed_length=args.fixed_length,
        num_filters=args.num_filters,
        filter_sizes=args.filter_sizes,
        batch_size=args.batch_size,
        novelty=args.novelty,
    )
    model_name = ".".join(("../model/"+str(args.ext_model),
                         "termination_", str(args.terminated_way),
                         "pooling_", str(args.pooling_way),
                         "max_sent", str(args.oracle_length),
                         "min_sents", str(args.min_num_of_sents),
                         "rl_m",str(args.rl_baseline_method), 
                         "oracle_l", str(args.oracle_length),
                         "bsz", str(args.batch_size), 
                         "rl_loss", str(args.rl_loss_method),
                         "hidden", str(args.hidden),
                         "dropout", str(args.dropout),
                         'ext'))
    print(model_name)

    log_name = ".".join(("../log/"+str(args.ext_model),
                         "termination_", str(args.terminated_way),
                         "pooling_", str(args.pooling_way),
                         "max_sent", str(args.oracle_length),
                         "min_sents", str(args.min_num_of_sents),
                         "rl_m",str(args.rl_baseline_method), 
                         "oracle_l", str(args.oracle_length),
                         "bsz", str(args.batch_size), 
                         "rl_loss", str(args.rl_loss_method),
                         "hidden", str(args.hidden),
                         "dropout", str(args.dropout),
                         'log'))

    print("init data loader and RL learner")
    data_loader = PickleReader(args.data_dir)

    # init statistics
    reward_list = []
    best_eval_reward = 0.
    model_save_name = model_name

    if args.fine_tune:
        model_save_name = model_name + ".fine_tune"
        log_name = log_name + ".fine_tune"
        args.std_rouge = True
        print("fine_tune model with std_rouge, args.std_rouge changed to %s" % args.std_rouge)

    print('init extractive model')

    extract_net = model.SHE(config).cuda()
    reinforce = ReinforceReward(terminated_way=args.terminated_way, std_rouge=args.std_rouge, rouge_metric=args.rouge_metric,
                                    b=args.batch_size, rl_baseline_method=args.rl_baseline_method,
                                    loss_method=1)
    extract_net.cuda()


    logging.basicConfig(filename='%s' % log_name,
                        level=logging.INFO, format='%(asctime)s [INFO] %(message)s')
    if args.load_ext:
        print("loading existing model%s" % model_name)
        extract_net = torch.load(model_name, map_location=lambda storage, loc: storage)
        extract_net.cuda()
        print("finish loading and evaluate model %s" % model_name)
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward, _ = evaluate.ext_model_eval(extract_net, vocab, args, "val")

    # Loss and Optimizer
    optimizer_ext = torch.optim.Adam(extract_net.parameters(), lr=args.lr, betas=(0., 0.999))

    print("starting training")
    n_step = 100
    error_counter = 0

    for epoch in range(args.epochs_ext):
        train_iter = data_loader.chunked_data_reader("train", data_quota=args.train_example_quota)
        step_in_epoch = 0
        for dataset in train_iter:
            # for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True, batch_size=args.batch_size )):
            for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True)):
                try:
                    extract_net.train()
                    # if True:
                    step_in_epoch += 1
                    # for i in range(1):  # how many times a single data gets updated before proceeding
                    doc = docs[0]
                    doc.content = tokens_to_sentences(doc.content)
                    doc.summary = tokens_to_sentences(doc.summary)

                    if len(doc.content) == 0 or len(doc.summary) == 0:
                        continue

                    if len(doc.content) <3:
                        summary_index_list = range(min(len(doc.content),3))
                        loss = 0
                        reward = from_summary_index_compute_rouge(doc, summary_index_list,
                                                            std_rouge=args.std_rouge,
                                                            rouge_metric=args.rouge_metric,
                                                            max_num_of_bytes=args.length_limit)                    
                    
                    
                    
                    else:
                        if args.oracle_length == -1:  # use true oracle length
                            oracle_summary_sent_num = len(doc.summary)
                        else:
                            oracle_summary_sent_num = args.oracle_length

                        x = prepare_data(doc, vocab)
                        if min(x.shape) == 0:
                            continue
                        sents = Variable(torch.from_numpy(x)).cuda()

                        outputs = extract_net(sents)

                        if args.prt_inf and np.random.randint(0, 1000) == 0:
                            prt = True
                        else:
                            prt = False
                        loss, reward = reinforce.train(outputs, doc,
                                                min_num_of_sents=args.min_num_of_sents,
                                                max_num_of_sents=oracle_summary_sent_num,
                                                max_num_of_bytes=args.length_limit,
                                                prt=prt)
                        if prt:
                            print('Probabilities: ', outputs.squeeze().data.cpu().numpy())
                            print('-' * 80)
                    reward_list.append(reward)

                    if isinstance(loss, Variable):
                        loss.backward()

                    if step % 1 == 0:
                        torch.nn.utils.clip_grad_norm(extract_net.parameters(), 1)  # gradient clipping
                        optimizer_ext.step()
                        optimizer_ext.zero_grad()
                    # print('Epoch %d Step %d Reward %.4f'%(epoch,step_in_epoch,reward))
                    logging.info('Epoch %d Step %d Reward %.4f' % (epoch, step_in_epoch, reward))

                except Exception as e:
                    error_counter += 1
                    print(e)

                if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                    print('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) +
                       ' reward: ' + str(np.mean(reward_list)))
                    print('error_count: ',error_counter)
                    mean_loss = np.mean(reward_list)
                    writer.add_scalar('Train/SHE', mean_loss, step_in_epoch)
                    reward_list = []

                if (step_in_epoch) % 2000 == 0 and step_in_epoch != 0:
                    print("doing evaluation")
                    extract_net.eval()
                    eval_reward, lead3_reward = evaluate.ext_model_eval(extract_net, vocab, args, "val")
                    if eval_reward > best_eval_reward:
                        best_eval_reward = eval_reward
                        print("saving model %s with eval_reward:" % model_save_name, eval_reward, "leadreward",
                              lead3_reward)
                        torch.save(extract_net, model_name)
                    writer.add_scalar('val/SHE', eval_reward, step_in_epoch)
                    f = open('log/learning_curve','a')
                    f.write(str(eval_reward)+'\t'+str(lead3_reward)+'\n')
                    f.close()
                    print('epoch ' + str(epoch) + ' reward in validation: '
                          + str(eval_reward) +  ' lead3: ' + str(lead3_reward))
                    print('Error Counter: ',error_counter)
        

    return extract_net
Exemple #6
0
def extractive_training(args, vocab):
    print(args)
    print("generating config")
    config = Config(
        vocab_size=vocab.embedding.shape[0],
        embedding_dim=vocab.embedding.shape[1],
        category_size=args.category_size,
        category_dim=50,
        word_input_size=100,
        sent_input_size=2 * args.hidden,
        word_GRU_hidden_units=args.hidden,
        sent_GRU_hidden_units=args.hidden,
        pretrained_embedding=vocab.embedding,
        word2id=vocab.w2i,
        id2word=vocab.i2w,
    )

    def create_model_name(
            epoch):  #this method creates model name for loading and saving
        path = args.model_file + args.data + "/" + str(
            args.num_topics) + "/model"
        return ".".join((path, 'epoch', str(epoch), args.ext_model, 'tr'))

    model_name = create_model_name(args.start_epoch)
    print(model_name)

    log_name = '/home/ml/lyu40/PycharmProjects/E_Yue/log/' + args.data + "/" + str(
        args.num_topics) + "/" + args.ext_model + ".tr"
    eval_file_name = '/home/ml/lyu40/PycharmProjects/E_Yue/log/' + args.data + "/" + str(
        args.num_topics) + "/" + args.ext_model + ".eval"

    print("init data loader and RL learner")
    data_loader = PickleReader()

    # init statistics
    reward_list = []
    best_eval_reward = 0.
    model_save_name = args.resume
    reinforce = ReinforceReward(std_rouge=args.std_rouge,
                                rouge_metric=args.rouge_metric,
                                b=args.batch_size,
                                rl_baseline_method=args.rl_baseline_method,
                                loss_method=1)

    print('init extractive model')
    if args.ext_model == "fs":
        extract_net = model_all.FullyShare(config)
    elif args.ext_model == "ps":
        extract_net = model_all.PrivateShare(config)
    elif args.ext_model == "dm":
        extract_net = model_all.DomainModel(config)
    elif args.ext_model == "gm":
        extract_net = model_all.GeneralModel(config)
    else:
        print("this model is not implemented yet")
    # Loss and Optimizer
    optimizer = torch.optim.Adam(extract_net.parameters(),
                                 lr=args.lr,
                                 betas=(0., 0.999))
    logging.basicConfig(filename='%s.log' % log_name,
                        level=logging.INFO,
                        format='%(asctime)s [INFO] %(message)s')

    if args.resume:
        if os.path.isfile(model_name):
            try:
                print("=> loading checkpoint '{}'".format(model_name))
                checkpoint = torch.load(model_name)
                args.start_epoch = checkpoint['epoch']
                best_eval_reward = checkpoint['best_eval_reward']
                extract_net.load_state_dict(checkpoint['state_dict'])
                # optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    model_name, checkpoint['epoch']))
            except:
                extract_net = torch.load(
                    model_name, map_location=lambda storage, loc: storage)
                print("=> finish loaded checkpoint '{}' (epoch {})".format(
                    model_name, args.start_epoch))
        else:
            print("=> no checkpoint found at '{}'".format(model_name))
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        # best_eval_reward, _ = evaluate.ext_model_eval(extract_net, vocab, args, eval_data="val")
    extract_net.cuda()

    #do a quick test, remove afterwards
    # evaluate.ext_model_eval(extract_net, vocab, args, "test")
    print("starting training")
    for epoch in range(args.start_epoch + 1, args.epochs_ext):
        train_iter = data_loader.chunked_data_reader(
            "train", data_quota=args.train_example_quota)
        # train_iter: the data sets for this training epoch
        print("finish loading the data for this epoch")
        step_in_epoch = 0
        for dataset in train_iter:
            for step, docs in enumerate(BatchDataLoader(dataset,
                                                        shuffle=False)):

                try:
                    # if True:
                    #     print("trying step %d"%step_in_epoch)
                    step_in_epoch += 1
                    doc = docs[0]
                    if args.oracle_length == -1:  # use true oracle length
                        oracle_summary_sent_num = len(doc.summary)
                    else:
                        oracle_summary_sent_num = args.oracle_length

                    x = prepare_data(doc, vocab.w2i)
                    if min(x.shape) == 0:
                        continue
                    sents = Variable(torch.from_numpy(x)).cuda()
                    label_idx = Variable(
                        torch.from_numpy(np.array([doc.label_idx]))).cuda()
                    print(
                        "label_idx:", label_idx
                    )  # label_idx: tensor([ 2], dtype=torch.int32, device='cuda:0')
                    #print("content:", doc.content)
                    #print("summary:", doc.summary)

                    if label_idx.dim() == 2:
                        outputs = extract_net(sents, label_idx[0])
                    else:
                        outputs = extract_net(sents, label_idx)
                    #print("outputs: ", outputs)

                    # if np.random.randint(0, 100) == 0:
                    #     prt = True
                    # else:
                    #     prt = False
                    prt = False
                    loss, reward, summary_index_list = reinforce.train(
                        outputs,
                        doc,
                        max_num_of_sents=oracle_summary_sent_num,
                        max_num_of_chars=args.length_limit,
                        prt=prt)
                    if prt:
                        print('Probabilities: ',
                              outputs.squeeze().data.cpu().numpy())
                        print('-' * 80)

                    reward_list.append(reward)

                    if isinstance(loss, Variable):
                        loss.backward()

                    if step % 10 == 0:
                        torch.nn.utils.clip_grad_norm(extract_net.parameters(),
                                                      1)  # gradient clipping
                        optimizer.step()
                        optimizer.zero_grad()
                    #print('Epoch %d Step %d Reward %.4f'%(epoch,step_in_epoch,reward))
                    if reward < 0.0001:
                        print(
                            "very low rouge score for this instance, with reward =",
                            reward)
                        print("outputs:", outputs)
                        print("content:", doc.content)
                        print("summary:", doc.summary)
                        print("selected sentences index list:",
                              summary_index_list)
                        print("*" * 40)
                    logging.info('Epoch %d Step %d Reward %.4f' %
                                 (epoch, step_in_epoch, reward))
                except Exception as e:
                    print(
                        "skip one example because error during training, input is %s"
                        % docs[0].content)
                    print("Exception:")
                    print(e)
                    pass

                n_step = 200
                if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                    print('Epoch ' + str(epoch) + ' Step ' +
                          str(step_in_epoch) + ' reward: ' +
                          str(np.mean(reward_list)))
                    reward_list = []

                if (step_in_epoch) % 50000 == 0 and step_in_epoch != 0:
                    save_checkpoint(
                        {
                            'epoch': epoch,
                            'state_dict': extract_net.state_dict(),
                            'best_eval_reward': best_eval_reward,
                            'optimizer': optimizer.state_dict(),
                        },
                        False,
                        filename=create_model_name(epoch))

                    print("doing evaluation")
                    eval_reward, lead3_reward = evaluate.ext_model_eval(
                        extract_net, vocab, args, "val")
                    if eval_reward > best_eval_reward:
                        best_eval_reward = eval_reward
                        print(
                            "saving model %s with eval_reward:" %
                            model_save_name, eval_reward, "leadreward",
                            lead3_reward)
                        try:
                            save_checkpoint(
                                {
                                    'epoch': epoch,
                                    'step_in_epoch': step_in_epoch,
                                    'state_dict': extract_net.state_dict(),
                                    'best_eval_reward': best_eval_reward,
                                    'optimizer': optimizer.state_dict(),
                                },
                                True,
                                filename=create_model_name(epoch))
                        except:
                            print(
                                'cant save the model since shutil doesnt work')

                    print('epoch ' + str(epoch) + ' reward in validation: ' +
                          str(eval_reward) + ' lead3: ' + str(lead3_reward))
                    with open(eval_file_name, "a") as file:
                        file.write('epoch ' + str(epoch) +
                                   ' reward in validation: ' +
                                   str(eval_reward) + ' lead3: ' +
                                   str(lead3_reward) + "\n")
    return extract_net