Ejemplo n.º 1
0
    def __init__(self,
                 batch_size: int,
                 dataset_cls: TorchGraphDataset.__class__,
                 collate_fn: Callable,
                 subset_cls: Subset.__class__ = Subset,
                 with_random=True):
        # HACK: determine how to transfer mutex with the pytorch multiprocessing between the process
        lock = multiprocessing.Lock()
        jit_cacher.init_locks(lock)
        # END HACK

        super(GraphDataLoader_NEW, self).__init__(batch_size)
        self.subset_cls = subset_cls
        self.dataset = dataset_cls()
        with jit_cacher.instance() as cacher:
            self.dataset.connect(cacher,
                                 self.dataset.dataset_name,
                                 drop_old=False,
                                 mode='r')

        assert len(
            self.dataset) > 0, f"empty dataset {self.dataset.dataset_name}"
        n_train = int(len(self.dataset) * 0.8)
        n_valid = len(self.dataset) - n_train

        LOGGER.info(
            f"Initializing dataloader with {self.dataset.__class__.__name__}. Len = {len(self.dataset)}."
        )
        LOGGER.info(f"Num train samples: {n_train}")
        LOGGER.info(f"Num valid samples: {n_valid}")

        if with_random:
            self.train_data, self.val_data = random_split(
                self.dataset, [n_train, n_valid])
            self.train_data.__class__ = self.subset_cls
            self.val_data.__class__ = self.subset_cls
        else:
            self.train_data = self.subset_cls(self.dataset, range(0, n_train))
            self.val_data = self.subset_cls(self.dataset,
                                            range(n_train, n_train + n_valid))
        self.collate_fn = collate_fn
Ejemplo n.º 2
0
    def __init__(self,
                 parse_cfg: Dict,
                 global_transformer: Transformer,
                 n_stations: int):

        self.parse_cfg = parse_cfg
        self.loaded_model_state = None
        self.model_loader = None

        self.n_stations = n_stations

        self.global_transformer = global_transformer

        global_lock = multiprocessing.Lock()
        jit_cacher.init_locks(global_lock)

        with jit_cacher.instance() as cacher:
            cacher.init()

        self.data_df_transformed = []
        self.data_df_hashes = []
Ejemplo n.º 3
0
def clean_jit_cache(str_to_clean):
    global cacher, date
    global_lock = multiprocessing.Lock()
    jit_cacher.init_locks(global_lock)
    with jit_cacher.instance() as cacher:
        cacher.init()
    df = cacher.cache

    time_to_clean = str_to_clean[:-1]
    time = int(time_to_clean)
    clear_from = None
    if str_to_clean[-1] == 'w':
        clear_from = datetime.datetime.now() - datetime.timedelta(weeks=time)
    if str_to_clean[-1] == 'm':
        clear_from = datetime.datetime.now() - datetime.timedelta(minutes=time)
    if str_to_clean[-1] == 'd':
        clear_from = datetime.datetime.now() - datetime.timedelta(days=time)
    if str_to_clean[-1] == 'h':
        clear_from = datetime.datetime.now() - datetime.timedelta(hours=time)
    assert clear_from is not None, 'provide time to clean cache in format like "1w"(one week) or "4m" (4 minutes) or ' \
                                   '"3h" (three hours) '
    print(
        f"warning, you are going to clean all cache for the {str_to_clean} starting from the {clear_from} till now"
    )
    all_times = [
        datetime.datetime.fromisoformat(date).timestamp()
        for date in df.date.values
    ]
    removed = []
    for idx, val in enumerate(all_times):
        if val > clear_from.timestamp():
            removed.append(idx)
    if confirm():
        with cacher.handle(cacher.cache_db_path, mode='a') as f:
            for ind in removed:
                row = df.iloc[ind]
                path = row.key
                if path in f:
                    print(f"deleting path {path}")
                    del f[path]
Ejemplo n.º 4
0
    def run(self):
        try:
            jit_cacher.init_locks(self.global_lock)

            with jit_cacher.instance() as cacher:
                cacher.init()
            data_df = cacher.read_df(self.df_hash)
            if data_df.empty:
                self.result_queue.put([self.idx, ''])
                return

            old_path = self.main_dataset.dataset_name
            new_path = old_path + f"_{self.basename}_p{self.idx}"
        except KeyboardInterrupt:
            print(
                f"KeyboardInterrupt in process {os.getpid()}. No result will be returned."
            )
            self.result_queue.put([self.idx, ''])
            jit_cacher.fini_locks()
            return

        current_work = []
        try:
            data_df = data_df[(data_df.event >= self.work_slice[0])
                              & (data_df.event < self.work_slice[1])]
            for ev_id, event in data_df.groupby('event'):
                try:
                    #chunk = DFDataChunk.from_df(event)
                    processed = self.target_processor(event)
                    if processed is None:
                        continue
                    idx = f"graph_{self.basename}_{ev_id}"
                    current_work.append((processed, idx))
                    #postprocessed = self.target_postprocessor(processed, process_dataset, idx)
                    if ev_id % 10 == 0:
                        self.result_queue.put([-1])
                except KeyboardInterrupt:
                    raise KeyboardInterrupt
                except:
                    stack = traceback.format_exc()
                    print(
                        f"Exception in process {os.getpid()}! details below: {stack}"
                    )
                    self.result_queue.put([-2, stack])
                    break
                if not self.message_queue.empty:
                    break
        except KeyboardInterrupt:
            print(f"KeyboardInterrupt in process {os.getpid()}.")
        try:
            process_dataset: AriadneDataset = None
            with self.main_dataset.open_dataset(cacher,
                                                new_path) as process_dataset:
                for (processed, idx) in current_work:
                    self.target_postprocessor(processed, process_dataset, idx)
        except KeyboardInterrupt:
            print(f"KeyboardInterrupt in process {os.getpid()}.")
        try:
            print(
                f"Submitting data to the main storage for process {os.getpid()}..."
            )
            if process_dataset is None:
                self.result_queue.put([self.idx, ''])
                return

            process_dataset.dataset_name = old_path
            process_dataset.local_submit()
        except KeyboardInterrupt:
            print(
                f"KeyboardInterrupt while merging data in process {os.getpid()}. No result will be returned."
            )
            new_path = ''

        # finish signal
        self.result_queue.put([self.idx, new_path])

        jit_cacher.fini_locks()
Ejemplo n.º 5
0
def preprocess_mp(transformer: Transformer,
                  target_processor: IPreprocessor,
                  target_postprocessor: IPostprocessor,
                  target_dataset: AriadneDataset,
                  process_num: int = None,
                  chunk_size: int = 1):
    os.makedirs(f"output/{target_dataset.dataset_name}", exist_ok=True)
    setup_logger(f"output/{target_dataset.dataset_name}",
                 target_processor.__class__.__name__)

    global_lock = multiprocessing.Lock()
    jit_cacher.init_locks(global_lock)

    with jit_cacher.instance() as cacher:
        cacher.init()

    with target_dataset.open_dataset(cacher) as ds:
        target_dataset = ds

    target_dataset.meta["cfg"] = gin.config_str()

    # warnings to exceptions:
    pd.set_option('mode.chained_assignment', 'raise')

    LOGGER.info(
        "GOT config: \n======config======\n %s \n========config=======" %
        gin.config_str())
    process_num = multiprocessing.cpu_count(
    ) if process_num is None else process_num
    LOGGER.info(
        f"Running with the {process_num} processes with chunk_size={chunk_size}"
    )

    for data_df, basename, df_hash in parse():
        LOGGER.info("[Preprocess]: started processing a df with %d rows:" %
                    len(data_df))

        data_df, hash = transformer(DFDataChunk.from_df(data_df, df_hash),
                                    return_hash=True)
        event_count = data_df.event.nunique()
        events = sorted(list(data_df.event.unique()))
        chunk_size = event_count // process_num
        if event_count // process_num == 0:
            process_num = 1
            chunk_size = 1

        result_queue = multiprocessing.Queue()
        message_queue = multiprocessing.Queue()
        workers = []
        workers_result = [""] * process_num
        for i in range(0, process_num):

            if i == process_num - 1:
                work_slice = (events[i * chunk_size], 1e10)
            else:
                work_slice = (events[i * chunk_size],
                              events[(i + 1) * chunk_size])

            workers.append(
                EventProcessor(hash, basename, target_processor,
                               target_postprocessor, target_dataset,
                               work_slice, i, result_queue, message_queue,
                               global_lock))
            workers[-1].start()
        canceled = False
        try:
            with tqdm(total=len(events)) as pbar:
                while any(workers):
                    obj = result_queue.get()
                    if obj[0] == -1:
                        pbar.update(n=10)
                    elif obj[0] == -2:
                        LOGGER.info(f"Process got exception: {obj[1]}.")
                        return
                    else:
                        pbar.update()
                        LOGGER.debug(
                            f"Process idx={obj} has finished processing. joining..."
                        )
                        workers[obj[0]].join()
                        workers[obj[0]].close()
                        workers[obj[0]] = False
                        workers_result[obj[0]] = obj[1]
        except KeyboardInterrupt:
            LOGGER.info("KeyboardInterrupt! terminating all processes....")
            message_queue.put(1)
            try:
                [worker.join() for worker in workers if worker]
            except KeyboardInterrupt:
                LOGGER.info("KeyboardInterrupt! seems like a deadlock...")
            canceled = True

        LOGGER.info("Finished processing. Merging results....")

        while not result_queue.empty():
            try:
                obj = result_queue.get()
            except EOFError:
                LOGGER.info("Weird EOFError...")
                break
            if obj[0] >= 0:
                workers_result[obj[0]] = obj[1]

        for worker_id, worker_result in enumerate(workers_result):
            if worker_result == "":
                LOGGER.info(f"Worker {worker_id} failed...")

        workers_result = [
            worker_result for worker_result in workers_result
            if worker_result != ''
        ]

        with target_dataset.open_dataset(cacher,
                                         target_dataset.dataset_name,
                                         drop_old=False) as ds:
            ds.global_submit(workers_result)

        if canceled:
            break