Пример #1
0
def test_maxsize_value():
    """Limited Caching."""
    ac = AnyCache(maxsize=None)

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        sleep(3)  # wait for slow windows file system
        return posarg + kwarg

    myfunc.callcount = 0

    eq_(myfunc(4, 2), 6)
    size1 = ac.size
    n = 3
    calls = n * n

    for maxsize in (5 * size1, 8 * size1):
        ac.clear()
        myfunc.callcount = 0
        ac.maxsize = maxsize

        for posarg in range(n):
            for kwarg in range(n):
                eq_(myfunc(posarg, kwarg), posarg + kwarg)
        eq_(maxsize, ac.size)
        eq_(myfunc.callcount, calls)
        # last should be in cache
        eq_(myfunc(posarg, kwarg), posarg + kwarg)
        eq_(maxsize, ac.size)
        eq_(myfunc.callcount, calls)
Пример #2
0
 def __init__(
     self,
     url=None,
     raw_dir=None,
     save_dir=None,
     force_reload=False,
     verbose=False,
     fraction=1.0,
     mode="train",
     mem_cache=25000,
 ):
     super(MyDataset, self).__init__(name='dataset_name',
                                     url=url,
                                     raw_dir=raw_dir,
                                     save_dir=save_dir,
                                     force_reload=force_reload,
                                     verbose=verbose)
     self.ac = AnyCache(cachedir=save_dir)
     self.argo_loader = ArgoverseForecastingLoader(raw_dir)
     self.fraction = fraction if fraction <= 1.0 else 1.0
     self.mode = mode
     if self.mode == 'test':
         self.mem_cache = 0
     else:
         self.mem_cache = mem_cache
Пример #3
0
def test_maxsize_none():
    """Unlimited Caching."""
    ac = AnyCache(maxsize=None)

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        return posarg + kwarg

    myfunc.callcount = 0

    eq_(myfunc(4, 2), 6)
    size1 = ac.size
    n = 5
    calls = n * n

    for posarg in range(n):
        for kwarg in range(n):
            eq_(myfunc(posarg, kwarg), posarg + kwarg)
    eq_(calls * size1, ac.size)
    eq_(myfunc.callcount, calls)

    ac.clear()
    eq_(ac.size, 0)
Пример #4
0
def test_del():
    ac = AnyCache()

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        return posarg + kwarg

    eq_(myfunc(4, 5), 9)
    assert ac.size > 0

    del ac
Пример #5
0
def test_corrupt_cache():
    """Corrupted Cache."""
    cachedir = Path(mkdtemp())
    ac = AnyCache(cachedir=cachedir)

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        myfunc.callcount += 1
        return posarg + kwarg
    myfunc.callcount = 0

    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 1)
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 1)

    # corrupt cache
    cachefilepath = list(cachedir.glob("*.cache"))[0]
    with open(str(cachefilepath), "w") as cachefile:
        cachefile.write("foo")

    # repair
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 2)
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 2)

    # corrupt dep
    depfilepath = list(cachedir.glob("*.dep"))[0]
    with open(str(depfilepath), "w") as depfile:
        depfile.write("foo")

    # repair
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 3)
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 3)

    ac.clear()
Пример #6
0
def test_size():
    """Size."""
    ac = AnyCache()

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        return posarg + kwarg

    eq_(ac.size, 0)
    eq_(len(tuple(ac.cachedir.glob("*.cache"))), 0)
    eq_(myfunc(4, 5), 9)
    eq_(len(tuple(ac.cachedir.glob("*.cache"))), 1)
    size1 = ac.size
    eq_(myfunc(4, 2), 6)
    eq_(ac.size,  2 * size1)
    eq_(len(tuple(ac.cachedir.glob("*.cache"))), 2)
Пример #7
0
def test_maxsize_0():
    """Disable Caching."""
    ac = AnyCache(maxsize=0)

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        return posarg + kwarg

    myfunc.callcount = 0

    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 1)
    eq_(myfunc(4, 2), 6)
    eq_(myfunc.callcount, 2)
    eq_(myfunc(4, 2), 6)
    eq_(myfunc.callcount, 3)
    eq_(ac.size, 0)
Пример #8
0
def test_cleanup():
    """Cleanup."""
    ac = AnyCache()
    cachedir = ac.cachedir

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        return posarg + kwarg
    myfunc.callcount = 0

    # first use
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 1)
    eq_(myfunc(4, 2), 6)
    eq_(myfunc.callcount, 2)
    eq_(myfunc(4, 2), 6)
    eq_(myfunc.callcount, 2)
    assert ac.size > 0

    # clear
    ac.clear()
    eq_(ac.size, 0)
    eq_(tuple(cachedir.glob("*")), tuple())

    # second use
    eq_(myfunc(4, 4), 8)
    eq_(myfunc.callcount, 3)
    assert ac.size > 0

    # clear twice
    ac.clear()
    eq_(ac.size, 0)
    ac.clear()
    eq_(ac.size, 0)
Пример #9
0
def test_ident():

    ac = AnyCache()

    @ac.anycache()
    def onefunc(posarg, kwarg=3):
        return posarg + kwarg

    @ac.anycache()
    def otherfunc(posarg, kwarg=3):
        return posarg + kwarg

    eq_(ac.get_ident(onefunc, 3),
        'e41fb232f3d486a830bb807545ef52be582e907d505e7275a40d040b53bfe6a5')
    eq_(ac.get_ident(onefunc, 3, 3),
        '762c17e6404375af3a788bf29a0e92327c8d89383b1dacd1d6e80260b2d48500')
    eq_(ac.get_ident(onefunc, 4),
        'e0e1f954b03011a72ce56802da766b3ab613b3f63b894bb43d353140bef0a5ea')
    eq_(ac.get_ident(otherfunc, 4),
        '731001353ec04a1d6247224866b20b274a77943a7bb6e57bfad5ceb86cc0eebe')
Пример #10
0
def test_persistent():
    """Persistent Cache over multiple instances."""
    cachedir = Path(mkdtemp())
    ac = AnyCache(cachedir=cachedir)

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        return posarg + kwarg

    myfunc.callcount = 0

    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 1)
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 1)
    eq_(myfunc(4, 2), 6)
    eq_(myfunc.callcount, 2)

    del ac
    eq_(len(tuple(Path(cachedir).glob("*.cache"))), 2)

    ac = AnyCache(cachedir=cachedir)

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        return posarg + kwarg

    myfunc.callcount = 0

    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 0)
    eq_(myfunc(4, 5), 9)
    eq_(myfunc.callcount, 0)
    eq_(myfunc(4, 2), 6)
    eq_(myfunc.callcount, 0)

    ac.clear()
Пример #11
0
def test_is_outdated_and_remove():
    """is_outdated(), remove()."""

    ac = AnyCache()

    @ac.anycache()
    def myfunc(posarg, kwarg=3):
        # count the number of calls
        myfunc.callcount += 1
        return posarg + kwarg

    myfunc.callcount = 0

    eq_(ac.is_outdated(myfunc, 3), True)
    eq_(ac.is_outdated(myfunc, 3), True)
    eq_(myfunc.callcount, 0)

    eq_(myfunc(3), 6)

    eq_(myfunc.callcount, 1)
    eq_(ac.is_outdated(myfunc, 3), False)
    eq_(ac.is_outdated(myfunc, 3), False)

    ac.remove(myfunc, 3)

    eq_(myfunc.callcount, 1)
    eq_(ac.is_outdated(myfunc, 3), True)

    eq_(myfunc(3), 6)

    eq_(myfunc.callcount, 2)
    eq_(ac.is_outdated(myfunc, 3), False)

    ac.remove(myfunc, 3)
    ac.remove(myfunc, 3)
    eq_(ac.is_outdated(myfunc, 3), True)
Пример #12
0
class MyDataset(DGLDataset, ABC):
    """ Template for customizing graph datasets in DGL.

    Parameters
    ----------
    url : str
        URL to download the raw dataset
    raw_dir : str
        Specifying the directory that will store the
        downloaded data or the directory that
        already stores the input data.
        Default: ~/.dgl/
    save_dir : str
        Directory to save the processed dataset.
        Default: the value of `raw_dir`
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose : bool
        Whether to print out progress information
    split data set : float
        proportion of the dataset
    """
    def __init__(
        self,
        url=None,
        raw_dir=None,
        save_dir=None,
        force_reload=False,
        verbose=False,
        fraction=1.0,
        mode="train",
        mem_cache=25000,
    ):
        super(MyDataset, self).__init__(name='dataset_name',
                                        url=url,
                                        raw_dir=raw_dir,
                                        save_dir=save_dir,
                                        force_reload=force_reload,
                                        verbose=verbose)
        self.ac = AnyCache(cachedir=save_dir)
        self.argo_loader = ArgoverseForecastingLoader(raw_dir)
        self.fraction = fraction if fraction <= 1.0 else 1.0
        self.mode = mode
        if self.mode == 'test':
            self.mem_cache = 0
        else:
            self.mem_cache = mem_cache

    def download(self):
        # download raw data to local disk
        pass

    def process(self):
        # process raw data to graphs, labels, splitting masks
        pass

    def __getitem__(self, idx):
        # get one example by index
        @lru_cache(maxsize=self.mem_cache)
        @self.ac.anycache()
        def idx_to_graph(sample_id: int):
            my_dict = {}
            argo_sample = self.argo_loader[idx]
            seq_df = argo_sample.seq_df
            if self.mode == 'test':
                my_dict['df'] = seq_df
            # timestamp_iter = map(lambda t: round(float(t), 1), np.unique(seq_df["TIMESTAMP"].values).tolist())

            track_id_list = argo_sample.track_id_list

            my_dict['city'] = argo_sample.city
            my_dict['filename'] = os.path.splitext(
                os.path.basename(argo_sample.current_seq))[0]
            my_dict['timestamp'] = np.unique(
                seq_df["TIMESTAMP"].values).tolist()
            my_dict['time_dict'] = {
                t: idx
                for idx, t in enumerate(my_dict['timestamp'])
            }
            my_dict['start_time'], my_dict['end_time'] = my_dict['timestamp'][
                0], my_dict['timestamp'][-1]
            my_dict['split_time'] = my_dict['timestamp'][19]
            my_dict['agent'] = argo_sample.agent_traj
            my_dict['radix'] = {
                'x': my_dict['agent'][19][0],
                'y': my_dict['agent'][19][1],
                'yaw': 0,
            }
            my_dict['av'] = {}
            my_dict['agent_track_id'] = ""
            for track_id, row in seq_df.iterrows():
                if str(row['OBJECT_TYPE']) == "AGENT":
                    my_dict['agent_track_id'] = row['TRACK_ID']
                    break
            for track_id in track_id_list:
                track = seq_df[(seq_df["TRACK_ID"] == track_id)
                               & (seq_df["OBJECT_TYPE"] == "AV")][[
                                   "TIMESTAMP", "X", "Y"
                               ]].to_numpy()
                if len(track) > 0:
                    my_dict['av'][track_id] = track

            my_dict['others'] = {}
            for track_id in track_id_list:
                track = seq_df[(seq_df["TRACK_ID"] == track_id)
                               & (seq_df["OBJECT_TYPE"] == "OTHERS")][[
                                   "TIMESTAMP", "X", "Y"
                               ]].to_numpy()
                if len(track) > 0:
                    my_dict['others'][track_id] = track

            my_dict['lanes'] = {
                lane_id:
                lane_seq[:10, :] if len(lane_seq) >= 10 else np.concatenate(
                    (lane_seq, np.zeros((10 - len(lane_seq), 2))))
                for lane_id, lane_seq in enumerate(
                    get_lane_center_lines(center_coordinate=my_dict['agent']
                                          [19],
                                          radius=50,
                                          city_name=my_dict['city']))
            }
            return my_dict

        my_dict = idx_to_graph(sample_id=idx)
        return dict_to_graph(my_dict), my_dict

    def __len__(self):
        # number of data examples
        return int(self.fraction * len(self.argo_loader))

    def save(self):
        # save processed data to directory `self.save_path`
        pass

    def load(self):
        # load processed data from directory `self.save_path`
        pass

    def clear_cache(self):
        # clear the processed data from directory 'self.save_path'
        print('clear the cache')
        self.ac.clear()

    def has_cache(self):
        # check whether there are processed data in `self.save_path`
        pass

    @property
    def cache_size(self):
        return self.ac.size