示例#1
0
def test_global_variable_saver_from_arrays(variable, name, shape):
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())

        saver = GlobalVariableSaver("name")
        saver.from_arrays(session, {name: np.ones(shape)})
        arrays = saver.to_arrays(session)
        assert_arrays_ones_shape(arrays, shape, name)
示例#2
0
def test_global_variable_saver_to_arrays(variable, name, shape):
    with tf.compat.v1.Session() as session:
        session.run(tf.compat.v1.global_variables_initializer())
        session.run(variable.assign(tf.ones(shape)))

        saver = GlobalVariableSaver("name")
        arrays = saver.to_arrays(session)
        assert_arrays_ones_shape(arrays, shape, name)
示例#3
0
def test_global_variable_saver_from_string(variable, name, shape):
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())

        saver = GlobalVariableSaver("name")
        saver.from_string(session,
                          pickle.dumps({name: np.ones(shape)}, protocol=-1))
        arrays = saver.to_arrays(session)
        assert_arrays_ones_shape(arrays, shape, name)
示例#4
0
def test_global_variable_saver_to_string(variable, name, shape):
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(variable.assign(tf.ones(shape)))

        saver = GlobalVariableSaver("name")
        string = saver.to_string(session)
        arrays = pickle.loads(string)
        assert_arrays_ones_shape(arrays, shape, name)
    def save_policy(self, graph_manager):
        """
        Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy
        event
        """
        if self.saver is None:
            self.saver = GlobalVariableSaver()

            # TODO: only subscribe if this data store is being used to publish policies
            self._connect()
            self.pubsub.unsubscribe(self.params.redis_channel)

        policy_string = self.saver.to_string(graph_manager.sess)
        self.redis_connection.set(self.params.redis_channel, policy_string)
        self.redis_connection.publish(self.params.redis_channel, "new_policy")
    def load_policy(self, graph_manager, require_new_policy=True, timeout=0):
        """
        :param graph_manager: the graph_manager to load the policy into
        :param require_new_policy: if True, only load a policy if it hasn't been loaded in this
        process yet before.
        :param timeout: Will only try to load the policy once if timeout is None, otherwise will
        retry for timeout seconds
        """
        if self.saver is None:
            # the GlobalVariableSaver needs to be instantiated after the graph is created. For now,
            # it can be instantiated here, but it might be nicer to have a more explicit
            # on_graph_creation_end callback or similar to put it in
            self.saver = GlobalVariableSaver()
            self._connect()

        if not require_new_policy:
            # try just loading whatever policy is available most recently
            if self._load_policy(graph_manager):
                return

        message = "first"
        timeout_ends = time.time() + timeout
        while time.time() < timeout_ends or message == "first":
            message = self.pubsub.get_message()

            if message and message["type"] == "message":
                if message["data"] == b"end_of_policies":
                    self._end_of_policies = True
                    return
                elif message["data"] == b"new_policy":
                    if self._load_policy(graph_manager):
                        return
                    else:
                        raise ValueError(
                            "'new_policy' message was sent, but no policy was found."
                        )

            time.sleep(1.0)

        if require_new_policy:
            raise ValueError(
                "Waited for {timeout} seconds on channel {channel}, but no first policy was received."
                .format(timeout=timeout, channel=self.params.redis_channel))
示例#7
0
 def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
     """
     Collection of all checkpoints for the network (typically only one checkpoint)
     :param parent_path_suffix: path suffix of the parent of the network
         (e.g. could be name of level manager plus name of agent)
     :return: checkpoint collection for the network
     """
     savers = SaverCollection()
     if not self.distributed_training:
         savers.add(GlobalVariableSaver(self.name))
     return savers
class RedisDataStore(DataStore):
    """
    This DataStore sends policies over redis pubsub and get/set.

    Deployment
    ==========
    It assumes that a redis server is already available. We make this assumption because during
    multinode training at this time, redis is already used for communicating replay memories.

    Communication
    =============

    A redis pubsub channel is used by the training worker to signal to the rollout workers that a
    new policy is ready. When this occurs, a new policy is loaded from the redis key/value store
    where key is the same as the pubsub channel. Originally, just the pubsub was used, but that
    could result in a race condition where the master worker publishes the first policy and waits
    for the rollout workers to submit all rollouts, while a delayed rollout worker waits for the
    first policy since it subscribed to the channel after the initial policy was published.
    """
    def __init__(self, params: RedisDataStoreParameters):
        self.params = params
        self.saver = None
        self._end_of_policies = False

        # NOTE: a connection is not attempted at this stage because the address and port are likely
        # not available yet. This is because of how the kubernetes orchestrator works. At the time
        # of parameter construction, the address and port are not yet known since they are copied
        # out of the redis memory backend after it is deployed. One improvement would be to use
        # two separate redis deployments independently, and let this class deploy its own redis.

    def _connect(self):
        """
        Connect to redis and subscribe to the pubsub channel
        """
        self.redis_connection = redis.Redis(self.params.redis_address,
                                            self.params.redis_port)
        self.pubsub = self.redis_connection.pubsub(
            ignore_subscribe_messages=True)
        self.pubsub.subscribe(self.params.redis_channel)

        self._end_of_policies = False

    def deploy(self):
        """
        For now, this data store does not handle its own deployment, it piggybacks off of the redis
        memory backend
        """
        return True

    def undeploy(self):
        """
        For now, this data store does not handle its own deployment, it piggybacks off of the redis
        memory backend
        """
        pass

    def save_to_store(self):
        """
        save_to_store and load_from_store are not used in the case where the data stored needs to
        synchronize checkpoints saved to disk into a central file system, and not used here
        """
        pass

    def load_from_store(self):
        """
        save_to_store and load_from_store are not used in the case where the data stored needs to
        synchronize checkpoints saved to disk into a central file system, and not used here
        """
        pass

    def save_policy(self, graph_manager):
        """
        Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy
        event
        """
        if self.saver is None:
            self.saver = GlobalVariableSaver()

            # TODO: only subscribe if this data store is being used to publish policies
            self._connect()
            self.pubsub.unsubscribe(self.params.redis_channel)

        policy_string = self.saver.to_string(graph_manager.sess)
        self.redis_connection.set(self.params.redis_channel, policy_string)
        self.redis_connection.publish(self.params.redis_channel, "new_policy")

    def _load_policy(self, graph_manager) -> bool:
        """
        Get the most recent policy from redis and loaded into the graph_manager
        """
        policy_string = self.redis_connection.get(self.params.redis_channel)
        if policy_string is None:
            return False

        self.saver.from_string(graph_manager.sess, policy_string)
        return True

    def load_policy(self, graph_manager, require_new_policy=True, timeout=0):
        """
        :param graph_manager: the graph_manager to load the policy into
        :param require_new_policy: if True, only load a policy if it hasn't been loaded in this
        process yet before.
        :param timeout: Will only try to load the policy once if timeout is None, otherwise will
        retry for timeout seconds
        """
        if self.saver is None:
            # the GlobalVariableSaver needs to be instantiated after the graph is created. For now,
            # it can be instantiated here, but it might be nicer to have a more explicit
            # on_graph_creation_end callback or similar to put it in
            self.saver = GlobalVariableSaver()
            self._connect()

        if not require_new_policy:
            # try just loading whatever policy is available most recently
            if self._load_policy(graph_manager):
                return

        message = "first"
        timeout_ends = time.time() + timeout
        while time.time() < timeout_ends or message == "first":
            message = self.pubsub.get_message()

            if message and message["type"] == "message":
                if message["data"] == b"end_of_policies":
                    self._end_of_policies = True
                    return
                elif message["data"] == b"new_policy":
                    if self._load_policy(graph_manager):
                        return
                    else:
                        raise ValueError(
                            "'new_policy' message was sent, but no policy was found."
                        )

            time.sleep(1.0)

        if require_new_policy:
            raise ValueError(
                "Waited for {timeout} seconds on channel {channel}, but no first policy was received."
                .format(timeout=timeout, channel=self.params.redis_channel))

    def end_of_policies(self) -> bool:
        """
        This is used by the rollout workers to detect a message from the training worker signaling
        that training is complete.
        """
        return self._end_of_policies