Ejemplo n.º 1
0
Archivo: prep.py Proyecto: jzbjyb/rri
def click_to_rel():
    data_dir = args.data_dir
    judge_click = os.path.join(data_dir, 'judgement_DCTR')
    judge_refer = args.judgement_refer
    judge_click = load_judge_file(judge_click, scale=float)
    judge_refer = load_judge_file(judge_refer, scale=int)
    rels = []
    for q in judge_refer:
        for d in judge_refer[q]:
            rels.append(judge_refer[q][d])
    clicks = []
    for q in judge_click:
        for d in judge_click[q]:
            clicks.append(judge_click[q][d])
    rels = sorted(rels)
    clicks = sorted(clicks)
    if len(rels) <= 0 or len(clicks) <= 0:
        raise Exception('judgement has no record')
    ratio = []
    last = '#'
    for i in range(len(rels)):
        r = rels[i]
        if r != last:
            ratio.append([r, 0])
            if len(ratio) > 1:
                ratio[-2][1] = i / len(rels)
        last = r
    ratio[-1][1] = 1
    threshold = []
    k = 0
    last = '#'
    for i in range(len(clicks)):
        while i / len(clicks) >= ratio[k][1]:
            k += 1
        if last != '#' and last[0] != ratio[k][0]:
            threshold.append(last)
        last = [ratio[k][0], clicks[i]]
    threshold.append(last)
    print('ratio: {}'.format(ratio))
    print('threshold: {}'.format(threshold))
    threshold = [[0, 0.05], [1, 0.3], [2, 1]]  # my guess
    judge_rel = defaultdict(lambda: defaultdict(lambda: None))

    def click2rel(click):
        k = 0
        while click > threshold[k][1]:
            k += 1
        return threshold[k][0]

    for q in judge_click:
        for d in judge_click[q]:
            judge_rel[q][d] = click2rel(judge_click[q][d])
    save_judge_file(judge_rel, os.path.join(data_dir, 'judgement_rel'))
Ejemplo n.º 2
0
Archivo: prep.py Proyecto: jzbjyb/rri
def generate_train_test():
    data_dir = args.data_dir
    query_filepath = os.path.join(data_dir, 'query')
    judge_filepath = os.path.join(data_dir, 'judgement')
    run_filepath = os.path.join(data_dir, 'run')
    # split train and test dataset based on queries rather than qid
    query_dict = load_from_query_file(query_filepath)
    unique_queries = np.unique(list(query_dict.values()))
    np.random.shuffle(unique_queries)
    train_size = int(len(unique_queries) * args.train_test_ratio)
    test_size = len(unique_queries) - train_size
    if train_size <= 0 or test_size <= 0:
        raise Exception('train test dataset size is incorrect')
    print('#unique queries: {}, train size: {}, test size: {}'.format(
        len(unique_queries), train_size, test_size))
    train_queries = set(unique_queries[:train_size])
    test_queries = set(unique_queries[train_size:])
    train_qids = set([q for q in query_dict if query_dict[q] in train_queries])
    test_qids = set([q for q in query_dict if query_dict[q] in test_queries])
    miss_docs = set()
    have_docs = set()
    train_samples = []
    test_samples = []
    qd_judge = load_judge_file(judge_filepath)
    for q in qd_judge:
        for d in qd_judge[q]:
            if qd_judge[q][d] is None:  # skip documents without judgement
                continue
            if not os.path.exists(os.path.join(data_dir, 'docs', d + '.html')):
                miss_docs.add(d)
                continue
            have_docs.add(d)
            if q in train_qids:
                train_samples.append((q, d, qd_judge[q][d]))
            elif q in test_qids and not os.path.exists(run_filepath):
                test_samples.append((q, d, qd_judge[q][d]))
    if os.path.exists(run_filepath):
        run_result = load_run_file(run_filepath)
        for q, _, d, rank, score, _ in run_result:
            if qd_judge[q][d] is None:  # skip documents without judgement
                continue
            if not os.path.exists(os.path.join(data_dir, 'docs', d + '.html')):
                miss_docs.add(d)
                continue
            have_docs.add(d)
            if q in test_qids:
                test_samples.append((q, d, qd_judge[q][d]))
    print('have {} docs, miss {} docs'.format(len(have_docs), len(miss_docs)))
    save_train_test_file(train_samples,
                         os.path.join(data_dir, 'train.pointwise'))
    save_train_test_file(test_samples, os.path.join(data_dir,
                                                    'test.pointwise'))
Ejemplo n.º 3
0
def click_model_to_rel(do=['train', 'test'],
                       files=['train.prep.pointwise', 'test.prep.pointwise']):
    train_file = os.path.join(args.data_dir, files[0])
    test_file = os.path.join(args.data_dir, files[1])
    train_click = load_judge_file(train_file, scale=float)
    clicks = []
    for q in train_click:
        for d in train_click[q]:
            clicks.append(train_click[q][d])
    clicks = sorted(clicks)
    ratio = [(0, 0.8), (1, 0.95),
             (2, 1)]  # the cumulative distribution of relevance label
    threshold = []
    k = 0
    last = '#'
    for i in range(len(clicks)):
        while i / len(clicks) >= ratio[k][1]:
            k += 1
        if last != '#' and last[0] != ratio[k][0]:
            threshold.append(last)
        last = [ratio[k][0], clicks[i]]
    threshold.append(last)
    print('ratio: {}'.format(ratio))
    print('threshold: {}'.format(threshold))

    #threshold = [[0, 0.05], [1, 0.3], [2, 1]]  # my guess
    def click2rel(click):
        k = 0
        while click > threshold[k][1]:
            k += 1
        return threshold[k][0]

    # save
    def map_fn(sample):
        nonlocal click2rel
        return PointwiseSample(sample.qid, sample.docid,
                               click2rel(sample.label), sample.query,
                               sample.doc)

    if 'train' in do:
        prep_file_mapper(train_file,
                         train_file + '.rel',
                         method='sample',
                         func=float,
                         map_fn=map_fn)
    if 'test' in do:
        prep_file_mapper(test_file,
                         test_file + '.rel',
                         method='sample',
                         func=float,
                         map_fn=map_fn)
Ejemplo n.º 4
0
Archivo: prep.py Proyecto: jzbjyb/rri
def filter_judgement():
    filtered_ext = ['.pdf', '.ppt', '.pptx', '.doc', '.docx', '.txt']
    filtered_ext = tuple(filtered_ext + [ext.upper() for ext in filtered_ext])
    allowed_ext = tuple(['html', 'htm', 'com', 'cn', 'asp', 'shtml', 'php'])
    data_dir = args.data_dir
    docid_to_url = load_from_query_file(os.path.join(data_dir, 'docid_to_url'))
    qd_judge = load_judge_file(os.path.join(data_dir, 'judgement_rel'))
    qd_judge_new = defaultdict(lambda: defaultdict(lambda: None))
    count = 0
    for q in qd_judge:
        for d in qd_judge[q]:
            if docid_to_url[d].endswith(filtered_ext):
                count += 1
                continue
            qd_judge_new[q][d] = qd_judge[q][d]
    print('#non-html url: {}'.format(count))
    save_judge_file(qd_judge_new, os.path.join(data_dir, 'judgement'))
Ejemplo n.º 5
0
def train_test():
    '''
    load config
    '''
    rel_level = 2
    max_q_len_consider = 10
    max_d_len_consider = 1000
    if args.config != None:
        model_config = json.load(open(args.config))
        print('model config: {}'.format(model_config))
    '''
    load word vector
    '''
    w2v_file = os.path.join(args.data_dir, 'w2v')
    vocab_file = os.path.join(args.data_dir, 'vocab')
    print('loading word vector ...')
    wv = WordVector(filepath=w2v_file)
    vocab = Vocab(filepath=vocab_file, file_format=args.format)
    print('vocab size: {}, word vector dim: {}'.format(wv.vocab_size, wv.dim))
    if not args.tfrecord:
        '''
        load data (placeholder)
        '''
        train_file = os.path.join(args.data_dir,
                                  'train.prep.{}'.format(args.paradigm))
        test_file = os.path.join(args.data_dir,
                                 'test.prep.{}'.format(args.paradigm))
        test_file_judge = os.path.join(args.data_dir, 'test.prep.pointwise')
        doc_file = os.path.join(args.data_dir, 'docs.prep')
        if args.format == 'ir':
            query_file = os.path.join(args.data_dir, 'query.prep')
        print('loading query doc content ...')
        doc_raw = load_prep_file(doc_file, file_format=args.format)
        if args.format == 'ir':
            query_raw = load_prep_file(query_file, file_format=args.format)
        else:
            query_raw = doc_raw
        print('truncate long document')
        d_long_count = 0
        avg_doc_len, avg_truncate_doc_len = 0, 0
        truncate_len = max(max_q_len_consider, max_d_len_consider)
        for d in doc_raw:
            avg_doc_len += len(doc_raw[d])
            if len(doc_raw[d]) > truncate_len:
                d_long_count += 1
                doc_raw[d] = doc_raw[d][:truncate_len]
                avg_truncate_doc_len += truncate_len
            else:
                avg_truncate_doc_len += len(doc_raw[d])
        avg_doc_len = avg_doc_len / len(doc_raw)
        avg_truncate_doc_len = avg_truncate_doc_len / len(doc_raw)
        print(
            'total doc: {}, long doc: {}, average len: {}, average truncate len: {}'
            .format(len(doc_raw), d_long_count, avg_doc_len,
                    avg_truncate_doc_len))
        max_q_len = min(max_q_len_consider,
                        max([len(query_raw[q]) for q in query_raw]))
        max_d_len = min(max_d_len_consider,
                        max([len(doc_raw[d]) for d in doc_raw]))
        print('data assemble with max_q_len: {}, max_d_len: {} ...'.format(
            max_q_len, max_d_len))

        def relevance_mapper(r):
            if r < 0:
                return 0
            if r >= rel_level:
                return rel_level - 1
            return r

        train_X, train_y, batcher = data_assemble(
            train_file,
            query_raw,
            doc_raw,
            max_q_len,
            max_d_len,
            relevance_mapper=relevance_mapper)
        '''
        doc_len_list = []
        for q_x in train_X:
            for d in q_x['qd_size']:
                doc_len_list.append(d[1])
        doc_len_list = np.array(doc_len_list, dtype=np.int32)
        doc_len_list = [min(max_jump_offset ** 2 / d, max_jump_offset) for d in doc_len_list]
        plt.hist(doc_len_list, bins=max_jump_offset)
        plt.xlim(xmin=0, xmax=max_jump_offset)
        plt.xlabel('preserve number')
        plt.ylabel('number')
        plt.show()
        '''
        test_X, test_y, _ = data_assemble(test_file,
                                          query_raw,
                                          doc_raw,
                                          max_q_len,
                                          max_d_len,
                                          relevance_mapper=relevance_mapper)
        if args.paradigm == 'pairwise':
            test_X_judge, test_y_judge, _ = data_assemble(
                test_file_judge,
                query_raw,
                doc_raw,
                max_q_len,
                max_d_len,
                relevance_mapper=relevance_mapper)
        else:
            text_X_judge, test_y_judge = test_X, test_y
        print('number of training samples: {}'.format(
            sum([len(x['query']) for x in train_X])))
        '''
        load judge file
        '''
        test_qd_judge = load_judge_file(test_file_judge,
                                        file_format=args.format,
                                        reverse=args.reverse)
        for q in test_qd_judge:
            for d in test_qd_judge[q]:
                test_qd_judge[q][d] = relevance_mapper(test_qd_judge[q][d])
    else:
        '''
        load data (tfrecord)
        '''
        max_q_len = max_q_len_consider
        max_d_len = max_d_len_consider
        batcher = None
    '''
    train and test the model
    '''
    model_config_ = {
        'max_q_len': max_q_len,
        'max_d_len': max_d_len,
        'max_jump_step': 100,
        'word_vector': wv.get_vectors(normalize=True),
        'oov_word_vector': None,
        'vocab': vocab,
        'word_vector_trainable': False,
        'use_pad_word': True,
        'interaction': 'dot',
        'glimpse': 'all_next_hard',
        'glimpse_fix_size': 10,
        'min_density': -1,
        'use_ratio': False,
        'min_jump_offset': 3,
        'jump': 'min_density_hard',
        'represent': 'interaction_cnn_hard',
        'separate': False,
        'aggregate': 'max',
        'rnn_size': 16,
        'max_jump_offset': 50,
        'max_jump_offset2': max_q_len,
        'rel_level': rel_level,
        'loss_func': 'classification',
        'keep_prob': 1.0,
        'paradigm': args.paradigm,
        'learning_rate': 0.0002,
        'random_seed': SEED,
        'n_epochs': 30,
        'batch_size': 256,
        'batch_num': 400,
        'batcher': batcher,
        'verbose': 1,
        'save_epochs': 1,
        'reuse_model': args.reuse_model_path,
        'save_model': args.save_model_path,
        'summary_path': args.tf_summary_path,
        'tfrecord': args.tfrecord,
    }
    if args.config != None:
        model_config_.update(model_config)
    rri = RRI(**model_config_)
    if not args.tfrecord:
        #train_X, train_y, test_X, test_y = train_X[:2560], train_y[:2560], test_X[:2560], test_y[:2560]
        print('train query: {}, test query: {}'.format(len(train_X),
                                                       len(test_X)))
        for e in rri.fit_iterable(train_X, train_y):
            start = time.time()
            loss, acc = rri.test(test_X, test_y)
            if args.format == 'ir':
                ranks, _ = rri.decision_function(test_X_judge)
                scores = evaluate(ranks, test_qd_judge, metric=ndcg, top_k=20)
                avg_score = np.mean(list(scores.values()))
            elif args.format == 'text':
                avg_score = None
            print('\t{:>7}:{:>5.3f}:{:>5.3f}:{:>5.3f}'.format(
                'test_{:>3.1f}'.format((time.time() - start) / 60), loss, acc,
                avg_score),
                  end='',
                  flush=True)
    else:
        for e in rri.fit_iterable_tfrecord(
                'data/bing/test.prep.pairwise.tfrecord-???-of-???'):
            print(e)
            input()