コード例 #1
0
 def test_size(self):
     sl = SkipList()
     sl.insert('foo', 'bar')
     size = getsize(sl)
     self.assertIsInstance(size, int)
     self.assertGreater(size, 0)
     self.assertLess(size, 5000)
コード例 #2
0
ファイル: replay.py プロジェクト: alcinos/dps
    def __init__(self,
                 size,
                 n_partitions,
                 priority_func,
                 alpha,
                 beta_schedule,
                 min_experiences=None,
                 name=None):
        self.size = size
        self.n_partitions = n_partitions
        self.priority_func = priority_func
        self.alpha = alpha
        self.beta_schedule = beta_schedule
        self.min_experiences = min_experiences

        self.index = 0

        self._experiences = {}

        # Note this is actually a MIN priority queue, so to make it act like a MAX priority
        # queue, we use the negative of the provided priorities.
        self.skip_list = SkipList()
        self.distributions = self.build_distribution()

        self._active_set = None

        super(PrioritizedReplayBuffer, self).__init__(name)
コード例 #3
0
ファイル: perf_skiplist.py プロジェクト: geertj/pyskiplist
 def _create_skiplist(self, n):
     # Create a skiplist with *n* elements.
     sl = SkipList()
     maxkey = 100*n
     for i in range(n):
         sl.insert(random.randint(0, maxkey), i)
     return sl
コード例 #4
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_size(self):
     sl = SkipList()
     sl.insert('foo', 'bar')
     size = getsize(sl)
     self.assertIsInstance(size, int)
     self.assertGreater(size, 0)
     self.assertLess(size, 5000)
コード例 #5
0
 def _create_skiplist(self, n):
     # Create a skiplist with *n* elements.
     sl = SkipList()
     maxkey = 100 * n
     for i in range(n):
         sl.insert(random.randint(0, maxkey), i)
     return sl
コード例 #6
0
 def test_node_size(self):
     sl = SkipList()
     for i in range(1000):
         sl.insert(i, None)
     size = getsize(sl)
     self.assertIsInstance(size, int)
     self.assertGreater(size, 0)
     self.assertLess(size / 1000, 250)
コード例 #7
0
ファイル: mem_skiplist.py プロジェクト: geertj/pyskiplist
 def mem_node_size(self):
     for logN in range(3, 6):
         items = 10 ** logN
         sl = SkipList()
         for i in range(items):
             sl.insert(i, i)
         size = getsize(sl)
         self.add_result(size / items, suffix=items)
コード例 #8
0
ファイル: mem_skiplist.py プロジェクト: geertj/pyskiplist
 def mem_node_overhead(self):
     for logN in range(3, 6):
         items = 10 ** logN
         sl = SkipList()
         for i in range(items):
             sl.insert(i, i)
         overhead = getsize(sl) - items * 2 * sys.getsizeof(i)
         self.add_result(overhead / items, suffix=items)
コード例 #9
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_node_size(self):
     sl = SkipList()
     for i in range(1000):
         sl.insert(i, None)
     size = getsize(sl)
     self.assertIsInstance(size, int)
     self.assertGreater(size, 0)
     self.assertLess(size/1000, 250)
コード例 #10
0
 def mem_node_overhead(self):
     for logN in range(3, 6):
         items = 10**logN
         sl = SkipList()
         for i in range(items):
             sl.insert(i, i)
         overhead = getsize(sl) - items * 2 * sys.getsizeof(i)
         self.add_result(overhead / items, suffix=items)
コード例 #11
0
 def mem_node_size(self):
     for logN in range(3, 6):
         items = 10**logN
         sl = SkipList()
         for i in range(items):
             sl.insert(i, i)
         size = getsize(sl)
         self.add_result(size / items, suffix=items)
コード例 #12
0
 def test_dump(self):
     sl = SkipList()
     sl.insert('foo', 'bar')
     sl.insert('baz', 'qux')
     out = six.StringIO()
     dump(sl, out)
     s = out.getvalue()
     self.assertIsInstance(s, str)
     self.assertGreater(len(s), 20)
コード例 #13
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_bool(self):
     sl = SkipList()
     self.assertFalse(sl)
     self.assertFalse(bool(sl))
     check(sl)
     sl.insert('foo', 'bar')
     self.assertTrue(sl)
     self.assertTrue(bool(sl))
     check(sl)
コード例 #14
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_dump(self):
     sl = SkipList()
     sl.insert('foo', 'bar')
     sl.insert('baz', 'qux')
     out = six.StringIO()
     dump(sl, out)
     s = out.getvalue()
     self.assertIsInstance(s, str)
     self.assertGreater(len(s), 20)
コード例 #15
0
 def test_bool(self):
     sl = SkipList()
     self.assertFalse(sl)
     self.assertFalse(bool(sl))
     check(sl)
     sl.insert('foo', 'bar')
     self.assertTrue(sl)
     self.assertTrue(bool(sl))
     check(sl)
コード例 #16
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_clear(self):
     size = self.size
     sl = SkipList()
     for i in range(size):
         sl.insert(random.randint(0, 2*size), random.randint(0, 10*size))
     self.assertGreater(sl.level, 1)
     self.assertEqual(len(sl), size)
     sl.clear()
     check(sl); self.assertEqual(list(sl), [])
     self.assertEqual(sl.level, 1)
コード例 #17
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_len(self):
     size = self.size
     sl = SkipList()
     pairs = []
     for i in range(size):
         pair = (random.randint(0, 2*size), random.randint(0, 10*size))
         sl.insert(*pair)
         pairs = sorted(pairs + [pair], key=lambda x: x[0])
         self.assertEqual(len(sl), i+1)
         check(sl); self.assertEqual(list(sl), pairs)
コード例 #18
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_replace(self):
     size = self.size
     sl = SkipList()
     values = {}
     for i in range(size):
         pair = (random.randint(0, 2*size), random.randint(0, 10*size))
         sl.replace(*pair)
         values[pair[0]] = pair[1]
         pairs = sorted(values.items(), key=lambda x: x[0])
         check(sl); self.assertEqual(list(sl), pairs)
     self.assertGreater(sl.level, 1)
コード例 #19
0
 def test_len(self):
     size = self.size
     sl = SkipList()
     pairs = []
     for i in range(size):
         pair = (random.randint(0, 2 * size), random.randint(0, 10 * size))
         sl.insert(*pair)
         pairs = sorted(pairs + [pair], key=lambda x: x[0])
         self.assertEqual(len(sl), i + 1)
         check(sl)
         self.assertEqual(list(sl), pairs)
コード例 #20
0
 def test_replace(self):
     size = self.size
     sl = SkipList()
     values = {}
     for i in range(size):
         pair = (random.randint(0, 2 * size), random.randint(0, 10 * size))
         sl.replace(*pair)
         values[pair[0]] = pair[1]
         pairs = sorted(values.items(), key=lambda x: x[0])
         check(sl)
         self.assertEqual(list(sl), pairs)
     self.assertGreater(sl.level, 1)
コード例 #21
0
 def _create_skiplist(self, size, keysize, valuesize):
     sl = SkipList()
     pairs = []
     values = {}
     for i in range(size):
         pair = (random.randint(0, keysize), random.randint(0, valuesize))
         sl.insert(*pair)
         pairs.append(pair)
         if pair[0] not in values:
             values[pair[0]] = []
         values[pair[0]].append(pair[1])
     pairs = sorted(pairs, key=lambda x: x[0])
     return sl, pairs, values
コード例 #22
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def _create_skiplist(self, size, keysize, valuesize):
     sl = SkipList()
     pairs = []
     values = {}
     for i in range(size):
         pair = (random.randint(0, keysize), random.randint(0, valuesize))
         sl.insert(*pair)
         pairs.append(pair)
         if pair[0] not in values:
             values[pair[0]] = []
         values[pair[0]].append(pair[1])
     pairs = sorted(pairs, key=lambda x: x[0])
     return sl, pairs, values
コード例 #23
0
ファイル: sparsedlist.py プロジェクト: bdragon300/sparsedlist
    def __init__(self, initlist=None, inititems=None, required=False):
        """
        :param initlist: Optional. Initial data. Elements will be placed sequentallu
        :param inititems: Optional. Initial items pairs.
        :param required: Optional. If True, getting unset elements causes IndexError. Otherwise, unset elements will
            be substituted by None. Default is False.
        """
        self.data = SkipList()
        self._required = required

        if initlist is not None:
            for i, v in enumerate(initlist):
                self.data.insert(i, v)

        if inititems is not None:
            for i, v in inititems:
                self.data.insert(i, v)
コード例 #24
0
ファイル: sparsedlist.py プロジェクト: bdragon300/sparsedlist
    def insert(self, index, value):
        index = int(index)

        new = SkipList()
        for k, v in self.data.items(stop=index):
            new.insert(k, v)
        new.insert(index, value)
        for k, v in self.data.items(start=index):
            new.insert(k + 1, v)

        self.data = new
コード例 #25
0
ファイル: Skiplist.py プロジェクト: Jakobis/OrderedSequences
class Skiplist(Template):
    def __init__(self, preload=[]):
        self.li = SkipList()
        for i in preload:  # init dumb
            self.add(self, i)

    def add(self, element):
        self.li.insert(element, element)

    def delete(self, index):
        self.li.__delitem__(index)

    def remove(self, element):
        self.li.remove(element)

    def rank(self, element):
        return self.li.index(element)

    def select(self, index):
        return self.li[index][0]

    def iter(self):
        return self.li.values()

    def reversed(self):
        return reversed(self.li)

    def count(self, value):
        return self.li.count(value)

    def successor(self, value):
        index = self.rank(self, value) + 1
        result = self.select(self, index)
        while result == value:
            index += 1
            if index > len(self.li): return None
            result = self.select(self, index)
        return self.select(self, index)

    def predecessor(self, value):
        return self.select(self, self.rank(self, value) - 1)

    def size(self):
        return len(self.li)
コード例 #26
0
 def test_clear(self):
     size = self.size
     sl = SkipList()
     for i in range(size):
         sl.insert(random.randint(0, 2 * size),
                   random.randint(0, 10 * size))
     self.assertGreater(sl.level, 1)
     self.assertEqual(len(sl), size)
     sl.clear()
     check(sl)
     self.assertEqual(list(sl), [])
     self.assertEqual(sl.level, 1)
コード例 #27
0
ファイル: test_skiplist.py プロジェクト: geertj/pyskiplist
 def test_repr(self):
     sl = SkipList()
     sl.insert(1, 2)
     sl.insert(3, 4)
     self.assertEqual(repr(sl), 'SkipList(((1, 2), (3, 4)))')
     check(sl)
コード例 #28
0
 def test_level(self):
     sl = SkipList()
     self.assertEqual(sl.level, 1)
     check(sl)
コード例 #29
0
 def test_repr(self):
     sl = SkipList()
     sl.insert(1, 2)
     sl.insert(3, 4)
     self.assertEqual(repr(sl), 'SkipList(((1, 2), (3, 4)))')
     check(sl)
コード例 #30
0
ファイル: sparsedlist.py プロジェクト: bdragon300/sparsedlist
class SparsedList(MutableSequence):
    def __init__(self, initlist=None, inititems=None, required=False):
        """
        :param initlist: Optional. Initial data. Elements will be placed sequentallu
        :param inititems: Optional. Initial items pairs.
        :param required: Optional. If True, getting unset elements causes IndexError. Otherwise, unset elements will
            be substituted by None. Default is False.
        """
        self.data = SkipList()
        self._required = required

        if initlist is not None:
            for i, v in enumerate(initlist):
                self.data.insert(i, v)

        if inititems is not None:
            for i, v in inititems:
                self.data.insert(i, v)

    def _clone(self):
        return self.__class__(required=self._required)

    def _unset(self, index):
        if not self._required:
            return None
        else:
            raise IndexError(
                "Item with index '{}' does not exist".format(index))

    def __repr__(self):
        return 'SparsedList{' + str(dict(self.data.items())) + '}'

    def __eq__(self, other):
        return len(self.data) == len(self.__cast(other)) \
               and all(a[1] == b[1] and a[0] == b[0] for a, b in zip(self.data, self.__cast(other)))

    def __ne__(self, other):
        return not self.__eq__(other)

    @staticmethod
    def __cast(other):
        return other.data if isinstance(other, SparsedList) else other

    def __contains__(self, item):
        return any(x == item for x in self.data.values())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        def objs(start, stop, step):
            c = start or 0
            step = step or 1

            # E.g. [5:5] or [10:5]
            if stop is not None and (start or 0) >= stop:
                return []

            items = self.data.items(start=start, stop=stop)  # generator
            for i in items:
                while c < i[0]:
                    yield self._unset(c)
                    c += step

                if c == i[0]:
                    yield i[1]
                    c += step

            while stop is None or c < stop:
                yield self._unset(c)
                c += step

        if isinstance(item, slice):
            start, stop, step = self._slice_indexes(item)

            return objs(start, stop, step)
        else:
            item = int(item)
            if item < 0:
                last_ind = self.data[-1][0]  # IndexError if empty
                item = last_ind + item + 1
                if item < 0:
                    raise IndexError("Negative index overlaps the list start")

            val = self.data.search(item, default=DOES_NOT_EXIST)
            if val is DOES_NOT_EXIST:
                return self._unset(item)

            return val

    def __setitem__(self, key, value):
        if isinstance(key, slice):
            if not isinstance(value, Iterable):
                raise TypeError('Can only assign an iterable')

            start, stop, step = self._slice_indexes(key)
            step = step or 1
            start = start or 0
            c = start
            vi = iter(value)

            while stop is None or c < stop:
                try:
                    self.data.replace(c, next(vi))
                except StopIteration:
                    # Remove the rest elements in slice if it longer than given iterable
                    for i in self.data.items(start=c, stop=stop):
                        if (
                                i[0] - start
                        ) % step:  # Dont touch items which does not fall into steps
                            continue

                        self.data.remove(i[0])
                    return

                c += step

        else:
            key = int(key)
            if key < 0:
                last_ind = self.data[-1][0]  # IndexError if empty
                key = last_ind + key + 1
                if key < 0:
                    raise IndexError("Negative index overlaps the list start")

            self.data.replace(key, value)

    def __delitem__(self, key):
        if isinstance(key, slice):
            start, stop, step = self._slice_indexes(key)

            step = step or 1
            start = start or 0
            stop = stop or self.tail() + 1
            c = start
            while stop is None or c < stop:
                try:
                    self.data.remove(c)
                except KeyError:
                    pass

                c += step
        else:
            try:
                key = int(key)
                if key < 0:
                    last_ind = self.data[-1][0]  # IndexError if empty
                    key = last_ind + key + 1
                    if key < 0:
                        raise IndexError(
                            "Negative index overlaps the list start")

                self.data.remove(key)
            except KeyError:
                raise IndexError(
                    "Item with index '{}' does not exist".format(key))

    def __iter__(self):
        c = 0
        for k, v in self.data.items():
            while k > c:
                yield self._unset(c)
                c += 1

            yield v
            c += 1

    def __reversed__(self):
        l = len(self.data)
        return (self.data[i][1] for i in range(l - 1, -1, -1))

    def __add__(self, other):
        obj = self._clone()
        try:
            offset = self.tail() + 1
        except IndexError:
            offset = 0

        if isinstance(other, SparsedList):
            other = ((i + offset, v) for i, v in other.data)
        else:
            other = enumerate(other, start=offset)

        for i in itertools.chain(self.data, other):
            obj.data.insert(*i)

        return obj

    def __radd__(self, other):
        obj = self._clone()

        if isinstance(other, SparsedList):
            try:
                offset = other.tail() + 1
            except IndexError:
                offset = 0
            other = other.data
        else:
            offset = len(other)
            other = enumerate(other)

        this = ((i + offset, v) for i, v in self.data)

        for i in itertools.chain(other, this):
            obj.data.insert(*i)

        return obj

    def __iadd__(self, other):
        try:
            offset = self.tail() + 1
        except IndexError:
            offset = 0

        if isinstance(other, SparsedList):
            other = ((i + offset, v) for i, v in other.data)
        else:
            other = enumerate(other, start=offset)

        for i in other:
            self.data.insert(*i)

        return self

    def __mul__(self, n):
        if not isinstance(n, int):
            raise TypeError(
                "can't multiply sequence by non-int of type '{}'".format(
                    type(n)))

        obj = self._clone()

        try:
            offset = self.tail() + 1
        except IndexError:
            offset = 0

        for c in range(0, offset * n, offset):
            for i, v in self.data:
                obj.data.insert(i + c, v)

        return obj

    __rmul__ = __mul__

    def __imul__(self, n):
        if not isinstance(n, int):
            raise TypeError(
                "can't multiply sequence by non-int of type '{}'".format(
                    type(n)))

        try:
            offset = self.tail() + 1
        except IndexError:
            offset = 0

        for c in range(offset, offset * n, offset):
            for i, v in self.data.items(stop=offset):
                self.data.insert(i + c, v)

        return self

    def __copy__(self):
        return self.copy()

    def insert(self, index, value):
        index = int(index)

        new = SkipList()
        for k, v in self.data.items(stop=index):
            new.insert(k, v)
        new.insert(index, value)
        for k, v in self.data.items(start=index):
            new.insert(k + 1, v)

        self.data = new

    def append(self, value):
        """Append given value in place after the last item"""
        self.data.insert(self.tail() + 1, value)

    def extend(self, items):
        """
        Extend (merge) SparsedList with given items. Already existing items will be overwritten
        :param items: key/value pairs iterable
        """
        for i, v in items:
            self.data.replace(i, v)

    def clear(self):
        """Clear all data"""
        self.data.clear()

    def reverse(self):
        l = len(self.data)
        for k1, k2 in zip(range(l // 2), range(l - 1, 0, -1)):
            self.data[k1], self.data[k2] = self.data[k2][1], self.data[k1][1]

    def pop(self, index=-1):
        """Pop the item with given index. Negative indexes counted from position of the last existing item"""
        if index < 0:
            index = max(self.tail() + index + 1, 0)

        try:
            return self.data.pop(index)
        except KeyError:
            raise IndexError('Pop from empty SparsedList')

    def remove(self, value):
        """Remove the first item from the list whose value is equal to x. ValueError is raised if value not found"""
        ind = self.index(value)
        self.data.remove(ind)

    def sort(self, *args, **kwds):
        for k, v in enumerate(sorted(self.data.values())):
            self.data[k] = v

    def copy(self):
        obj = self._clone()
        for p in self.data.items():
            obj.data.insert(*p)

        return obj

    def index(self, value, start=None, stop=None):
        """
        Return zero-based index in the list of the first item whose value is equal to x.
        Raises a ValueError if there is no such item.
        """
        if start is not None and start < 0:
            start = max(self.tail() + start + 1, 0)
        if stop is not None and stop < 0:
            stop = max(self.tail() + stop + 1, 0)

        for i, v in self.data.items(start, stop):
            if v == value:
                return i

        raise ValueError("'{}' is not in SparsedList".format(value))

    def count(self, item):
        """
        Return total number of occurrences of given `item` in list
        :param item:
        """
        return len([1 for x in self.data.values() if x == item])

    def items(self, start=None, stop=None):
        if start is not None and start < 0:
            start = max(self.tail() + start + 1, 0)
        if stop is not None and stop < 0:
            stop = max(self.tail() + stop + 1, 0)

        return self.data.items(start=start, stop=stop)

    def keys(self, start=None, stop=None):  # NOQA
        """
        Return keys of non-empty items
        :param start:
        :param stop:
        :return:
        """
        return self.data.keys(start=start, stop=stop)

    def values(self, start=None, stop=None):  # NOQA
        """
        Return values of non-empty items
        :param start:
        :param stop:
        :return:
        """
        return self.data.values(start=start, stop=stop)

    def tail(self):
        """
        Return index of the last element
        :raises IndexError: no elements in list
        """
        return self.data[-1][0]

    def _slice_indexes(self, s):
        """
        Calculate positive index bounds from slice. If slice param is None, then it will be left as None
        :param s: slice object
        :return: start, stop, step
        """
        pieces = [s.start, s.stop, s.step]

        for i in [0, 1]:
            if pieces[i] is not None:
                pieces[i] = int(pieces[i])
                if pieces[i] < 0:
                    try:
                        last_ind = self.tail()  # IndexError if empty
                    except IndexError:
                        last_ind = 0
                    pieces[i] = max(last_ind + pieces[i] + 1, 0)

        if pieces[2] is not None:
            if pieces[2] < 0:
                raise ValueError('Negative slice step is not supported')
            elif pieces[2] == 0:
                raise ValueError('Slice step cannot be zero')

        return tuple(pieces)
コード例 #31
0
ファイル: CacheTest.py プロジェクト: Jakobis/OrderedSequences
if sys.argv[1] == 'Blist':
    l = blist.sortedlist([i for i in range(10**N)])
elif sys.argv[1] == 'SortedArray':
    l = SCArray._SortedArray('q', [i for i in range(10**N)])
elif sys.argv[1] == 'SortedList':
    l = sortedcontainers.SortedList([i for i in range(10**N)])
elif sys.argv[1] == 'AutoLoad':
    l = SCAutoBalance.SortedList([i for i in range(10**N)])
elif sys.argv[1] == 'BadBisect':
    l = SCPyBisect.SortedList([i for i in range(10**N)])
elif sys.argv[1] == 'RBSTree':
    l = RedBlackBST()
    for i in range(10**N):
        l.put(i, 1)
elif sys.argv[1] == 'SkipList':
    l = SkipList()
    for i in range(10**N):
        l.insert(i, i)
else:
    print("Error, did not supply working thingy")
    exit(-1)

print(f"Testing {sys.argv[1]}")
print(len(l))
if sys.argv[3] == 'Base':
    pass
elif sys.argv[3] == 'Add':
    if sys.argv[1] == 'SkipList':
        for i in range(1, 10**N):  #No duplicates
            l.insert(-i, -i)
    elif sys.argv[1] == 'RBSTree':
コード例 #32
0
ファイル: replay.py プロジェクト: alcinos/dps
class PrioritizedReplayBuffer(RLObject):
    """ Implements rank-based version of Prioritized Experience Replay.

    Parameters
    ----------
    size: int
        Maximum number of experiences to store.
    n_partitions: int
        Number of partitions to use for the sampling approximation.
    priority_func: callable
        Maps from an RLContext object to a signal to use as the priorities for the replay buffer.
    alpha: float > 0
        Degree of prioritization (similar to a softmax temperature); 0 corresponds to no prioritization
        (uniform distribution), inf corresponds to degenerate distribution which always picks element
        with highest priority.
    beta_schedule: 1 > float > 0
        Degree of importance sampling correction, 0 corresponds to no correction, 1 corresponds to
        full correction. Usually anneal linearly from an initial value beta_0 to a value of 1 by the
        end of learning.
    min_experiences: int > 0
        Minimum number of experiences that must be stored in the replay buffer before it will return
        a valid batch when `get_batch` is called. Before this point, it returns None, indicating that
        whatever is making use of this replay memory should not make an update.

    """
    def __init__(self,
                 size,
                 n_partitions,
                 priority_func,
                 alpha,
                 beta_schedule,
                 min_experiences=None,
                 name=None):
        self.size = size
        self.n_partitions = n_partitions
        self.priority_func = priority_func
        self.alpha = alpha
        self.beta_schedule = beta_schedule
        self.min_experiences = min_experiences

        self.index = 0

        self._experiences = {}

        # Note this is actually a MIN priority queue, so to make it act like a MAX priority
        # queue, we use the negative of the provided priorities.
        self.skip_list = SkipList()
        self.distributions = self.build_distribution()

        self._active_set = None

        super(PrioritizedReplayBuffer, self).__init__(name)

    def build_core_signals(self, context):
        self.beta = context.get_signal("beta", self)
        self.priority_signal = tf.reshape(self.priority_func(context), [-1])

    def generate_signal(self, signal_key, context):
        if signal_key == "beta":
            return build_scheduled_value(self.beta_schedule,
                                         '{}-beta'.format(self.name))
        else:
            raise Exception("NotImplemented")

    def post_update(self, feed_dict, context):
        if self._active_set is not None:
            priority = tf.get_default_session().run(self.priority_signal,
                                                    feed_dict=feed_dict)
            self.update_priority(priority)

    @property
    def n_experiences(self):
        return len(self._experiences)

    def build_distribution(self):
        pdf = np.arange(1, self.size + 1)**-self.alpha
        pdf /= pdf.sum()

        cdf = np.cumsum(pdf)

        # Whenever the CDF crosses one of the discretization bucket boundaries,
        # we assign the index where the crossing occurred to the next bucket rather
        # than the current one.
        strata_starts, strata_ends, strata_sizes = [], [], []
        start_idx, end_idx = 0, 0
        for s in range(self.n_partitions):
            strata_starts.append(start_idx)
            if s == self.n_partitions - 1:
                strata_ends.append(len(cdf))
            else:
                while cdf[end_idx] < (s + 1) / self.n_partitions:
                    end_idx += 1
                if start_idx == end_idx:
                    strata_ends.append(end_idx + 1)
                else:
                    strata_ends.append(end_idx)
            strata_sizes.append(strata_ends[-1] - strata_starts[-1])
            start_idx = end_idx

        self.strata_starts = strata_starts
        self.strata_ends = strata_ends
        self.strata_sizes = strata_sizes
        self.pdf = pdf

    def add_rollouts(self, rollouts):
        assert isinstance(rollouts, RolloutBatch)
        for r in rollouts.split():
            # If there was already an experience at location `self.index`, it is effectively ejected.
            self._experiences[self.index] = r

            # Insert with minimum priority initially.
            if self.skip_list:
                priority = self.skip_list[0][0]
            else:
                priority = 0.0
            self.skip_list.insert(priority, self.index)

            self.index = (self.index + 1) % self.size

    def update_priority(self, priorities):
        """ update priority after calling `get_batch` """
        if self._active_set is None:
            raise Exception(
                "``update_priority`` should only called after calling ``get_batch``."
            )

        for (p_idx, e_idx,
             old_priority), new_priority in zip(self._active_set, priorities):
            # negate `new_priority` because SkipList puts lowest first.
            if old_priority != -new_priority:
                del self.skip_list[p_idx]
                self.skip_list.insert(-new_priority, e_idx)

        self._active_set = None

    def get_batch(self, batch_size):
        no_sample = ((self.min_experiences is not None
                      and self.n_experiences < self.min_experiences)
                     or self.n_experiences < batch_size)
        if no_sample:
            return None, None

        priority_indices = []
        start = 0
        permutation = np.random.permutation(self.n_partitions)

        selected_sizes = []
        for i in islice(cycle(permutation), batch_size):
            start, end = self.strata_starts[i], self.strata_ends[i]
            priority_indices.append(np.random.randint(start, end))
            selected_sizes.append(self.strata_sizes[i])

        # We set p_x to be the actual probability that we sampled with,
        # namely 1 / (n_partitions * size_of_partition), rather than
        # the pdf that our sampling method approximates, namely `self.pdf`.
        # This is both more faithful, and works better when the memory is not full.
        p_x = (self.n_partitions * np.array(selected_sizes))**-1.

        beta = tf.get_default_session().run(self.beta)

        weights = (p_x * self.size)**-beta
        weights /= weights.max()

        # When we aren't full, map priority_indices (which are in range(self.size)) down to range(self.n_experiences)
        if self.n_experiences < self.size:
            priority_indices = [
                int(np.floor(self.n_experiences * (i / self.size)))
                for i in priority_indices
            ]

        self._active_set = []
        for p_idx in priority_indices:
            priority, e_idx = self.skip_list[p_idx]
            self._active_set.append((p_idx, e_idx, priority))

        experiences = RolloutBatch.join(
            [self._experiences[e_idx] for _, e_idx, _ in self._active_set])

        return experiences, weights
コード例 #33
0
ファイル: Skiplist.py プロジェクト: Jakobis/OrderedSequences
 def __init__(self, preload=[]):
     self.li = SkipList()
     for i in preload:  # init dumb
         self.add(self, i)
コード例 #34
0
def constructSL(aList):
    """Create SkipList from random sample."""
    sl = SkipList()
    for val in aList:
        sl.insert(val, val)
    return sl
コード例 #35
0
 def mem_size(self):
     sl = SkipList()
     self.add_result(getsize(sl))