Esempio n. 1
0
 def __init__(self,
              num_actions,
              buffer_size,
              latent_dim,
              hash_dim,
              gamma=0.99,
              bp=True,
              debug=True):
     self.num_actions = num_actions
     self.gamma = gamma
     self.rmax = 100000
     self.bp = bp
     self.debug = debug
     self.ec_buffer = LRU_KNN_GPU_PS(buffer_size,
                                     latent_dim,
                                     hash_dim,
                                     'game',
                                     num_actions,
                                     debug=debug)
     self.logger = logging.getLogger("ecbp")
     self.pqueue = HashPQueue()
     self.sa_explore = 10
     self.max_iter = 1000
     self.dist = None
     self.ind = None
     self.b = 10
     self.h = 1
Esempio n. 2
0
    def __init__(self,
                 num_actions,
                 buffer_size,
                 latent_dim,
                 obs_dim,
                 conn,
                 gamma=0.99,
                 knn=4,
                 queue_threshold=5e-5,
                 density=True):

        super(PSLearningProcess, self).__init__()
        self.num_actions = num_actions
        self.gamma = gamma
        self.rmax = 100000
        self.logger = logging.getLogger("ec")
        self.sa_explore = 10
        self.min_iter = 20
        self.run_sweep = True
        self.num_iters = 0
        self.conn = conn
        self.buffer_size = buffer_size
        self.latent_dim = latent_dim
        self.obs_dim = obs_dim
        self.knn = knn
        self.update_enough = True
        self.iters_per_step = 0
        self.queue_threshold = queue_threshold
        # self.queue_lock = Lock()
        self.pqueue = HashPQueue()
        self.sequence = []

        self.first_ob_index = -10
        self.use_density = density
        self.ec_buffer = None
 def __init__(self,
              num_actions,
              buffer_size,
              latent_dim,
              hash_dim,
              gamma=0.99,
              bp=True,
              debug=True):
     self.num_actions = num_actions
     self.gamma = gamma
     self.rmax = 100000
     self.bp = bp
     self.debug = debug
     self.manager = Manager()
     self.ns = self.manager.Namespace()
     self.logger = logging.getLogger("ecbp")
     self.sa_explore = 10
     self.max_iter = 1000000
     self.run_sweep = True
     self.num_iters = 0
     # self.queue_lock = Lock()
     self.ns.ec_buffer = LRU_KNN_GPU_PS(buffer_size,
                                        latent_dim,
                                        hash_dim,
                                        'game',
                                        num_actions,
                                        debug=debug)
     self.ec_buffer = self.ns.ec_buffer
     self.ns.pqueue = HashPQueue()
     self.pqueue = self.ns.pqueue
     self.queue_lock = self.manager.Lock()
     self.sweep_process = BackUpProcess(self.ns, self.queue_lock)
     self.begin_sweep()
Esempio n. 4
0
 def __init__(self, num_actions, buffer_size, latent_dim, hash_dim, conn, gamma=0.99):
     super(KernelBasedPriorSweepProcess, self).__init__()
     self.num_actions = num_actions
     self.gamma = gamma
     self.rmax = 100000
     self.logger = logging.getLogger("ecbp")
     self.sa_explore = 10
     self.max_iter = 1000000
     self.run_sweep = True
     self.num_iters = 0
     self.conn = conn
     self.buffer_size = buffer_size
     self.latent_dim = latent_dim
     self.hash_dim = hash_dim
     # self.queue_lock = Lock()
     self.pqueue = HashPQueue()
     self.b = 0.0001
     self.h = 0.0001
     self.knn_dist = None
     self.knn_ind = None
     self.sequence = []
Esempio n. 5
0
class PSLearningProcess(Process):
    def __init__(self,
                 num_actions,
                 buffer_size,
                 latent_dim,
                 obs_dim,
                 conn,
                 gamma=0.99,
                 knn=4,
                 queue_threshold=5e-5,
                 density=True):

        super(PSLearningProcess, self).__init__()
        self.num_actions = num_actions
        self.gamma = gamma
        self.rmax = 100000
        self.logger = logging.getLogger("ec")
        self.sa_explore = 10
        self.min_iter = 20
        self.run_sweep = True
        self.num_iters = 0
        self.conn = conn
        self.buffer_size = buffer_size
        self.latent_dim = latent_dim
        self.obs_dim = obs_dim
        self.knn = knn
        self.update_enough = True
        self.iters_per_step = 0
        self.queue_threshold = queue_threshold
        # self.queue_lock = Lock()
        self.pqueue = HashPQueue()
        self.sequence = []

        self.first_ob_index = -10
        self.use_density = density
        self.ec_buffer = None

    def log(self, *args, logtype='debug', sep=' '):
        getattr(self.logger, logtype)(sep.join(str(a) for a in args))

    def grow_model(self, sa_pair):  # grow model
        index_t, action_t, reward_t, z_tp1, h_tp1, done_t = sa_pair
        index_tp1, knn_dist, knn_ind = self.ec_buffer.peek(z_tp1)
        # self.log("finish peek")
        if index_tp1 < 0:

            index_tp1, override = self.ec_buffer.add_node(
                z_tp1, knn_dist, knn_ind)
            # index_tp1, override = self.ec_buffer.add_node(z_tp1)

            # self.log("add node", index_tp1, logtype='debug')
            if override:
                self.pqueue.remove(index_tp1)

        # if (index_t, action_t) not in self.ec_buffer.prev_id[index_tp1]:
        # self.log("add edge", index_t, action_t, index_tp1, logtype='debug')
        sa_count = self.ec_buffer.add_edge(index_t, index_tp1, action_t,
                                           reward_t, done_t)

        # if sa_coun t > self.sa_explore:
        #     self.ec_buffer.internal_value[index_t, action_t] = 0
        return index_tp1, sa_count

    def save(self, filedir):
        while len(self.pqueue) > 0:  # empty pqueue
            self.backup()
        ec_buffer_file = open(os.path.join(filedir, "ec_buffer.pkl"), "wb")
        pkl.dump(self.ec_buffer, ec_buffer_file)

    def load(self, filedir):
        try:
            ec_buffer_file = open(os.path.join(filedir, "ec_buffer.pkl"), "rb")
            self.ec_buffer = pkl.load(ec_buffer_file)
        except FileNotFoundError:
            return
        self.ec_buffer.allocate()
        batch_size = 32
        for i in range(int(np.ceil((self.buffer_size + 1) / batch_size))):
            low = i * batch_size
            high = min(self.buffer_size, (i + 1) * batch_size)
            z_to_update = self.ec_buffer.states[low:high]

            # print(z_to_update.shape,np.arange(low, high))
            # self.log("z shape", np.array(z_to_update).shape)
            self.ec_buffer.update(np.arange(low, high), np.array(z_to_update))

    def observe(self, sa_pair):
        # self.update_enough.wait(timeout=1000)
        # self.log("ps pqueue len", len(self.pqueue))
        # grow model
        index_tp1, count_t = self.grow_model(sa_pair)
        # update current value
        index_t, action_t, reward_t, z_tp1, h_tp1, done_t = sa_pair
        self.sequence.append(index_t)
        if done_t:
            #     delayed update, so that the value can be efficiently propagated
            if np.isnan(self.ec_buffer.external_value[index_t, action_t]):
                # self.ec_buffer.notify(index_t, action_t)
                self.ec_buffer.external_value[index_t, action_t] = reward_t
                total_count_tp1 = sum(
                    self.ec_buffer.next_id[index_t][action_t].values())
                for s_tp1, count_tp1 in self.ec_buffer.next_id[index_t][
                        action_t].items():
                    trans_p = count_tp1 / total_count_tp1
                    value_tp1 = np.nan_to_num(
                        self.ec_buffer.state_value_u[s_tp1])
                    self.ec_buffer.external_value[
                        index_t, action_t] += self.gamma * trans_p * value_tp1
            else:
                value_tp1 = np.nan_to_num(
                    self.ec_buffer.state_value_u[index_tp1], copy=True)
                self.ec_buffer.external_value[
                    index_t, action_t] += 1 / count_t * (
                        reward_t + self.gamma * value_tp1 -
                        self.ec_buffer.external_value[index_t, action_t])
            self.ec_buffer.state_value_v[index_t] = np.nanmax(
                self.ec_buffer.external_value[index_t, :])
            self.update_sequence()
            self.conn.send((2, index_tp1))
            return

        if np.isnan(
                self.ec_buffer.external_value[index_t, action_t]) and np.isnan(
                    self.ec_buffer.state_value_u[index_tp1]):
            # if next value is nan, we can't infer anything about current q value,
            # so we should return immediately witout update q values
            self.conn.send((2, index_tp1))
            return
        value_tp1 = np.nan_to_num(self.ec_buffer.state_value_u[index_tp1],
                                  copy=True)

        self.log("ps update s,a,count,new_value,old_value", index_t, action_t,
                 count_t, reward_t + self.gamma * value_tp1,
                 self.ec_buffer.external_value[index_t, action_t])
        if np.isnan(self.ec_buffer.external_value[index_t, action_t]):
            self.ec_buffer.external_value[index_t, action_t] = 0
        self.ec_buffer.external_value[index_t, action_t] += 1 / count_t * (
            reward_t + self.gamma * value_tp1 -
            self.ec_buffer.external_value[index_t, action_t])
        self.ec_buffer.state_value_v[index_t] = np.nanmax(
            self.ec_buffer.external_value[index_t, :])

        priority = abs(
            self.ec_buffer.state_value_v[index_t] -
            np.nan_to_num(self.ec_buffer.state_value_u[index_t], copy=True))
        if priority > self.queue_threshold:
            self.pqueue.push(priority, index_t)

            # self.log("add queue", priority, len(self.pqueue))
        # self.iters_per_step = 0
        # self.update_enough.clear()

        assert index_tp1 != -1
        self.conn.send((2, index_tp1))

    def update_sequence(self):
        # to make sure that the final signal can be fast propagate through the state,
        # we need a sequence update like episodic control
        for p, s in enumerate(self.sequence):
            self.pqueue.push(p + self.rmax, s)
            self.ec_buffer.newly_added[s] = False
        self.sequence = []
        self.update_enough = False
        self.iters_per_step = 0
        # self.ec_buffer.build_tree()

    def backup(self):
        # recursive backup
        # self.log("begin backup", self.run_sweep)
        self.num_iters += 1
        # self.log("bk pqueue len", len(self.pqueue))
        if len(self.pqueue) > 0:
            if self.iters_per_step < self.min_iter:
                self.iters_per_step += 1
            # self.log("what is wrong?")
            priority, index = self.pqueue.pop()
            delta_u = self.ec_buffer.state_value_v[index] - np.nan_to_num(
                self.ec_buffer.state_value_u[index], copy=True)
            self.ec_buffer.state_value_u[index] = self.ec_buffer.state_value_v[
                index]

            self.log("backup node", index, "priority", priority, "new value",
                     self.ec_buffer.state_value_v[index], "delta", delta_u)

            # self.log("pqueue len",len(self.pqueue))
            for sa_pair in self.ec_buffer.prev_id[index]:
                state_tm1, action_tm1 = sa_pair
                # self.log("update s,a,s',delta", state_tm1, action_tm1, index, delta_u)
                self.ec_buffer.update_q_value(state_tm1, action_tm1, index,
                                              delta_u)
                self.ec_buffer.state_value_v[state_tm1] = np.nanmax(
                    self.ec_buffer.external_value[state_tm1, :])
                priority_tm1 = abs(
                    self.ec_buffer.state_value_v[state_tm1] - np.nan_to_num(
                        self.ec_buffer.state_value_u[state_tm1], copy=True))
                if priority_tm1 > self.queue_threshold:
                    self.pqueue.push(priority_tm1, state_tm1)
            if priority < self.rmax and not self.update_enough:
                self.update_enough = True
                # self.log("update enough with low priority", self.update_enough, priority)
        if len(self.pqueue) == 0 and not self.update_enough:
            self.update_enough = True
            # self.log("update enough", self.update_enough)
        if self.num_iters % 100000 == 0:
            self.log("backup count", self.num_iters)

    # def peek(self, state):
    #     ind = self.ec_buffer.peek(state)
    #     return ind

    def run(self):

        buffer = LRU_KNN_GPU_PS_DENSITY if self.use_density else LRU_KNN_GPU_PS
        self.ec_buffer = buffer(self.buffer_size, self.latent_dim, 'game', 0,
                                self.num_actions, self.knn)

        while self.run_sweep:
            self.backup()
            if self.update_enough:
                self.recv_msg()
                # self.update_enough = 0

    def retrieve_q_value(self, obj):
        z, h, knn = obj

        extrinsic_qs, intrinsic_qs, find, neighbour_ind, neighbour_dist = self.ec_buffer.act_value_ec(
            z, knn)
        self.conn.send((0, (extrinsic_qs, intrinsic_qs, find, neighbour_ind,
                            neighbour_dist)))

    def peek_node(self, obj):
        z, h = obj
        ind, knn_dist, knn_ind = self.ec_buffer.peek(z)
        if ind == -1:
            ind, _ = self.ec_buffer.add_node(z, knn_dist, knn_ind)
            # ind, _ = self.ec_buffer.add_node(z)

            self.log("add node for first ob ", ind)
            self.first_ob_index = ind
        assert ind != -1
        self.conn.send((1, ind))
        self.log("send finish")

    def recv_msg(self):
        # 0 —— retrieve q values
        # 1 —— peek or add node
        # 2 —— observe
        # 3 —— kill
        while self.conn.poll():
            # self.update_enough = False
            # self.iters_per_step = 0

            msg, obj = self.conn.recv()
            # self.log("receiving message", msg)
            if msg == 0:
                self.retrieve_q_value(obj)
            elif msg == 1:
                self.peek_node(obj)
            elif msg == 2:
                self.observe(obj)
            elif msg == 3:
                self.run_sweep = False
                self.conn.send((3, True))
            elif msg == 4:

                sampled = self.ec_buffer.sample(*obj)

                self.conn.send((4, sampled))
            elif msg == 5:
                indexes, z_new = obj
                self.ec_buffer.update(indexes, z_new)
                self.conn.send((5, True))
            elif msg == 6:
                indexes = obj

                self.conn.send((6, self.ec_buffer.states[indexes]))

            elif msg == 7:
                indexes = obj
                self.conn.send(
                    (7, np.nanmax(self.ec_buffer.external_value[indexes, :])))
            elif msg == 8:
                self.ec_buffer.recompute_density()
                self.conn.send((8, 0))
            elif msg == 9:
                buffer = LRU_KNN_GPU_PS_DENSITY if self.use_density else LRU_KNN_GPU_PS
                self.ec_buffer = buffer(self.buffer_size, self.latent_dim,
                                        'game', 0, self.num_actions, self.knn)
                self.conn.send((9, 0))
            elif msg == 10:  # save
                filedir = obj
                self.save(filedir)
                self.conn.send((10, 0))
            elif msg == 11:  # load
                filedir = obj
                self.load(filedir)
                self.conn.send((11, 0))

            else:
                raise NotImplementedError
Esempio n. 6
0
class LRU_KNN_KBPS(object):
    def __init__(self,
                 num_actions,
                 buffer_size,
                 latent_dim,
                 hash_dim,
                 gamma=0.99,
                 bp=True,
                 debug=True):
        self.num_actions = num_actions
        self.gamma = gamma
        self.rmax = 100000
        self.bp = bp
        self.debug = debug
        self.ec_buffer = LRU_KNN_GPU_PS(buffer_size,
                                        latent_dim,
                                        hash_dim,
                                        'game',
                                        num_actions,
                                        debug=debug)
        self.logger = logging.getLogger("ecbp")
        self.pqueue = HashPQueue()
        self.sa_explore = 10
        self.max_iter = 1000
        self.dist = None
        self.ind = None
        self.b = 10
        self.h = 1

    # def act_value(self, keys, action, knn):
    #     return self.ec_buffer.act_value(keys, knn)

    def log(self, *args, logtype='debug', sep=' '):
        getattr(self.logger, logtype)(sep.join(str(a) for a in args))

    def grow_model(self, sa_pair):  # grow model
        index_t, action_t, reward_t, z_tp1, done_t = sa_pair
        index_tp1, _, _ = self.peek(z_tp1)

        if index_tp1 < 0:
            index_tp1, override = self.ec_buffer.add_node(z_tp1)
            self.log("add node", index_tp1, logtype='debug')
            if override:
                self.pqueue.remove(index_tp1)

        # if (index_t, action_t) not in self.ec_buffer.prev_id[index_tp1]:
        self.log("add edge", index_t, action_t, index_tp1, logtype='debug')
        sa_count = self.ec_buffer.add_edge(index_t, index_tp1, action_t,
                                           reward_t, done_t)
        coeff = np.exp(np.array(self.dist).reshape(-1) / self.b)
        self.log("coeff", coeff.shape, coeff)
        self.ec_buffer.pseudo_count[index_t][action_t] = {}
        for i, s in enumerate(self.ind):

            for sp in self.ec_buffer.next_id[s][action_t].keys():
                dist = self.ec_buffer.distance(
                    self.ec_buffer.states[sp], self.ec_buffer.states[sp] +
                    self.ec_buffer.states[index_t] - self.ec_buffer.states[s])
                reweight = np.exp(np.array(dist).squeeze() / self.h)
                try:
                    self.ec_buffer.pseudo_count[index_t][action_t][sp] += reweight * coeff[i] * \
                                                                          self.ec_buffer.next_id[s][action_t][sp]
                except KeyError:
                    self.ec_buffer.pseudo_count[index_t][action_t][sp] = reweight * coeff[i] * \
                                                                         self.ec_buffer.next_id[s][action_t][sp]

            self.ec_buffer.pseudo_reward[
                index_t,
                action_t] += coeff[i] * self.ec_buffer.reward[s, action_t]
            for sp in self.ec_buffer.next_id[index_t][action_t].keys():
                dist = self.ec_buffer.distance(
                    self.ec_buffer.states[sp], self.ec_buffer.states[sp] +
                    self.ec_buffer.states[s] - self.ec_buffer.states[index_t])
                reweight = np.exp(np.array(dist).squeeze() / self.h)
                try:
                    self.ec_buffer.pseudo_count[s][action_t][sp] += reweight * coeff[i] * \
                                                                    self.ec_buffer.next_id[index_t][action_t][
                                                                        sp]
                except KeyError:
                    self.ec_buffer.pseudo_count[s][action_t][sp] = reweight * coeff[i] * \
                                                                   self.ec_buffer.next_id[index_t][action_t][
                                                                       sp]
            self.ec_buffer.pseudo_reward[
                s, action_t] += coeff[i] * self.ec_buffer.reward[index_t,
                                                                 action_t]
        if sa_count > self.sa_explore:
            self.ec_buffer.internal_value[index_t, action_t] = 0
        return index_tp1, sa_count

    def prioritized_sweeping(self, sa_pair):
        # grow model
        index_tp1, count_t = self.grow_model(sa_pair)
        # update current value
        index_t, action_t, reward_t, z_tp1, done_t = sa_pair
        assert index_t in self.ind, "self should be a neighbor of self"
        for index in self.ind:
            self.update_q_value(index, action_t)
            self.ec_buffer.state_value_v[index_t] = max(
                self.ec_buffer.external_value[index_t, :])
            priority = abs(self.ec_buffer.state_value_v[index_t] -
                           self.ec_buffer.state_value_u[index_t])
            if priority > 1e-7:
                self.pqueue.push(priority, index_t)
        # recursive backup
        # self.log("begin backup")
        num_iters = 0
        while len(self.pqueue) > 0 and num_iters < self.max_iter:
            num_iters += 1
            priority, state = self.pqueue.pop()
            delta_u = self.ec_buffer.state_value_v[
                state] - self.ec_buffer.state_value_u[state]
            self.ec_buffer.state_value_u[state] = self.ec_buffer.state_value_v[
                state]
            # self.log("backup node", state, "priority", priority, "new value", self.ec_buffer.state_value_v[state],
            #          "delta", delta_u)
            for sa_pair in self.ec_buffer.prev_id[state]:
                state_tm1, action_tm1 = sa_pair
                # self.log("update s,a,s',delta", state_tm1, action_tm1, state, delta_u)
                self.update_q_value_backup(state_tm1, action_tm1, state,
                                           delta_u)
                self.ec_buffer.state_value_v[state_tm1] = max(
                    self.ec_buffer.external_value[state_tm1, :])
                priority = abs(self.ec_buffer.state_value_v[state_tm1] -
                               self.ec_buffer.state_value_u[state_tm1])
                if priority > 1e-7:
                    self.pqueue.push(priority, state_tm1)
        # self.log("finish backup")
        return index_tp1

    def peek(self, state):
        ind = self.ec_buffer.peek(state)
        return ind

    def update_q_value(self, state, action):

        n_sa = sum(self.ec_buffer.pseudo_count[state][action].values())
        r_smooth = self.ec_buffer.pseudo_reward[state, action] / n_sa
        # n_sasp = sum([coeff[i] * self.ec_buffer.next_id[s][action].get(state_tp1, 0) for i, s in enumerate(self.ind)])
        self.ec_buffer.external_value[state, action] = r_smooth
        for state_tp1 in self.ec_buffer.pseudo_count[state][action].keys():
            value_tp1 = self.ec_buffer.state_value_u[state_tp1]
            trans_p = self.ec_buffer.pseudo_count[state][action][
                state_tp1] / n_sa
            self.ec_buffer.external_value[
                state, action] += trans_p * self.gamma * value_tp1

    def update_q_value_backup(self, state, action, state_tp1, delta_u):
        n_sa = sum(self.ec_buffer.pseudo_count[state][action].values())
        n_sasp = self.ec_buffer.pseudo_count[state][action].get(state_tp1, 0)
        trans_p = n_sasp / n_sa
        self.ec_buffer.external_value[state,
                                      action] += self.gamma * trans_p * delta_u
Esempio n. 7
0
class KernelBasedPriorSweepProcess(Process):
    def __init__(self, num_actions, buffer_size, latent_dim, hash_dim, conn, gamma=0.99):
        super(KernelBasedPriorSweepProcess, self).__init__()
        self.num_actions = num_actions
        self.gamma = gamma
        self.rmax = 100000
        self.logger = logging.getLogger("ecbp")
        self.sa_explore = 10
        self.max_iter = 1000000
        self.run_sweep = True
        self.num_iters = 0
        self.conn = conn
        self.buffer_size = buffer_size
        self.latent_dim = latent_dim
        self.hash_dim = hash_dim
        # self.queue_lock = Lock()
        self.pqueue = HashPQueue()
        self.b = 0.0001
        self.h = 0.0001
        self.knn_dist = None
        self.knn_ind = None
        self.sequence = []

    def log(self, *args, logtype='debug', sep=' '):
        getattr(self.logger, logtype)(sep.join(str(a) for a in args))

    def grow_model(self, sa_pair):  # grow model
        index_t, action_t, reward_t, z_tp1, done_t = sa_pair
        index_tp1, _, _ = self.peek(z_tp1)

        if index_tp1 < 0:
            index_tp1, override = self.ec_buffer.add_node(z_tp1)
            self.log("add node", index_tp1, logtype='debug')
            if override:
                self.pqueue.remove(index_tp1)

        # if (index_t, action_t) not in self.ec_buffer.prev_id[index_tp1]:
        self.log("add edge", index_t, action_t, index_tp1, logtype='debug')
        sa_count = self.ec_buffer.add_edge(index_t, index_tp1, action_t, reward_t, done_t)
        coeff = np.exp(-np.array(self.knn_dist).reshape(-1) / self.b)
        self.log("coeff", coeff.shape, coeff)
        self.ec_buffer.pseudo_count[index_t][action_t] = {}
        self.ec_buffer.pseudo_reward[index_t, action_t] = 0
        # self.ec_buffer.pseudo_prev[index_tp1] = {}
        assert index_t in self.knn_ind, "self should be a neighbour of self"
        for i, s in enumerate(self.knn_ind):

            for sp in self.ec_buffer.next_id[s][action_t].keys():
                dist = self.ec_buffer.distance(self.ec_buffer.states[sp],
                                               self.ec_buffer.states[sp] + self.ec_buffer.states[index_t] -
                                               self.ec_buffer.states[s])
                reweight = np.exp(-np.array(dist).squeeze() / self.h)
                weighted_count = reweight * coeff[i] * self.ec_buffer.next_id[s][action_t][sp]
                try:
                    self.ec_buffer.pseudo_count[index_t][action_t][sp] += weighted_count
                except KeyError:
                    self.ec_buffer.pseudo_count[index_t][action_t][sp] = weighted_count

                self.ec_buffer.pseudo_prev[sp][(index_t, action_t)] = 1
                self.ec_buffer.pseudo_reward[index_t, action_t] += weighted_count * self.ec_buffer.reward[
                    s, action_t]
            if index_t == s:
                continue
            for sp in self.ec_buffer.next_id[index_t][action_t].keys():
                dist = self.ec_buffer.distance(self.ec_buffer.states[sp],
                                               self.ec_buffer.states[sp] + self.ec_buffer.states[s] -
                                               self.ec_buffer.states[index_t])
                reweight = np.exp(-np.array(dist).squeeze() / self.h)
                weighted_count = reweight * coeff[i] * self.ec_buffer.next_id[index_t][action_t][sp]
                try:
                    self.ec_buffer.pseudo_count[s][action_t][sp] += reweight * coeff[i]
                except KeyError:
                    self.ec_buffer.pseudo_count[s][action_t][sp] = weighted_count
                self.ec_buffer.pseudo_prev[sp][(s, action_t)] = 1
                self.ec_buffer.pseudo_reward[s, action_t] += reweight * coeff[i] * self.ec_buffer.reward[
                    index_t, action_t]
        if sa_count > self.sa_explore:
            self.ec_buffer.internal_value[index_t, action_t] = 0
        return index_tp1, sa_count
    # def grow_model(self, sa_pair):  # grow model
    #     index_t, action_t, reward_t, z_tp1, done_t = sa_pair
    #     index_tp1, _, _ = self.ec_buffer.peek(z_tp1)
    #     # self.log("finish peek")
    #     if index_tp1 < 0:
    #         index_tp1, override = self.ec_buffer.add_node(z_tp1)
    #
    #         self.log("add node", index_tp1, logtype='debug')
    #         if override:
    #             self.pqueue.remove(index_tp1)
    #
    #     # if (index_t, action_t) not in self.ec_buffer.prev_id[index_tp1]:
    #     self.log("add edge", index_t, action_t, index_tp1, logtype='debug')
    #     sa_count = self.ec_buffer.add_edge(index_t, index_tp1, action_t, reward_t, done_t)
    #     self.ec_buffer.pseudo_count[index_t][action_t] = self.ec_buffer.pseudo_count[index_t][action_t]
    #     self.ec_buffer.pseudo_count[index_t][action_t] = self.ec_buffer.next_id[index_t][action_t]
    #     # self.pseudo_count = [[{} for __ in range(num_actions)] for _ in range(capacity)]
    #     self.ec_buffer.pseudo_reward[index_t,action_t] = reward_t*sum(self.ec_buffer.pseudo_count[index_t][action_t].values())
    #     self.ec_buffer.pseudo_prev[index_tp1] = {x:1 for x in self.ec_buffer.prev_id[index_tp1]}
    #     # if sa_coun t > self.sa_explore:
    #     #     self.ec_buffer.internal_value[index_t, action_t] = 0
    #     return index_tp1, sa_count

    def observe(self, sa_pair):
        # self.update_enough.wait(timeout=1000)
        # self.log("ps pqueue len", len(self.pqueue))
        # grow model
        index_tp1, count_t = self.grow_model(sa_pair)
        # update current value
        index_t, action_t, reward_t, z_tp1, done_t = sa_pair
        self.sequence.append(index_t)
        self.log("self neighbour", index_t, self.knn_ind)
        assert index_t in self.knn_ind, "self should be a neighbor of self"
        for index in self.knn_ind:
            # self.log("q before observe", self.ec_buffer.external_value[index, :],index,action_t)
            self.update_q_value(index, action_t)
            # self.log("q after observe", self.ec_buffer.external_value[index, :], index, action_t)
            self.ec_buffer.state_value_v[index_t] = np.nanmax(self.ec_buffer.external_value[index_t, :])
            priority = abs(
                self.ec_buffer.state_value_v[index_t] - np.nan_to_num(self.ec_buffer.state_value_u[index_t], copy=True))
            if priority > 1e-7:
                self.pqueue.push(priority, index_t)
        if done_t:
            self.update_sequence()
        # self.iters_per_step = 0
        # self.update_enough.clear()
        self.conn.send((2, index_tp1))

    def backup(self):
        # recursive backup
        self.num_iters += 1
        if len(self.pqueue) > 0:
            priority, index = self.pqueue.pop()
            delta_u = self.ec_buffer.state_value_v[index] - np.nan_to_num(self.ec_buffer.state_value_u[index],
                                                                          copy=True)
            self.ec_buffer.state_value_u[index] = self.ec_buffer.state_value_v[index]
            self.log("backup node", index, "priority", priority, "new value",
                     self.ec_buffer.state_value_v[index],
                     "delta", delta_u)
            for sa_pair in self.ec_buffer.pseudo_prev[index].keys():
                state_tm1, action_tm1 = sa_pair
                # self.log("update s,a,s',delta", state_tm1, action_tm1, index, delta_u)
                # self.log("q before backup",self.ec_buffer.external_value[state_tm1,:],state_tm1,action_tm1)
                self.update_q_value_backup(state_tm1, action_tm1, index, delta_u)
                self.ec_buffer.state_value_v[state_tm1] = np.nanmax(self.ec_buffer.external_value[state_tm1, :])
                # self.log("q after backup", self.ec_buffer.external_value[index, :], state_tm1,action_tm1)
                priority = abs(
                    self.ec_buffer.state_value_v[state_tm1] - np.nan_to_num(
                        self.ec_buffer.state_value_u[state_tm1], copy=True))
                if priority > 1e-7:
                    self.pqueue.push(priority, state_tm1)
        if self.num_iters % 100000 == 0:
            self.log("backup count", self.num_iters)

    def update_sequence(self):
        # to make sure that the final signal can be fast propagate through the state,
        # we need a sequence update like episodic control
        for p, s in enumerate(self.sequence):
            # self.pqueue.push(p + self.rmax, s)
            self.ec_buffer.newly_added[s] = False
        self.sequence = []

        # self.ec_buffer.build_tree()

    def update_q_value(self, state, action):

        n_sa = sum(self.ec_buffer.pseudo_count[state][action].values())
        if n_sa < 1e-7:
            return
        r_smooth = np.nan_to_num(self.ec_buffer.pseudo_reward[state, action] / n_sa)
        # n_sasp = sum([coeff[i] * self.ec_buffer.next_id[s][action].get(state_tp1, 0) for i, s in enumerate(self.ind)])
        self.ec_buffer.external_value[state, action] = r_smooth
        for state_tp1 in self.ec_buffer.pseudo_count[state][action].keys():
            value_tp1 = np.nan_to_num(self.ec_buffer.state_value_u[state_tp1])
            trans_p = self.ec_buffer.pseudo_count[state][action][state_tp1] / n_sa
            self.ec_buffer.external_value[state, action] += trans_p * self.gamma * value_tp1

    def update_q_value_backup(self, state, action, state_tp1, delta_u):
        n_sa = sum(self.ec_buffer.pseudo_count[state][action].values())
        if n_sa < 1e-7:
            return
        n_sasp = self.ec_buffer.pseudo_count[state][action].get(state_tp1, 0)
        trans_p = n_sasp / n_sa
        assert 0 <= trans_p <= 1, "nsa{} nsap{} trans{}".format(n_sa, n_sasp, trans_p)
        if np.isnan(self.ec_buffer.external_value[state, action]):
            self.ec_buffer.external_value[state, action] = 0
        self.ec_buffer.external_value[state, action] += self.gamma * trans_p * delta_u

    def peek(self, state):
        ind = self.ec_buffer.peek(state)
        return ind

    def run(self):
        self.ec_buffer = LRU_KNN_GPU_PS(self.buffer_size, self.hash_dim, 'game', 0, self.num_actions)
        while self.run_sweep:
            self.backup()
            self.recv_msg()

    def retrieve_q_value(self, obj):
        z, knn = obj
        extrinsic_qs, intrinsic_qs, find = self.ec_buffer.act_value_ec(z, knn)
        self.conn.send((0, (extrinsic_qs, intrinsic_qs, find)))

    def peek_node(self, obj):
        z = obj
        ind, knn_dist, knn_ind = self.ec_buffer.peek(z)
        knn_dist = np.array(knn_dist).reshape(-1).tolist()
        knn_ind = np.array(knn_ind).reshape(-1).tolist()
        if ind == -1:
            ind, _ = self.ec_buffer.add_node(z)
            knn_dist = [0] + knn_dist
            knn_ind = [ind] + knn_ind
            self.log("add node for first ob ", ind)
        self.knn_dist = knn_dist
        self.knn_ind = knn_ind
        self.conn.send((1, ind))

    def recv_msg(self):
        # 0 —— retrieve q values
        # 1 —— peek or add node
        # 2 —— observe
        # 3 —— kill
        while self.conn.poll():
            msg, obj = self.conn.recv()
            if msg == 0:
                self.retrieve_q_value(obj)
            elif msg == 1:
                self.peek_node(obj)
            elif msg == 2:
                self.observe(obj)
            elif msg == 3:
                self.run_sweep = False
                self.conn.send((3, True))
            else:
                raise NotImplementedError