def __init__(self,
                 args,
                 model,
                 loss,
                 train_data_loader,
                 dev_data_loader,
                 train_eval_file,
                 dev_eval_file,
                 optimizer,
                 scheduler,
                 epochs,
                 with_cuda,
                 save_dir,
                 verbosity=2,
                 save_freq=1,
                 print_freq=10,
                 resume=False,
                 identifier='',
                 debug=False,
                 debug_batchnum=2,
                 visualizer=None,
                 logger=None,
                 grad_clip=5.0,
                 decay=0.9999,
                 lr=0.001,
                 lr_warm_up_num=1000,
                 use_scheduler=False,
                 use_grad_clip=False,
                 use_ema=False,
                 ema=None,
                 use_early_stop=False,
                 early_stop=10):
        self.device = torch.device("cuda" if with_cuda else "cpu")
        self.args = args

        self.model = model
        self.loss = loss
        self.optimizer = optimizer
        self.epochs = epochs
        self.save_dir = save_dir
        self.save_freq = save_freq
        self.print_freq = print_freq
        self.verbosity = verbosity
        self.identifier = identifier
        self.visualizer = visualizer
        self.with_cuda = with_cuda

        self.train_data_loader = train_data_loader
        self.dev_data_loader = dev_data_loader
        self.dev_eval_dict = pickle_load_large_file(dev_eval_file)
        self.is_debug = debug
        self.debug_batchnum = debug_batchnum
        self.logger = logger
        self.unused = True  # whether scheduler has been updated

        self.lr = lr
        self.lr_warm_up_num = lr_warm_up_num
        self.decay = decay
        self.use_scheduler = use_scheduler
        self.scheduler = scheduler
        self.use_grad_clip = use_grad_clip
        self.grad_clip = grad_clip
        self.use_ema = use_ema
        self.ema = ema
        self.use_early_stop = use_early_stop
        self.early_stop = early_stop

        self.start_time = datetime.now().strftime('%b-%d_%H-%M')
        self.start_epoch = 1
        self.step = 0
        self.best_em = 0
        self.best_f1 = 0
        if resume:
            self._resume_checkpoint(resume)
            self.model = self.model.to(self.device)
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(self.device)
예제 #2
0
def main(args):
    # show configuration
    print(args)
    random_seed = None

    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

    # set log file
    log = sys.stdout
    if args.log_file is not None:
        log = open(args.log_file, "a")

    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    if torch.cuda.is_available():
        print("device is cuda, # cuda is: ", n_gpu)
    else:
        print("device is cpu")

    # process word vectors and datasets
    if not args.processed_data:
        prepro(args)

    # load word vectors and datasets
    wv_tensor = torch.FloatTensor(
        np.array(pickle_load_large_file(args.word_emb_file), dtype=np.float32))
    cv_tensor = torch.FloatTensor(
        np.array(pickle_load_large_file(args.char_emb_file), dtype=np.float32))
    wv_word2ix = pickle_load_large_file(args.word_dictionary)

    train_dataloader = get_loader(args.train_examples_file,
                                  args.batch_size,
                                  shuffle=True)
    dev_dataloader = get_loader(args.dev_examples_file,
                                args.batch_size,
                                shuffle=True)

    # construct model
    model = QANet(wv_tensor,
                  cv_tensor,
                  args.para_limit,
                  args.ques_limit,
                  args.d_model,
                  num_head=args.num_head,
                  train_cemb=(not args.pretrained_char),
                  pad=wv_word2ix["<PAD>"])
    model.summary()
    if torch.cuda.device_count() > 1 and args.multi_gpu:
        model = nn.DataParallel(model)
    model.to(device)

    # exponential moving average
    ema = EMA(args.decay)
    if args.use_ema:
        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.register(name, param.data)

    # set optimizer and scheduler
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(params=parameters,
                           lr=args.lr,
                           betas=(args.beta1, args.beta2),
                           eps=1e-8,
                           weight_decay=3e-7)
    cr = 1.0 / math.log(args.lr_warm_up_num)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda ee: cr * math.log(ee + 1)
        if ee < args.lr_warm_up_num else 1)

    # set loss, metrics
    loss = torch.nn.CrossEntropyLoss()

    # set visdom visualizer to store training process information
    # see the training process on http://localhost:8097/
    vis = None
    if args.visualizer:
        os.system("python -m visdom.server")
        vis = Visualizer("main")

    # construct trainer
    # an identifier (prefix) for saved model
    identifier = type(model).__name__ + '_'
    trainer = Trainer(args,
                      model,
                      loss,
                      train_data_loader=train_dataloader,
                      dev_data_loader=dev_dataloader,
                      train_eval_file=args.train_eval_file,
                      dev_eval_file=args.dev_eval_file,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      epochs=args.epochs,
                      with_cuda=args.with_cuda,
                      save_dir=args.save_dir,
                      verbosity=args.verbosity,
                      save_freq=args.save_freq,
                      print_freq=args.print_freq,
                      resume=args.resume,
                      identifier=identifier,
                      debug=args.debug,
                      debug_batchnum=args.debug_batchnum,
                      lr=args.lr,
                      lr_warm_up_num=args.lr_warm_up_num,
                      grad_clip=args.grad_clip,
                      decay=args.decay,
                      visualizer=vis,
                      logger=log,
                      use_scheduler=args.use_scheduler,
                      use_grad_clip=args.use_grad_clip,
                      use_ema=args.use_ema,
                      ema=ema,
                      use_early_stop=args.use_early_stop,
                      early_stop=args.early_stop)

    # start training!
    start = datetime.now()
    trainer.train()
    print("Time of training model ", datetime.now() - start)
예제 #3
0
 def __init__(self, examples_file):
     self.examples = pickle_load_large_file(examples_file)
     self.num = len(self.examples)
예제 #4
0
def main(args):
    if args.net == "s2s_qanet":
        from model.QG_model import QGModel_S2S_CluePredict as Model
    else:
        print("Default use s2s_qanet model.")
        from model.QG_model import QGModel_S2S_CluePredict as Model
    # configuration
    emb_config["word"]["emb_size"] = args.tgt_vocab_limit
    args.emb_config["word"]["emb_size"] = args.tgt_vocab_limit
    args.brnn = True
    args.lower = True

    args.share_embedder = True
    args.use_ema = True
    args.use_clue_predict = True
    args.clue_predictor = "gcn"
    args.use_refine_copy_tgt_src = True
    args.add_word_freq_emb = True

    args.save_dir = get_auto_save_dir(args)
    if args.mode != "train":
        args.resume = args.save_dir + "model_best.pth.tar"  # !!!!! NOTICE: so set --resume won't change it.
    print(args)

    # device, random seed, logger
    device, use_cuda, n_gpu = set_device(args.no_cuda)
    set_random_seed(args.seed)
    logger = set_logger(args.log_file)

    # preprocessing
    if args.not_processed_data:  # use --not_processed_data --spacy_not_processed_data for complete prepro
        prepro(args)

    # data
    emb_mats = pickle_load_large_file(args.emb_mats_file)
    emb_dicts = pickle_load_large_file(args.emb_dicts_file)

    train_dataloader = get_loader(args,
                                  emb_dicts,
                                  args.train_examples_file,
                                  args.batch_size,
                                  shuffle=True)
    dev_dataloader = get_loader(args,
                                emb_dicts,
                                args.dev_examples_file,
                                args.batch_size,
                                shuffle=False)
    test_dataloader = get_loader(args,
                                 emb_dicts,
                                 args.test_examples_file,
                                 args.batch_size,
                                 shuffle=False)

    # model
    model = Model(args, emb_mats, emb_dicts)
    summarize_model(model)
    if use_cuda and args.use_multi_gpu and n_gpu > 1:
        model = nn.DataParallel(model)
    model.to(device)
    partial_models = None
    partial_resumes = None
    partial_trainables = None

    # optimizer and scheduler
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    for p in parameters:
        if p.dim() == 1:
            p.data.normal_(0, math.sqrt(6 / (1 + p.size(0))))
        elif list(p.shape) == [args.tgt_vocab_limit, 300]:
            print("omit embeddings.")
        else:
            nn.init.xavier_normal_(p, math.sqrt(3))
    optimizer = Optim(args.optim,
                      args.learning_rate,
                      max_grad_norm=args.max_grad_norm,
                      max_weight_value=args.max_weight_value,
                      lr_decay=args.learning_rate_decay,
                      start_decay_at=args.start_decay_at,
                      decay_bad_count=args.halve_lr_bad_count)
    optimizer.set_parameters(model.parameters())
    scheduler = None

    loss = {}
    loss["P"] = torch.nn.CrossEntropyLoss()
    loss["D"] = torch.nn.BCEWithLogitsLoss(reduction="sum")

    # trainer
    trainer = Trainer(args,
                      model,
                      train_dataloader=train_dataloader,
                      dev_dataloader=dev_dataloader,
                      loss=loss,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      device=device,
                      emb_dicts=emb_dicts,
                      logger=logger,
                      partial_models=partial_models,
                      partial_resumes=partial_resumes,
                      partial_trainables=partial_trainables)

    # start train/eval/test model
    start = datetime.now()
    if args.mode == "train":
        trainer.train()
    elif args.mode == "eval_train":
        args.use_ema = False
        trainer.eval(train_dataloader, args.train_eval_file,
                     args.train_output_file)
    elif args.mode in ["eval", "evaluation", "valid", "validation"]:
        args.use_ema = False
        trainer.eval(dev_dataloader, args.dev_eval_file, args.eval_output_file)
    elif args.mode == "test":
        args.use_ema = False
        trainer.eval(test_dataloader, args.test_eval_file,
                     args.test_output_file)
    else:
        print("Error: set mode to be train or eval or test.")
    print(("Time of {} model: {}").format(args.mode, datetime.now() - start))
예제 #5
0
파일: QG_data.py 프로젝트: zxr445116086/QG
    def __init__(self, config, emb_dicts, examples_file):
        self.examples = pickle_load_large_file(examples_file)
        self.num = len(self.examples)

        # refine examples according to config here.
        start = datetime.now()

        if (config.use_refine_copy or config.use_refine_copy_tgt
                or config.use_refine_copy_src
                or config.use_refine_copy_tgt_src):
            assert (config.refined_src_vocab_limit <= config.tgt_vocab_limit)
            assert (config.refined_tgt_vocab_limit <=
                    config.refined_src_vocab_limit)
            assert (config.refined_copy_vocab_limit <=
                    config.refined_tgt_vocab_limit)
            OOV_id = emb_dicts["tgt"]["<OOV>"]
            for i in range(self.num):
                # refine switch and copy_position
                example = self.examples[i]
                switch = np.zeros(config.ques_limit, dtype=np.int32)
                copy_position = np.zeros(config.ques_limit, dtype=np.int32)
                tgt = np.zeros(config.ques_limit, dtype=np.int32)
                # iterate over question tokens
                for idx, tgt_word in enumerate(example["ques_tokens"]):
                    # get question token's word index and generated_tgt index
                    word_idx = None
                    generated_tgt_idx = None
                    for each in (tgt_word, tgt_word.lower(),
                                 tgt_word.capitalize(), tgt_word.upper()):
                        if each in emb_dicts["tgt"]:
                            word_idx = emb_dicts["tgt"][each]
                            generated_tgt_idx = emb_dicts[
                                "word2generated_tgt"][word_idx]
                            break

                    # get refined copy
                    compare_idx = word_idx
                    OOV_idx = emb_dicts["tgt"]["<OOV>"]
                    if config.use_generated_tgt_as_tgt_vocab:
                        compare_idx = generated_tgt_idx
                        OOV_idx = emb_dicts["generated_tgt"]["<OOV>"]

                    # oov or low-freq as copy target
                    if (compare_idx is None) or \
                            (compare_idx >= config.refined_copy_vocab_limit) or \
                            compare_idx == OOV_idx:
                        if tgt_word.lower() in example["src_tokens"]:
                            switch[idx] = 1
                            # NOTICE: we can revise here,
                            # as tgt_word can show multiple times
                            copy_position[idx] = \
                                example["src_tokens"].index(tgt_word.lower())

                    # get refined tgt
                    if (config.use_refine_copy_tgt
                            or config.use_refine_copy_tgt_src):
                        if (compare_idx is None) or \
                                (compare_idx >= config.refined_tgt_vocab_limit) or \
                                compare_idx == OOV_idx:
                            tgt[idx] = OOV_id
                        else:
                            tgt[idx] = word_idx
                # assign new values
                self.examples[i]["switch"] = switch
                self.examples[i]["copy_position"] = copy_position

                # refine tgt ids
                if (config.use_refine_copy_tgt
                        or config.use_refine_copy_tgt_src):
                    self.examples[i]["tgt"] = tgt

                # refine src ids
                if (config.use_refine_copy_src
                        or config.use_refine_copy_tgt_src):
                    c_mask = (example['ans_sent_word_ids'] >=
                              config.refined_src_vocab_limit)
                    self.examples[i]['ans_sent_word_ids'] = \
                        c_mask * OOV_id + \
                        (1 - c_mask) * example['ans_sent_word_ids']
                    q_mask = (example['ques_word_ids'] >=
                              config.refined_src_vocab_limit)
                    self.examples[i]['ques_word_ids'] = \
                        q_mask * OOV_id + \
                        (1 - q_mask) * example['ques_word_ids']

        for i in range(self.num):
            # add elmo embedding
            if config.add_elmo:
                example = self.examples[i]
                self.examples[i]["ans_sent_elmo_ids"] = tokens2ELMOids(
                    example["ans_sent_tokens"], config.sent_limit)
                self.examples[i]["ques_elmo_ids"] = tokens2ELMOids(
                    example["ques_tokens"], config.ques_limit)
            else:
                self.examples[i]["ans_sent_elmo_ids"] = np.array([0])
                self.examples[i]["ques_elmo_ids"] = np.array([0])

            # add word frequency embedding: 0 pad, 1 low-freq, 2 high-freq
            if config.add_word_freq_emb:
                example = self.examples[i]
                self.examples[i]["ans_sent_word_freq"] = get_word_freq_ids(
                    example["ans_sent_word_ids"], config.high_freq_bound,
                    config.low_freq_bound, emb_dicts["word"]["<PAD>"],
                    emb_dicts["word"]["<OOV>"])

                self.examples[i]["ques_word_freq"] = get_word_freq_ids(
                    example["ques_word_ids"], config.high_freq_bound,
                    config.low_freq_bound, emb_dicts["word"]["<PAD>"],
                    emb_dicts["word"]["<OOV>"])
            else:
                self.examples[i]["ans_sent_word_freq"] = np.array([0])
                self.examples[i]["ques_word_freq"] = np.array([0])

            # add hybrid clue target
            if config.use_hybrid_clue_tgt:
                example = self.examples[i]
                self.examples[i]["y_clue"] = get_hybrid_clue_target(
                    example, config.high_freq_bound, config.low_freq_bound,
                    emb_dicts["word"]["<PAD>"], emb_dicts["word"]["<OOV>"])
            else:
                example = self.examples[i]
                self.examples[i]["y_clue"] = \
                    example["ans_sent_is_overlap"] * \
                    abs(1 - example["ans_sent_is_stop"])

        print(("Time of refine data: {}").format(datetime.now() - start))
예제 #6
0
파일: QG_data.py 프로젝트: zxr445116086/QG
def prepro(config):
    emb_tags = config.emb_config.keys()
    emb_config = config.emb_config
    emb_mats = {}
    emb_dicts = {}

    debug = config.debug
    debug_length = config.debug_batchnum * config.batch_size

    # get train spacy processed examples
    if config.spacy_not_processed_data:
        train_examples = get_raw_examples(config.train_file, config.data_type,
                                          debug, debug_length)
        train_examples, train_meta, train_eval = get_spacy_processed_examples(
            config, train_examples, debug, debug_length, shuffle=False)

        dev_examples = get_raw_examples(config.dev_file, config.data_type,
                                        debug, debug_length)
        dev_examples, dev_meta, dev_eval = get_spacy_processed_examples(
            config, dev_examples, debug, debug_length, shuffle=False)

        test_examples = get_raw_examples(config.test_file, config.data_type,
                                         debug, debug_length)
        test_examples, test_meta, test_eval = get_spacy_processed_examples(
            config, test_examples, debug, debug_length, shuffle=False)

        save(config.train_spacy_processed_examples_file,
             (train_examples, train_meta, train_eval),
             message="train spacy processed examples and meta")
        save(config.dev_spacy_processed_examples_file,
             (dev_examples, dev_meta, dev_eval),
             message="dev spacy processed examples and meta")
        save(config.test_spacy_processed_examples_file,
             (test_examples, test_meta, test_eval),
             message="test spacy processed examples and meta")
    else:
        train_examples, train_meta, train_eval = pickle_load_large_file(
            config.train_spacy_processed_examples_file)
        dev_examples, dev_meta, dev_eval = pickle_load_large_file(
            config.dev_spacy_processed_examples_file)
        test_examples, test_meta, test_eval = pickle_load_large_file(
            config.test_spacy_processed_examples_file)

    # get counters
    counters = get_updated_counters_by_examples(config,
                                                None,
                                                train_examples,
                                                increment=1,
                                                init=True,
                                                finish=True)
    # only use train data
    final_counters = copy.deepcopy(counters)

    # get emb_mats and emb_dicts
    if not config.processed_emb:
        for tag in emb_tags:
            emb_mats[tag], emb_dicts[tag] = get_embedding(
                final_counters[tag],
                tag,
                emb_file=emb_config[tag]["emb_file"],
                size=emb_config[tag]["emb_size"],
                vec_size=emb_config[tag]["emb_dim"])
        emb_mats = init_emb_mat_by_glove(config, emb_mats, emb_dicts)
        emb_dicts["tgt"] = emb_dicts["word"]
        emb_mats["generated_tgt"], emb_dicts["generated_tgt"] = get_embedding(
            final_counters["generated_tgt"],
            "generated_tgt",
            emb_file=emb_config["word"]["emb_file"],
            size=emb_config["word"]["emb_size"],
            vec_size=emb_config["word"]["emb_dim"])
        emb_dicts["generated_tgt2word"] = get_value_maps_between_dicts(
            emb_dicts["generated_tgt"], emb_dicts["word"])
        emb_dicts["word2generated_tgt"] = get_value_maps_between_dicts(
            emb_dicts["word"], emb_dicts["generated_tgt"])
        emb_dicts["idx2tgt"] = {v: k for k, v in emb_dicts["tgt"].items()}
        emb_dicts["idx2generated_tgt"] = {
            v: k
            for k, v in emb_dicts["generated_tgt"].items()
        }
    else:
        emb_mats = pickle_load_large_file(config.emb_mats_file)
        emb_dicts = pickle_load_large_file(config.emb_dicts_file)
    for k in emb_dicts:
        print("Embedding dict length: " + k + " " + str(len(emb_dicts[k])))

    # get featured examples
    # TODO: handle potential insert SOS EOS problem
    #       when extracting tag features
    train_examples, train_meta = get_featured_examples(config, train_examples,
                                                       train_meta, "train",
                                                       emb_dicts)
    dev_examples, dev_meta = get_featured_examples(config, dev_examples,
                                                   dev_meta, "dev", emb_dicts)
    test_examples, test_meta = get_featured_examples(config, test_examples,
                                                     test_meta, "test",
                                                     emb_dicts)

    # save pickle
    save(config.emb_mats_file, emb_mats, message="embedding mats")
    save(config.emb_dicts_file, emb_dicts, message="embedding dicts")
    save(config.train_examples_file, train_examples, message="train examples")
    save(config.dev_examples_file, dev_examples, message="dev examples")
    save(config.test_examples_file, test_examples, message="test examples")
    save(config.train_meta_file, train_meta, message="train meta")
    save(config.dev_meta_file, dev_meta, message="dev meta")
    save(config.test_meta_file, test_meta, message="test meta")
    save(config.train_eval_file, train_eval, message="train eval")
    save(config.dev_eval_file, dev_eval, message="dev eval")
    save(config.test_eval_file, test_eval, message="test eval")
    save(config.counters_file, final_counters, message="counters")

    # print to txt to debug
    for k in emb_dicts:
        write_dict(emb_dicts[k], "output/emb_dicts_" + str(k) + ".txt")
    for k in counters:
        write_counter(counters[k], "output/counters_" + str(k) + ".txt")
    write_example(train_examples[5], "output/train_example.txt")
    write_example(dev_examples[5], "output/dev_example.txt")
    write_example(test_examples[5], "output/test_example.txt")
    write_dict(train_meta, "output/train_meta.txt")
    write_dict(dev_meta, "output/dev_meta.txt")
    write_dict(test_meta, "output/test_meta.txt")