예제 #1
0
    def __call__(self, data: DFDataChunk, preserve_index=True, return_hash=False):
        hash = None
        if data.cachable():
            rep = {f'tr_{idx}': ("%r" % t) for idx, t in enumerate(self.transforms)}
            hash = Cacher.build_hash(preserve_index, **rep, SRC=data.jit_hash())
            with jit_cacher.instance() as cacher:
                dc = cacher.read_datachunk(hash)
                if dc:
                    if not return_hash:
                        return dc.as_df()
                    else:
                        return dc.as_df(), hash

        data = data.as_df()
        if preserve_index:
            data['index'] = data.index
        for t in self.transforms:
            data = t(data)
            if data.empty:
                LOGGER.warning(f'{t.__class__.__name__} returned empty data. '
                               'Skipping all further transforms')
                if not return_hash:
                    return data
                else:
                    return data, None

        if hash is not None:
            dc = DFDataChunk.from_df(data, hash)
            with jit_cacher.instance() as cacher:
                cacher.store_datachunk(hash, dc)

        if not return_hash:
            return data
        else:
            return data, hash
예제 #2
0
 def refresh_df(self, df_name_to_refresh: str):
     db = self.ds.dataset_name
     hash = Cacher.build_hash(name=df_name_to_refresh, db=db)
     with jit_cacher.instance(self.ds.cacher) as cacher:
         df = cacher.read_df(hash, db=db)
     self.__dfs[df_name_to_refresh] = df
     return df
예제 #3
0
        def __getitem__(self, key):
            if key not in self.__props:
                with jit_cacher.instance(self.ds.cacher) as cacher:
                    value = cacher.read_attr(self.ds.dataset_name, key)
                self.__props[key] = value

            return self.__props[key]
예제 #4
0
 def set_df(self, df_name, df):
     assert self.ds.connected or multiprocessing.parent_process() is None, \
         "Direct write to the metainfo from child process is forbidden and error-prone, " \
         "consider using modify_attr"
     db = self.ds.dataset_name
     hash = Cacher.build_hash(name=df_name, db=db)
     with jit_cacher.instance(self.ds.cacher) as cacher:
         cacher.store_df(hash, df, db=db)
     self.__dfs[df_name] = df
예제 #5
0
        def get_df(self, df_name):
            if df_name not in self.__dfs:
                db = self.ds.dataset_name
                hash = Cacher.build_hash(name=df_name, db=db)
                with jit_cacher.instance(self.ds.cacher) as cacher:
                    df = cacher.read_df(hash, db=db)
                self.__dfs[df_name] = df

            return self.__dfs[df_name]
예제 #6
0
        def update_df(self, df_name, df_update: Callable[[Any, Cacher], pd.DataFrame]):
            db = self.ds.dataset_name
            hash = Cacher.build_hash(name=df_name, db=db)
            with jit_cacher.instance(self.ds.cacher) as cacher:
                df = cacher.read_df(hash, db=db)
                df = df_update(df, cacher)
                cacher.store_df(hash, df, db)

            self.__dfs[df_name] = df

            return df
예제 #7
0
        def __setitem__(self, key, item):
            assert self.ds.connected or multiprocessing.parent_process() is None, \
                "Direct write to the metainfo from child process is forbidden and error-prone, " \
                "consider using modify_attr"

            with jit_cacher.instance(self.ds.cacher) as cacher:
                cacher.store_attr(self.ds.dataset_name, key, item)

            if key in self.__props:
                assert type(item) is type(self.__props[key]), \
                    f"Types are not the same! first {type(item)} second {type(self.__props[key])}"
            self.__props[key] = item
예제 #8
0
    def __init__(self,
                 batch_size: int,
                 dataset_cls: TorchGraphDataset.__class__,
                 collate_fn: Callable,
                 subset_cls: Subset.__class__ = Subset,
                 with_random=True):
        # HACK: determine how to transfer mutex with the pytorch multiprocessing between the process
        lock = multiprocessing.Lock()
        jit_cacher.init_locks(lock)
        # END HACK

        super(GraphDataLoader_NEW, self).__init__(batch_size)
        self.subset_cls = subset_cls
        self.dataset = dataset_cls()
        with jit_cacher.instance() as cacher:
            self.dataset.connect(cacher,
                                 self.dataset.dataset_name,
                                 drop_old=False,
                                 mode='r')

        assert len(
            self.dataset) > 0, f"empty dataset {self.dataset.dataset_name}"
        n_train = int(len(self.dataset) * 0.8)
        n_valid = len(self.dataset) - n_train

        LOGGER.info(
            f"Initializing dataloader with {self.dataset.__class__.__name__}. Len = {len(self.dataset)}."
        )
        LOGGER.info(f"Num train samples: {n_train}")
        LOGGER.info(f"Num valid samples: {n_valid}")

        if with_random:
            self.train_data, self.val_data = random_split(
                self.dataset, [n_train, n_valid])
            self.train_data.__class__ = self.subset_cls
            self.val_data.__class__ = self.subset_cls
        else:
            self.train_data = self.subset_cls(self.dataset, range(0, n_train))
            self.val_data = self.subset_cls(self.dataset,
                                            range(n_train, n_train + n_valid))
        self.collate_fn = collate_fn
예제 #9
0
    def __init__(self,
                 parse_cfg: Dict,
                 global_transformer: Transformer,
                 n_stations: int):

        self.parse_cfg = parse_cfg
        self.loaded_model_state = None
        self.model_loader = None

        self.n_stations = n_stations

        self.global_transformer = global_transformer

        global_lock = multiprocessing.Lock()
        jit_cacher.init_locks(global_lock)

        with jit_cacher.instance() as cacher:
            cacher.init()

        self.data_df_transformed = []
        self.data_df_hashes = []
예제 #10
0
def clean_jit_cache(str_to_clean):
    global cacher, date
    global_lock = multiprocessing.Lock()
    jit_cacher.init_locks(global_lock)
    with jit_cacher.instance() as cacher:
        cacher.init()
    df = cacher.cache

    time_to_clean = str_to_clean[:-1]
    time = int(time_to_clean)
    clear_from = None
    if str_to_clean[-1] == 'w':
        clear_from = datetime.datetime.now() - datetime.timedelta(weeks=time)
    if str_to_clean[-1] == 'm':
        clear_from = datetime.datetime.now() - datetime.timedelta(minutes=time)
    if str_to_clean[-1] == 'd':
        clear_from = datetime.datetime.now() - datetime.timedelta(days=time)
    if str_to_clean[-1] == 'h':
        clear_from = datetime.datetime.now() - datetime.timedelta(hours=time)
    assert clear_from is not None, 'provide time to clean cache in format like "1w"(one week) or "4m" (4 minutes) or ' \
                                   '"3h" (three hours) '
    print(
        f"warning, you are going to clean all cache for the {str_to_clean} starting from the {clear_from} till now"
    )
    all_times = [
        datetime.datetime.fromisoformat(date).timestamp()
        for date in df.date.values
    ]
    removed = []
    for idx, val in enumerate(all_times):
        if val > clear_from.timestamp():
            removed.append(idx)
    if confirm():
        with cacher.handle(cacher.cache_db_path, mode='a') as f:
            for ind in removed:
                row = df.iloc[ind]
                path = row.key
                if path in f:
                    print(f"deleting path {path}")
                    del f[path]
예제 #11
0
파일: graph.py 프로젝트: t3hseus/ariadne
def read_graph_hdf5(db, filename):
    with jit_cacher.instance() as cacher:
        return cacher.read_custom(db, filename, read_graph_hdf5_custom)
예제 #12
0
def read_df_from_hash(hash):
    with jit_cacher.instance() as cacher:
        result_df = cacher.read_df(hash)
        if result_df is not None and not result_df.empty:
            return result_df
    return None
예제 #13
0
 def modify_attr(self, key, update_meth: Callable[[Any], Any]):
     with jit_cacher.instance(self.ds.cacher) as cacher:
         value = cacher.update_attr(self.ds.dataset_name, key, update_meth)
     self.__props[key] = value
     return value
예제 #14
0
def store_df_from_hash(df, hash):
    with jit_cacher.instance() as cacher:
        cacher.store_df(hash, df)
예제 #15
0
 def refresh_attr(self, key_to_refresh: str):
     with jit_cacher.instance(self.ds.cacher) as cacher:
         value = cacher.read_attr(self.ds.dataset_name, key_to_refresh)
     self.__props[key_to_refresh] = value
     return value
예제 #16
0
파일: graph.py 프로젝트: t3hseus/ariadne
def save_graph_hdf5(db, graph, filename):
    with jit_cacher.instance() as cacher:
        cacher.store_custom(db, filename, graph, save_graph_hdf5_custom)
예제 #17
0
    def run(self):
        try:
            jit_cacher.init_locks(self.global_lock)

            with jit_cacher.instance() as cacher:
                cacher.init()
            data_df = cacher.read_df(self.df_hash)
            if data_df.empty:
                self.result_queue.put([self.idx, ''])
                return

            old_path = self.main_dataset.dataset_name
            new_path = old_path + f"_{self.basename}_p{self.idx}"
        except KeyboardInterrupt:
            print(
                f"KeyboardInterrupt in process {os.getpid()}. No result will be returned."
            )
            self.result_queue.put([self.idx, ''])
            jit_cacher.fini_locks()
            return

        current_work = []
        try:
            data_df = data_df[(data_df.event >= self.work_slice[0])
                              & (data_df.event < self.work_slice[1])]
            for ev_id, event in data_df.groupby('event'):
                try:
                    #chunk = DFDataChunk.from_df(event)
                    processed = self.target_processor(event)
                    if processed is None:
                        continue
                    idx = f"graph_{self.basename}_{ev_id}"
                    current_work.append((processed, idx))
                    #postprocessed = self.target_postprocessor(processed, process_dataset, idx)
                    if ev_id % 10 == 0:
                        self.result_queue.put([-1])
                except KeyboardInterrupt:
                    raise KeyboardInterrupt
                except:
                    stack = traceback.format_exc()
                    print(
                        f"Exception in process {os.getpid()}! details below: {stack}"
                    )
                    self.result_queue.put([-2, stack])
                    break
                if not self.message_queue.empty:
                    break
        except KeyboardInterrupt:
            print(f"KeyboardInterrupt in process {os.getpid()}.")
        try:
            process_dataset: AriadneDataset = None
            with self.main_dataset.open_dataset(cacher,
                                                new_path) as process_dataset:
                for (processed, idx) in current_work:
                    self.target_postprocessor(processed, process_dataset, idx)
        except KeyboardInterrupt:
            print(f"KeyboardInterrupt in process {os.getpid()}.")
        try:
            print(
                f"Submitting data to the main storage for process {os.getpid()}..."
            )
            if process_dataset is None:
                self.result_queue.put([self.idx, ''])
                return

            process_dataset.dataset_name = old_path
            process_dataset.local_submit()
        except KeyboardInterrupt:
            print(
                f"KeyboardInterrupt while merging data in process {os.getpid()}. No result will be returned."
            )
            new_path = ''

        # finish signal
        self.result_queue.put([self.idx, new_path])

        jit_cacher.fini_locks()
예제 #18
0
def preprocess_mp(transformer: Transformer,
                  target_processor: IPreprocessor,
                  target_postprocessor: IPostprocessor,
                  target_dataset: AriadneDataset,
                  process_num: int = None,
                  chunk_size: int = 1):
    os.makedirs(f"output/{target_dataset.dataset_name}", exist_ok=True)
    setup_logger(f"output/{target_dataset.dataset_name}",
                 target_processor.__class__.__name__)

    global_lock = multiprocessing.Lock()
    jit_cacher.init_locks(global_lock)

    with jit_cacher.instance() as cacher:
        cacher.init()

    with target_dataset.open_dataset(cacher) as ds:
        target_dataset = ds

    target_dataset.meta["cfg"] = gin.config_str()

    # warnings to exceptions:
    pd.set_option('mode.chained_assignment', 'raise')

    LOGGER.info(
        "GOT config: \n======config======\n %s \n========config=======" %
        gin.config_str())
    process_num = multiprocessing.cpu_count(
    ) if process_num is None else process_num
    LOGGER.info(
        f"Running with the {process_num} processes with chunk_size={chunk_size}"
    )

    for data_df, basename, df_hash in parse():
        LOGGER.info("[Preprocess]: started processing a df with %d rows:" %
                    len(data_df))

        data_df, hash = transformer(DFDataChunk.from_df(data_df, df_hash),
                                    return_hash=True)
        event_count = data_df.event.nunique()
        events = sorted(list(data_df.event.unique()))
        chunk_size = event_count // process_num
        if event_count // process_num == 0:
            process_num = 1
            chunk_size = 1

        result_queue = multiprocessing.Queue()
        message_queue = multiprocessing.Queue()
        workers = []
        workers_result = [""] * process_num
        for i in range(0, process_num):

            if i == process_num - 1:
                work_slice = (events[i * chunk_size], 1e10)
            else:
                work_slice = (events[i * chunk_size],
                              events[(i + 1) * chunk_size])

            workers.append(
                EventProcessor(hash, basename, target_processor,
                               target_postprocessor, target_dataset,
                               work_slice, i, result_queue, message_queue,
                               global_lock))
            workers[-1].start()
        canceled = False
        try:
            with tqdm(total=len(events)) as pbar:
                while any(workers):
                    obj = result_queue.get()
                    if obj[0] == -1:
                        pbar.update(n=10)
                    elif obj[0] == -2:
                        LOGGER.info(f"Process got exception: {obj[1]}.")
                        return
                    else:
                        pbar.update()
                        LOGGER.debug(
                            f"Process idx={obj} has finished processing. joining..."
                        )
                        workers[obj[0]].join()
                        workers[obj[0]].close()
                        workers[obj[0]] = False
                        workers_result[obj[0]] = obj[1]
        except KeyboardInterrupt:
            LOGGER.info("KeyboardInterrupt! terminating all processes....")
            message_queue.put(1)
            try:
                [worker.join() for worker in workers if worker]
            except KeyboardInterrupt:
                LOGGER.info("KeyboardInterrupt! seems like a deadlock...")
            canceled = True

        LOGGER.info("Finished processing. Merging results....")

        while not result_queue.empty():
            try:
                obj = result_queue.get()
            except EOFError:
                LOGGER.info("Weird EOFError...")
                break
            if obj[0] >= 0:
                workers_result[obj[0]] = obj[1]

        for worker_id, worker_result in enumerate(workers_result):
            if worker_result == "":
                LOGGER.info(f"Worker {worker_id} failed...")

        workers_result = [
            worker_result for worker_result in workers_result
            if worker_result != ''
        ]

        with target_dataset.open_dataset(cacher,
                                         target_dataset.dataset_name,
                                         drop_old=False) as ds:
            ds.global_submit(workers_result)

        if canceled:
            break