예제 #1
0
def dump_encodings(data: Collection[OpTrace],
                   encoder: PandasGraphEncoder,
                   sid: str,
                   path: str = None):
    if isinstance(data, IndexedFileReader):
        if path is None:
            path = f"{data.path}.encoded"

        if os.path.exists(path):
            return IndexedFileReader(path)

    if path is None:
        path = 'train.pkl'

    encoding_file = IndexedFileWriter(path)
    encoder_func = encoder.get_encoder(sid)
    for op in data:
        encoding_file.append(
            encoder_func(domain=op.domain,
                         context=op.context,
                         choice=op.choice,
                         sid=op.op_info.sid))

    encoding_file.close()
    return IndexedFileReader(path)
예제 #2
0
def dump_encodings(data: Collection[OpTrace], encode_func: Callable, path: str = None):
    global global_encode_func
    if isinstance(data, IndexedFileReader):
        if path is None:
            path = f"{data.path}.encoded"

        if os.path.exists(path):
            print(f"- Use cached encoding {path}")
            return IndexedFileReader(path)

    path = path or "data.encoded"

    global_encode_func = encode_func
    #pool = multiprocessing.Pool()

    def patch_func(func, data):
        ret = []
        for x in data:
            ret.append(func(x))
        return ret
    #pool.map = patch_func

    encoded_graphs = patch_func(encode_op, data)
    #del pool

    encoding_file = IndexedFileWriter(path)
    for graph in encoded_graphs:
        encoding_file.append(graph)

    encoding_file.close()
    return IndexedFileReader(path)
예제 #3
0
파일: imitation.py 프로젝트: qinxuye/atlas
    def create_operator_datasets(
            self,
            traces: Collection[GeneratorTrace],
            mode: str = 'training') -> Dict[OpInfo, Collection[OpTrace]]:
        if self.USE_DISK:
            file_maps: Dict[str, IndexedFileWriter] = {}
            path_maps: Dict[str, str] = {}
            for trace in tqdm.tqdm(traces):
                for op in trace.op_traces:
                    op_info = op.op_info
                    if op_info not in file_maps:
                        path = f"{self.work_dir}/data/{op_info.sid}"
                        os.makedirs(path, exist_ok=True)
                        file_maps[op_info] = IndexedFileWriter(
                            f"{path}/{mode}_op_data.pkl")
                        path_maps[op_info] = f"{path}/{mode}_op_data.pkl"

                    file_maps[op_info].append(op)

            for v in file_maps.values():
                v.close()

            return {k: IndexedFileReader(v) for k, v in path_maps.items()}

        else:
            data: Dict[OpInfo, List[OpTrace]] = collections.defaultdict(list)
            for trace in tqdm.tqdm(traces):
                for op in trace.op_traces:
                    data[op.op_info].append(op)

            return data
예제 #4
0
파일: imitation.py 프로젝트: SambhavS/atlas
    def create_operator_datasets(self, traces: Collection[GeneratorTrace],
                                 mode: str = 'training') -> Dict[str, Collection[OpTrace]]:
        file_maps: Dict[str, IndexedFileWriter] = {}
        path_maps: Dict[str, str] = {}
        for trace in tqdm.tqdm(traces):
            for op in trace.op_traces:
                sid = op.op_info.sid
                if sid not in file_maps:
                    path = f"{self.work_dir}/data/{sid}"
                    os.makedirs(path, exist_ok=True)
                    file_maps[sid] = IndexedFileWriter(f"{path}/{mode}_op_data.pkl")
                    path_maps[sid] = f"{path}/{mode}_op_data.pkl"

                file_maps[sid].append(op)

        for v in file_maps.values():
            v.close()

        return {k: IndexedFileReader(v) for k, v in path_maps.items()}
예제 #5
0
파일: imitation.py 프로젝트: qinxuye/atlas
 def load_operator_datasets(self, path_maps: Dict[str, str]):
     return {k: IndexedFileReader(v) for k, v in path_maps.items()}