def build_random_buffer(self, num_samples): n0, n1 = [int(s) for s in num_samples.split(',')][:2] ret = [] max_blk_num = CAPACITY // (BLOCK_SIZE + 1) logging.info('building buffers for introspection...') for qbuf, dbuf in tqdm(self.dataset): # 1. continous lb = max_blk_num - len(qbuf) st = random.randint(0, max(0, len(dbuf) - lb * n0)) for i in range(n0): buf = Buffer() buf.blocks = qbuf.blocks + dbuf.blocks[st + i * lb:st + (i + 1) * lb] ret.append(buf) # 2. pos + neg pbuf, nbuf = dbuf.filtered(lambda blk, idx: blk.relevance >= 1, need_residue=True) for i in range(n1): selected_pblks = random.sample(pbuf.blocks, min(lb, len(pbuf))) selected_nblks = random.sample( nbuf.blocks, min(lb - len(selected_pblks), len(nbuf))) buf = Buffer() buf.blocks = qbuf.blocks + selected_pblks + selected_nblks ret.append(buf.sort_()) return SimpleListDataset(ret)
def build_promising_buffer(self, num_samples): n2, n3 = [int(x) for x in num_samples.split(',')][2:] ret = [] max_blk_num = CAPACITY // (BLOCK_SIZE + 1) logging.info('building buffers for reasoning...') for qbuf, dbuf in tqdm(self.dataset): #1. retrieve top n2*(max-len(pos)) estimations into buf 2. cut pbuf, nbuf = dbuf.filtered(lambda blk, idx: blk.relevance >= 1, need_residue=True) if len(pbuf) >= max_blk_num - len(qbuf): pbuf = pbuf.random_sample(max_blk_num - len(qbuf) - 1) lb = max_blk_num - len(qbuf) - len(pbuf) estimations = torch.tensor([blk.estimation for blk in nbuf], dtype=torch.long) keeped_indices = estimations.argsort(descending=True)[:n2 * lb] selected_nblks = [ blk for i, blk in enumerate(nbuf) if i in keeped_indices ] while 0 < len(selected_nblks) < n2 * lb: selected_nblks = selected_nblks * ( n2 * lb // len(selected_nblks) + 1) for i in range(n2): buf = Buffer() buf.blocks = qbuf.blocks + pbuf.blocks + selected_nblks[ i * lb:(i + 1) * lb] ret.append(buf.sort_()) for i in range(n3): buf = Buffer() buf.blocks = qbuf.blocks + pbuf.blocks + random.sample( nbuf.blocks, min(len(nbuf), lb)) ret.append(buf.sort_()) return SimpleListDataset(ret)