예제 #1
0
    def gather_to_master(self, key):
        """
        This method assume that obj is summable to list.
        """

        if self.rank < 0:
            num_done = 0
            objs = []
            while True:
                time.sleep(0.1)
                # This for iteration can be faster.
                for rank in range(self.world_size):
                    trigger = self.r.get(key + '_trigger' + "_{}".format(rank))
                    if _int(trigger) == 1:
                        obj = cloudpickle.loads(
                            self.r.get(key + "_{}".format(rank)))
                        objs += obj
                        self.r.set(key + '_trigger' + "_{}".format(rank), '0')
                        num_done += 1
                if num_done == self.world_size:
                    break
            setattr(self, key, objs)
        else:
            obj = getattr(self, key)
            self.r.set(key + "_{}".format(self.rank), cloudpickle.dumps(obj))
            self.r.set(key + '_trigger' + "_{}".format(self.rank), '1')
            while True:
                time.sleep(0.1)
                if _int(self.r.get(key + '_trigger' + "_{}".format(self.rank))) == 0:
                    break
예제 #2
0
    def gather_to_master(self, key):
        """
        master: wait trigger, then get the value from DB
        sampler: set `key` to DB, then wait master fetch

        This method assume that obj is summable to list.
        """

        if self.rank < 0:
            num_done = 0
            objs = []
            while True:
                time.sleep(0.1)
                # This for iteration can be faster.
                for rank in range(self.world_size):
                    trigger = self.r.get(key + '_trigger' + "_{}".format(rank))
                    if _int(trigger) == 1:
                        obj = cloudpickle.loads(
                            self.r.get(key + "_{}".format(rank)))
                        objs += obj
                        self.r.set(key + '_trigger' + "_{}".format(rank), '0')
                        num_done += 1
                if num_done == self.world_size:
                    break
            setattr(self, key, objs)
        else:
            obj = getattr(self, key)
            self.r.set(key + "_{}".format(self.rank), cloudpickle.dumps(obj))
            trigger = '{}_trigger_{}'.format(key, self.rank)
            self.wait_trigger_completion(trigger)
예제 #3
0
 def sync(self, keys, target_value):
     """Wait until all `keys` become `target_value`
     """
     while True:
         values = self.r.mget(keys)
         if all([_int(v) == target_value for v in values]):
             break
         time.sleep(0.1)
예제 #4
0
def sync(traj, master_rank=0):
    """
    Synchronize trajs. This function is used in multi node situation, and use redis.

    Parameters
    ----------
    traj : Traj
    master_rank : int
        master_rank's traj is scattered

    Returns
    -------
    traj : Traj
    """
    rank = traj.rank
    r = get_redis()
    if rank == master_rank:
        obj = cloudpickle.dumps(traj)
        r.set('Traj', obj)
        triggers = {
            'Traj_trigger' + "_{}".format(rank): '1'
            for rank in range(traj.world_size)
        }
        triggers["Traj_trigger_{}".format(master_rank)] = '0'
        r.mset(triggers)
        while True:
            time.sleep(0.1)
            values = r.mget(triggers)
            if all([_int(v) == 0 for v in values]):
                break
    else:
        while True:
            time.sleep(0.1)
            trigger = r.get('Traj_trigger' + "_{}".format(rank))
            if _int(trigger) == 1:
                break
        obj = cloudpickle.loads(r.get('Traj'))

        traj.copy(obj)
        r.set('Traj_trigger' + "_{}".format(rank), '0')

    return traj
예제 #5
0
    def scatter_from_master(self, key):

        if self.rank < 0:
            obj = getattr(self, key)
            self.r.set(key, cloudpickle.dumps(obj))
            triggers = {key + '_trigger' +
                        "_{}".format(rank): '1' for rank in range(self.world_size)}
            self.r.mset(triggers)
            while True:
                time.sleep(0.1)
                values = self.r.mget(triggers)
                if all([_int(v) == 0 for v in values]):
                    break
        else:
            while True:
                time.sleep(0.1)
                trigger = self.r.get(key + '_trigger' +
                                     "_{}".format(self.rank))
                if _int(trigger) == 1:
                    break
            obj = cloudpickle.loads(self.r.get(key))
            setattr(self, key, obj)
            self.r.set(key + '_trigger' + "_{}".format(self.rank), '0')