Exemplo n.º 1
0
 def build_random_buffer(self, num_samples):
     n0, n1 = [int(s) for s in num_samples.split(',')][:2]
     ret = []
     max_blk_num = CAPACITY // (BLOCK_SIZE + 1)
     logging.info('building buffers for introspection...')
     for qbuf, dbuf in tqdm(self.dataset):
         # 1. continous
         lb = max_blk_num - len(qbuf)
         st = random.randint(0, max(0, len(dbuf) - lb * n0))
         for i in range(n0):
             buf = Buffer()
             buf.blocks = qbuf.blocks + dbuf.blocks[st + i * lb:st +
                                                    (i + 1) * lb]
             ret.append(buf)
         # 2. pos + neg
         pbuf, nbuf = dbuf.filtered(lambda blk, idx: blk.relevance >= 1,
                                    need_residue=True)
         for i in range(n1):
             selected_pblks = random.sample(pbuf.blocks, min(lb, len(pbuf)))
             selected_nblks = random.sample(
                 nbuf.blocks, min(lb - len(selected_pblks), len(nbuf)))
             buf = Buffer()
             buf.blocks = qbuf.blocks + selected_pblks + selected_nblks
             ret.append(buf.sort_())
     return SimpleListDataset(ret)
Exemplo n.º 2
0
 def build_promising_buffer(self, num_samples):
     n2, n3 = [int(x) for x in num_samples.split(',')][2:]
     ret = []
     max_blk_num = CAPACITY // (BLOCK_SIZE + 1)
     logging.info('building buffers for reasoning...')
     for qbuf, dbuf in tqdm(self.dataset):
         #1. retrieve top n2*(max-len(pos)) estimations into buf 2. cut
         pbuf, nbuf = dbuf.filtered(lambda blk, idx: blk.relevance >= 1,
                                    need_residue=True)
         if len(pbuf) >= max_blk_num - len(qbuf):
             pbuf = pbuf.random_sample(max_blk_num - len(qbuf) - 1)
         lb = max_blk_num - len(qbuf) - len(pbuf)
         estimations = torch.tensor([blk.estimation for blk in nbuf],
                                    dtype=torch.long)
         keeped_indices = estimations.argsort(descending=True)[:n2 * lb]
         selected_nblks = [
             blk for i, blk in enumerate(nbuf) if i in keeped_indices
         ]
         while 0 < len(selected_nblks) < n2 * lb:
             selected_nblks = selected_nblks * (
                 n2 * lb // len(selected_nblks) + 1)
         for i in range(n2):
             buf = Buffer()
             buf.blocks = qbuf.blocks + pbuf.blocks + selected_nblks[
                 i * lb:(i + 1) * lb]
             ret.append(buf.sort_())
         for i in range(n3):
             buf = Buffer()
             buf.blocks = qbuf.blocks + pbuf.blocks + random.sample(
                 nbuf.blocks, min(len(nbuf), lb))
             ret.append(buf.sort_())
     return SimpleListDataset(ret)