Beispiel #1
0
    def __init__(self, args, model):
        torch.manual_seed(0)
        torch.autograd.set_detect_anomaly(True)
        self.args = args
        self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)

        self.model = model
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        if args.train_parallel:
            self.model = nn.DataParallel(self.model)
        self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD])
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.args.learning_rate)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   self.args.schedule_step,
                                                   self.args.schedule_gamma)
        self.logger = get_logger('train')
        self.train_loader = KeyphraseDataLoader(
            data_source=self.args.train_filename,
            vocab2id=self.vocab2id,
            mode='train',
            args=args)
        if self.args.train_from:
            self.dest_dir = os.path.dirname(self.args.train_from) + '/'
        else:
            timemark = time.strftime('%Y%m%d-%H%M%S',
                                     time.localtime(time.time()))
            self.dest_dir = os.path.join(
                self.args.dest_base_dir,
                self.args.exp_name + '-' + timemark) + '/'
            os.mkdir(self.dest_dir)

        fh = logging.FileHandler(os.path.join(self.dest_dir, args.logfile))
        fh.setLevel(logging.INFO)
        fh.setFormatter(logging.Formatter('[%(asctime)s] %(message)s'))
        self.logger.addHandler(fh)

        if not self.args.tensorboard_dir:
            tensorboard_dir = self.dest_dir + 'logs/'
        else:
            tensorboard_dir = self.args.tensorboard_dir
        self.writer = SummaryWriter(tensorboard_dir)
        self.eval_topn = (5, 10)
        self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro',
                                                  args.token_field,
                                                  args.keyphrase_field)
        self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro',
                                                  args.token_field,
                                                  args.keyphrase_field)
        self.best_f1 = None
        self.best_step = 0
        self.not_update_count = 0
 def __init__(self):
     self.args = parse_args()
     self.model = PlmModel(self.args)
     if torch.cuda.is_available():
         self.model.cuda()
     self.tokenizer = self.model.tokenizer
     self.loader = VectorizationDataLoader(self.args.train_filename,
                                           self.model.tokenizer,
                                           args=self.args)
     self.logger = get_logger('vectorization trainer')
     if not os.path.exists(self.args.dest_base_dir):
         os.mkdir(self.args.dest_base_dir)
     self.logdir = self.args.dest_base_dir + '/logs'
     if not os.path.exists(self.logdir):
         os.mkdir(self.logdir)
     self.writer = SummaryWriter(self.logdir)
Beispiel #3
0
 def __init__(self):
     self.args = self.parse_args()
     self.vocab2id = load_vocab(self.args.vocab_path)
     self.dest_base_dir = self.args.dest_base_dir
     self.writer = tf.summary.create_file_writer(self.dest_base_dir +
                                                 '/logs')
     self.exp_name = self.args.exp_name
     self.pad_idx = self.vocab2id[PAD_WORD]
     self.eval_topn = (5, 10)
     self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro',
                                               self.args.token_field,
                                               self.args.keyphrase_field)
     self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro',
                                               self.args.token_field,
                                               self.args.keyphrase_field)
     self.best_f1 = None
     self.best_step = 0
     self.not_update_count = 0
     self.logger = get_logger(__name__)
     self.total_vocab_size = len(self.vocab2id) + self.args.max_oov_count
 def __init__(self):
     self.args = parse_args()
     self.plm_model_name = self.args.plm_model_name
     self.rerank_model_name = self.args.rerank_model_name
     self.model_info = MODEL_DICT[self.plm_model_name]
     if 'path' in self.model_info:
         tokenizer_path = self.model_info['path'] + 'vocab.txt'
     else:
         tokenizer_path = self.plm_model_name
     self.tokenizer = self.model_info['tokenizer_class'].from_pretrained(
         tokenizer_path)
     dest_dir = self.args.dest_base_dir
     if not os.path.exists(dest_dir):
         os.mkdir(dest_dir)
     self.train_loader = RerankDataLoader(self.args.train_filename,
                                          self.tokenizer, self.args,
                                          'train')
     self.logger = get_logger('rerank_trainer')
     logdir = self.args.dest_base_dir + 'logs/'
     if not os.path.exists(logdir):
         os.mkdir(logdir)
     self.writer = SummaryWriter(logdir)
Beispiel #5
0
class RerankDataBuilder(object):
    logger = get_logger('rank data builder')

    def __init__(self):
        self.args = self.parse_args()
        self.search_filename = self.args.search_filename
        self.golden_filename = self.args.golden_filename
        self.dest_filename = self.args.dest_filename
        if os.path.exists(self.dest_filename):
            os.remove(self.dest_filename)
        self.candidate_paper_id_list = read_lines(DATA_DIR +
                                                  'candidate_paper_id.txt')

    def parse_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('-search_filename', type=str, required=True)
        parser.add_argument('-golden_filename', type=str, required=True)
        parser.add_argument('-dest_filename', type=str, required=True)
        parser.add_argument('-select_strategy',
                            type=str,
                            choices=[
                                'random', 'search_result_offset',
                                'search_result_false_top'
                            ],
                            required=True)
        parser.add_argument('-query_field',
                            type=str,
                            default='cites_text',
                            choices=['cites_text', 'description_text'])
        parser.add_argument('-sample_count', type=int, default=1)
        parser.add_argument('-aggregate_sample', action='store_true')
        parser.add_argument('-offset', type=int, default=50)
        args = parser.parse_args()
        return args

    def run(self):
        self.build_data()

    def build_data(self):
        pool = Pool(20)

        desc_id2item = {}

        for item in read_jsonline_lazy(self.golden_filename):
            desc_id = item['description_id']
            desc_id2item[desc_id] = item

        chunk_size = 50
        for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename),
                                    chunk_size):
            new_item_chunk = []
            for item in item_chunk:
                true_item = desc_id2item[item['description_id']]
                true_paper_id = true_item['paper_id']
                cites_text = true_item['cites_text']
                docs = item['docs']
                item.pop('docs')
                item.pop('keywords')
                new_item_list = []
                new_item_dict = copy.deepcopy(item)
                new_item_dict['true_paper_id'] = true_paper_id
                new_item_dict['false_paper_id'] = []
                new_item_dict['cites_text'] = cites_text
                new_item_dict['description_text'] = true_item[
                    'description_text']
                for idx in range(self.args.sample_count):
                    train_pair = self.select_train_pair(
                        docs, true_paper_id, self.args.select_strategy, idx)
                    new_item = {
                        **train_pair,
                        **item, 'cites_text': cites_text,
                        'description_text': true_item['description_text']
                    }
                    new_item_list.append(new_item)
                    new_item_dict['false_paper_id'].append(
                        train_pair['false_paper_id'])
                if self.args.aggregate_sample:
                    new_item_chunk.append(new_item_dict)
                else:
                    new_item_chunk.extend(new_item_list)
            built_items = pool.map(self.build_single_query, new_item_chunk)
            built_items = [i for i in built_items if i]
            append_jsonlines(self.dest_filename, built_items)

    def select_train_pair(self, doc_list, true_doc_id, select_strategy,
                          intra_offset):
        offset = self.args.offset + intra_offset
        if select_strategy == 'search_result_offset':
            true_idx = index(doc_list, true_doc_id, -1)
            if true_idx == -1 or true_idx + offset >= len(doc_list):
                if len(doc_list) <= offset:
                    # when doc_list count is fewer than offset, used random selected false id
                    # because the result caused by following reasons will drop training result
                    # 1. the unusual description text will return 3 predefined paper id (in search.py)
                    # 2. too small topk in benchmark, and last instance of this result list has similar context of  true paper, will confused model
                    false_paper_id = self.random_choose_false_id(true_doc_id)
                else:
                    false_idx = -self.args.sample_count + intra_offset
                    false_paper_id = doc_list[false_idx]
            else:
                false_paper_id = doc_list[true_idx + offset]
        elif select_strategy == 'random':
            false_paper_id = self.random_choose_false_id(true_doc_id)
        elif select_strategy == 'search_result_false_top':
            true_idx = index(doc_list, true_doc_id, -1)
            if true_idx == 0:
                false_paper_id = doc_list[1]
            else:
                false_paper_id = doc_list[0]
        else:
            raise ValueError('false instance select strategy error')
        return {'true_paper_id': true_doc_id, 'false_paper_id': false_paper_id}

    def random_choose_false_id(self, true_doc_id):
        false_paper_id = random.choice(self.candidate_paper_id_list)
        if false_paper_id == true_doc_id:
            while True:
                false_paper_id = random.choice(self.candidate_paper_id_list)
                if false_paper_id != true_doc_id:
                    break
        return false_paper_id

    def build_single_query(self, item):
        query = item[self.args.query_field]
        true_paper = get_paper(item['true_paper_id'])
        true_text = true_paper['title'] + ' ' + true_paper['abstract']
        if isinstance(item['false_paper_id'], str):
            false_paper = get_paper(item['false_paper_id'])
            false_text = false_paper['title'] + ' ' + false_paper['abstract']
        elif isinstance(item['false_paper_id'], list):
            # false_paper = []
            false_text = []
            for pid in item['false_paper_id']:
                paper = get_paper(pid)
                # false_paper.append(paper)
                false_text.append(paper['title'] + ' ' + paper['abstract'])
        else:
            raise ValueError('false paper id type error')

        train_item = {
            'query': query,
            'true_doc': true_text,
            'false_doc': false_text,
            **item
        }
        train_item.pop('cites_text')
        return train_item
class ElasticSearchIndexer(object):
    logger = get_logger('elastic indexer')
    parallel_size = 2
    retry_count = 10

    def __init__(self):
        self.base_url = ES_API_URL

    def run(self):
        self.delete_index()
        self.create_fields()
        self.indexing_runner()

    def create_fields(self):
        mapping_url = self.base_url + '/_mapping'
        headers = {"Content-Type": "application/json"}
        base_data = json.dumps(read_json(DATA_DIR + 'setting.json'))
        field_data = json.dumps(read_json(DATA_DIR + 'fields.json'))
        ret = requests.put(self.base_url, data=base_data, headers=headers)
        if ret.status_code != 200:
            raise Exception('setting es error, {}'.format(ret.text))

        ret = requests.put(mapping_url, data=field_data, headers=headers)
        if ret.status_code != 200:
            raise Exception('create index error, {}'.format(ret.text))
        self.logger.info('create index success')

    def delete_index(self):
        requests.delete(self.base_url)

    def indexing_runner(self):
        filename = CANDIDATE_FILENAME
        pool = Pool(self.parallel_size)
        start = time.time()
        count = 0
        failed_doc_list = []
        for item_chunk in get_chunk(read_jsonline_lazy(filename), 500):
            ret = pool.map(self.index_doc, item_chunk)
            failed_doc_list.extend([i for i in ret if i])
            duration = time.time() - start
            count += len(item_chunk)
            msg = '{} completed, {}min {:.2f}s'.format(count, duration // 60,
                                                       duration % 60)
            self.logger.info(msg)

        for doc in failed_doc_list:
            self.index_doc(doc)

    def index_doc(self, doc):
        base_url = self.base_url + '/_doc/{}'
        headers = {"Content-Type": "application/json"}
        url = base_url.format(doc['paper_id'])
        if doc['keywords'] == '':
            doc['keywords'] = []
        else:
            doc['keywords'] = doc['keywords'].split(';')
        doc['TA'] = doc.get('title', '') + ' ' + doc.get('abstract', '')
        keyword_str = ' '.join(doc['keywords']).lower()
        if keyword_str:
            doc['TAK'] = doc['TA'] + ' ' + keyword_str
        else:
            doc['TAK'] = doc['TA']
        if not doc['paper_id'].strip():
            return

        input_str = json.dumps(doc)
        if not input_str:
            return

        try:
            ret = requests.put(url, input_str, headers=headers)
        except:
            count = 0
            is_success = False
            while count < self.retry_count:
                count += 1
                try:
                    ret = requests.put(url, input_str, headers=headers)
                except:
                    continue
                is_success = True
                break
            if not is_success:
                err_msg = '{} index failed, {}'.format(doc['paper_id'],
                                                       traceback.format_exc())
                self.logger.error(err_msg)
                return doc
        if ret.status_code != 200 and ret.status_code != 201:
            msg = '{} error {} {}'.format(json.dumps(doc), ret.status_code,
                                          ret.text)
            self.logger.error(msg)