Beispiel #1
0
 def from_pretrained(cls, num_heads, vocabs, input_dir, nprobe, topk, gpuid, use_response_encoder=False):
     model_args = torch.load(os.path.join(input_dir, 'args'))
     model = MultiProjEncoder.from_pretrained_projencoder(num_heads, vocabs['src'], model_args, os.path.join(input_dir, 'query_encoder'))
     mem_pool = [line.strip().split() for line in open(os.path.join(input_dir, 'candidates.txt')).readlines()]
    
     if use_response_encoder:
         mem_feat_or_feat_maker = ProjEncoder.from_pretrained(vocabs['tgt'], model_args, os.path.join(input_dir, 'response_encoder'))
     else:
         mem_feat_or_feat_maker = torch.load(os.path.join(input_dir, 'feat.pt'))
     
     mips = MIPS.from_built(os.path.join(input_dir, 'mips_index'), nprobe=nprobe)
     mips_max_norm = torch.load(os.path.join(input_dir, 'max_norm.pt'))
     retriever = cls(vocabs, model, mips, mips_max_norm, mem_pool, mem_feat_or_feat_maker, num_heads, topk, gpuid)
     return retriever
Beispiel #2
0
def mem_test():
    m = MIPS(0, [])
    m.parse_instruction("lui r1, 10")
    m.parse_instruction("sw r1, 10(r0)")

    assert (m.ram[10] == 10)
    assert (m.registers[1] == 10)

    print "Memory tests passed"
Beispiel #3
0
def main(args):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger.info('Loading model...')
    logger.info("using %d gpus", torch.cuda.device_count())
    device = torch.device('cuda', 0)

    vocab = Vocab(args.vocab_path, 0, [BOS, EOS])
    model_args = torch.load(args.args_path)
    model = ProjEncoder.from_pretrained(vocab, model_args, args.ckpt_path)
    model.to(device)

    logger.info('Collecting data...')
    data = []
    line_id = -1
    with open(args.input_file) as f:
        for line in f.readlines():
            r = line.strip()
            line_id += 1
            data.append([r, line_id])

    if args.only_dump_feat:
        max_norm = torch.load(
            os.path.join(os.path.dirname(args.index_path), 'max_norm.pt'))
        used_data = [x[0] for x in data]
        used_ids = np.array([x[1] for x in data])
        used_data, used_ids, _ = get_features(args.batch_size, args.norm_th,
                                              vocab, model, used_data,
                                              used_ids, max_norm)
        used_data = used_data[:, 1:]
        assert (used_ids == np.sort(used_ids)).all()
        logger.info('Dumping %d instances', used_data.shape[0])
        torch.save(torch.from_numpy(used_data),
                   os.path.join(os.path.dirname(args.index_path), 'feat.pt'))
        exit(0)

    logger.info('Collected %d instances', len(data))
    max_norm = args.max_norm
    if args.train_index:
        random.shuffle(data)
        used_data = [x[0] for x in data[:args.max_training_instances]]
        used_ids = np.array([x[1] for x in data[:args.max_training_instances]])
        logger.info('Computing feature for training')
        used_data, used_ids, max_norm = get_features(
            args.batch_size,
            args.norm_th,
            vocab,
            model,
            used_data,
            used_ids,
            max_norm_cf=args.max_norm_cf)
        logger.info('Using %d instances for training', used_data.shape[0])
        mips = MIPS(model_args.output_dim + 1,
                    args.index_type,
                    efSearch=args.efSearch,
                    nprobe=args.nprobe)
        mips.to_gpu()
        mips.train(used_data)
        mips.to_cpu()
        if args.add_to_index:
            mips.add_with_ids(used_data, used_ids)
            data = data[args.max_training_instances:]
        mips.save(args.index_path)
        torch.save(
            max_norm,
            os.path.join(os.path.dirname(args.index_path), 'max_norm.pt'))
    else:
        mips = MIPS.from_built(args.index_path, nprobe=args.nprobe)
        max_norm = torch.load(
            os.path.join(os.path.dirname(args.index_path), 'max_norm.pt'))

    if args.add_to_index:
        cur = 0
        while cur < len(data):
            used_data = [x[0] for x in data[cur:cur + args.add_every]]
            used_ids = np.array([x[1] for x in data[cur:cur + args.add_every]])
            cur += args.add_every
            logger.info('Computing feature for indexing')
            used_data, used_ids, _ = get_features(args.batch_size,
                                                  args.norm_th, vocab, model,
                                                  used_data, used_ids,
                                                  max_norm)
            logger.info('Adding %d instances to index', used_data.shape[0])
            mips.add_with_ids(used_data, used_ids)
        mips.save(args.index_path)
Beispiel #4
0
def main(args):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger.info('Loading model...')
    device = torch.device('cuda', 0)

    vocab = Vocab(args.vocab_path, 0, [BOS, EOS])
    model_args = torch.load(args.args_path)
    model = ProjEncoder.from_pretrained(vocab, model_args, args.ckpt_path)
    model.to(device)

    logger.info('Collecting data...')

    data_r = []
    with open(args.index_file) as f:
        for line in f.readlines():
            r = line.strip()
            data_r.append(r)

    data_q = []
    data_qr = []
    with open(args.input_file, 'r') as f:
        for line in f.readlines():
            q, r = line.strip().split('\t')
            data_q.append(q)
            data_qr.append(r)

    logger.info('Collected %d instances', len(data_q))
    textq, textqr, textr = data_q, data_qr, data_r
    data_loader = DataLoader(data_q, vocab, args.batch_size)

    mips = MIPS.from_built(args.index_path, nprobe=args.nprobe)
    max_norm = torch.load(os.path.dirname(args.index_path) + '/max_norm.pt')
    mips.to_gpu()
    model.cuda()
    model = torch.nn.DataParallel(model,
                                  device_ids=list(
                                      range(torch.cuda.device_count())))
    model.eval()

    logger.info('Start search')
    cur, tot = 0, len(data_q)
    with open(args.output_file, 'w') as fo:
        for batch in asynchronous_load(data_loader):
            with torch.no_grad():
                q = move_to_device(batch, torch.device('cuda')).t()
                bsz = q.size(0)
                vecsq = model(q, batch_first=True).detach().cpu().numpy()
            vecsq = augment_query(vecsq)
            D, I = mips.search(vecsq, args.topk + 1)
            D = l2_to_ip(D, vecsq, max_norm) / (max_norm * max_norm)
            for i, (Ii, Di) in enumerate(zip(I, D)):
                item = [textq[cur + i], textqr[cur + i]]
                for pred, s in zip(Ii, Di):
                    if args.allow_hit or textr[pred] != textqr[cur + i]:
                        item.append(textr[pred])
                        item.append(str(float(s)))
                item = item[:2 + 2 * args.topk]
                assert len(item) == 2 + 2 * args.topk
                fo.write('\t'.join(item) + '\n')
            cur += bsz
            logger.info('finished %d / %d', cur, tot)
Beispiel #5
0
 def rebuild_index(self, index_dir, batch_size=2048, add_every=1000000, index_type='IVF1024_HNSW32,SQ8', norm_th=999, max_training_instances=1000000, max_norm_cf=1.0, nprobe=64, efSearch=128):
     if not os.path.exists(index_dir):
         os.mkdir(index_dir)
     max_norm = None
     data = [ [' '.join(x), i] for i, x in enumerate(self.mem_pool) ]
     random.shuffle(data)
     used_data = [x[0] for x in data[:max_training_instances]]
     used_ids = np.array([x[1] for x in data[:max_training_instances]])
     logger.info('Computing feature for training')
     used_data, used_ids, max_norm = get_features(batch_size, norm_th, self.vocabs['tgt'], self.mem_feat_or_feat_maker, used_data, used_ids, max_norm_cf=max_norm_cf)
     torch.cuda.empty_cache()
     logger.info('Using %d instances for training', used_data.shape[0])
     mips = MIPS(self.model.output_dim+1, index_type, efSearch=efSearch, nprobe=nprobe) 
     mips.to_gpu()
     mips.train(used_data)
     mips.to_cpu()
     mips.add_with_ids(used_data, used_ids)
     data = data[max_training_instances:]
     torch.save(max_norm, os.path.join(index_dir, 'max_norm.pt'))
     
     cur = 0
     while cur < len(data):
         used_data = [x[0] for x in data[cur:cur+add_every]]
         used_ids = np.array([x[1] for x in data[cur:cur+add_every]])
         cur += add_every
         logger.info('Computing feature for indexing')
         used_data, used_ids, _ = get_features(batch_size, norm_th, vocab, self.mem_feat_or_feat_maker, used_data, used_ids, max_norm)
         logger.info('Adding %d instances to index', used_data.shape[0])
         mips.add_with_ids(used_data, used_ids)
     mips.save(os.path.join(index_dir, 'mips_index'))
Beispiel #6
0
 def update_index(self, index_dir, nprobe):
     self.mips = MIPS.from_built(os.path.join(index_dir, 'mips_index'), nprobe=nprobe)
     if self.gpuid >= 0:
         self.mips.to_gpu(gpuid=self.gpuid)
     self.mips_max_norm = torch.load(os.path.join(index_dir, 'max_norm.pt'))
Beispiel #7
0
from mips import MIPS
import re

m = MIPS(0, [])

# Citanje podatkovne memorije
for line in open("ram.txt", "r"):
    words = re.compile('\w+').findall(line)
    if len(words) == 2:
        m.ram[int(words[0])] = int(words[1])

# Citanje stanja registara
for line in open("reg.txt", "r"):
    words = re.compile('\w+').findall(line)
    if len(words) == 2:
        m.registers[int(words[0])] = int(words[1])

# Citanje, izvrsavanje i analiza instrukcija
for line in open("instr.txt", "r"):
    m.parse_instruction(line)

# Sljedece linije zakomentarisati/odkomentarisati po potrebi:

# Ispisivanje stanja podatkovne memorije nakon svake instrukcije
for i, x in enumerate(m.reg_hist):
    print "register state after instruction " + str(i)
    x.print_touched()

# Ispisivanje stanja registara nakon svake instrukcije
for i, x in enumerate(m.ram_hist):
    print "ram state after instruction " + str(i)
Beispiel #8
0
from mips import MIPS

m = MIPS(0, [])

m.parse_instruction("add r6, r16, r5")
m.parse_instruction("sub r1, r6, r3")
m.parse_instruction("or r8, r6, r2")
m.parse_instruction("lw r1, 100(r7)")
m.parse_instruction("xor r10, r1, r11")
m.parse_instruction("sw r3, 50(r0)")
m.parse_instruction("nop")

for i, x in enumerate(m.reg_hist):
    print "register state after instruction " + str(i)
    x.print_touched()

for i, x in enumerate(m.ram_hist):
    print "ram state after instruction " + str(i)
    x.print_touched()

print "Nacin tumacenja ovog testa:\n" \
        + "Ispisuju se samo oni registri i samo one celije RAM-a nad\n" \
        + "kojima se vrsila operacija upisivanja. Ako neki registar\n" \
        + "ako neki RAM nije ispisan u konzoli pri ispisivanju stanja\n" \
        + "registara/rama nakon neke instrukcije, pretpostavlja se da\n" \
        + "oni imaju defaultnu vrijednost(0)."


m.print_table()
Beispiel #9
0
def run_pred(args):
    if not os.path.exists(args.pred_dir):
        os.makedirs(args.pred_dir)

    with open(args.data_path, 'r') as fp:
        test_data = json.load(fp)
    pairs = []
    qid2text = {}
    for doc_idx, article in enumerate(test_data['data']):
        for para_idx, paragraph in enumerate(article['paragraphs']):
            for qa in paragraph['qas']:
                id_ = qa['id']
                question = qa['question']
                qid2text[id_] = question
                pairs.append([doc_idx, para_idx, id_, question])

    with h5py.File(args.question_dump_path, 'r') as question_dump:
        vecs = []
        q_texts = []
        for doc_idx, para_idx, id_, question in tqdm(pairs):
            vec = question_dump[id_][0, :]
            vecs.append(vec)

            if args.sparse:
                q_texts.append(qid2text[id_])

        query = np.stack(vecs, 0)
        if args.draft:
            query = query[:3]

    if not args.sparse:
        mips = MIPS(args.phrase_dump_dir,
                    args.index_path,
                    args.idx2id_path,
                    args.max_answer_length,
                    para=args.para,
                    num_dummy_zeros=args.num_dummy_zeros,
                    cuda=args.cuda)
    else:
        mips = MIPSSparse(args.phrase_dump_dir,
                          args.index_path,
                          args.idx2id_path,
                          args.ranker_path,
                          args.max_answer_length,
                          para=args.para,
                          tfidf_dump_dir=args.tfidf_dump_dir,
                          sparse_weight=args.sparse_weight,
                          sparse_type=args.sparse_type,
                          cuda=args.cuda,
                          max_norm_path=args.max_norm_path)

    # recall at k
    cd_results = []
    od_results = []
    step_size = args.step_size
    is_ = range(0, query.shape[0], step_size)
    #is_ = range(0, 500, step_size)
    for i in tqdm(is_):
        each_query = query[i:i + step_size]
        if args.sparse:
            each_q_text = q_texts[i:i + step_size]

        if args.no_od:
            doc_idxs, para_idxs, _, _ = zip(*pairs[i:i + step_size])
            if not args.sparse:
                each_results = mips.search(each_query,
                                           top_k=args.top_k,
                                           doc_idxs=doc_idxs,
                                           para_idxs=para_idxs)
            else:
                each_results = mips.search(each_query,
                                           top_k=args.top_k,
                                           doc_idxs=doc_idxs,
                                           para_idxs=para_idxs,
                                           start_top_k=args.start_top_k,
                                           q_texts=each_q_text)
            cd_results.extend(each_results)

        else:
            if not args.sparse:
                each_results = mips.search(each_query,
                                           top_k=args.top_k,
                                           nprobe=args.nprobe)
            else:
                each_results = mips.search(
                    each_query,
                    top_k=args.top_k,
                    nprobe=args.nprobe,
                    mid_top_k=args.mid_top_k,
                    start_top_k=args.start_top_k,
                    q_texts=each_q_text,
                    filter_=args.filter,
                    search_strategy=args.search_strategy,
                    doc_top_k=args.doc_top_k)
            od_results.extend(each_results)

    top_k_answers = {
        query_id: [(result["score"], result['answer'], result["context"])
                   for result in each_results]
        for (_, _, query_id, _), each_results in zip(pairs, od_results)
    }
    answers = {
        query_id: each_results[0]['answer']
        for (_, _, query_id, _), each_results in zip(pairs, cd_results)
    }

    if args.para:
        print('dumping %s' % args.cd_out_path)
        with open(args.cd_out_path, 'w') as fp:
            json.dump(answers, fp)

    print('dumping %s' % args.od_out_path)
    with open(args.od_out_path, 'w') as fp:
        json.dump(top_k_answers, fp)

    from collections import Counter
    counter = Counter(result['doc_idx'] for each in od_results
                      for result in each)
    with open(args.counter_path, 'w') as fp:
        json.dump(counter, fp)