def test_global_variable_saver_from_arrays(variable, name, shape): with tf.Session() as session: session.run(tf.global_variables_initializer()) saver = GlobalVariableSaver("name") saver.from_arrays(session, {name: np.ones(shape)}) arrays = saver.to_arrays(session) assert_arrays_ones_shape(arrays, shape, name)
def test_global_variable_saver_to_arrays(variable, name, shape): with tf.compat.v1.Session() as session: session.run(tf.compat.v1.global_variables_initializer()) session.run(variable.assign(tf.ones(shape))) saver = GlobalVariableSaver("name") arrays = saver.to_arrays(session) assert_arrays_ones_shape(arrays, shape, name)
def test_global_variable_saver_from_string(variable, name, shape): with tf.Session() as session: session.run(tf.global_variables_initializer()) saver = GlobalVariableSaver("name") saver.from_string(session, pickle.dumps({name: np.ones(shape)}, protocol=-1)) arrays = saver.to_arrays(session) assert_arrays_ones_shape(arrays, shape, name)
def test_global_variable_saver_to_string(variable, name, shape): with tf.Session() as session: session.run(tf.global_variables_initializer()) session.run(variable.assign(tf.ones(shape))) saver = GlobalVariableSaver("name") string = saver.to_string(session) arrays = pickle.loads(string) assert_arrays_ones_shape(arrays, shape, name)
def save_policy(self, graph_manager): """ Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy event """ if self.saver is None: self.saver = GlobalVariableSaver() # TODO: only subscribe if this data store is being used to publish policies self._connect() self.pubsub.unsubscribe(self.params.redis_channel) policy_string = self.saver.to_string(graph_manager.sess) self.redis_connection.set(self.params.redis_channel, policy_string) self.redis_connection.publish(self.params.redis_channel, "new_policy")
def load_policy(self, graph_manager, require_new_policy=True, timeout=0): """ :param graph_manager: the graph_manager to load the policy into :param require_new_policy: if True, only load a policy if it hasn't been loaded in this process yet before. :param timeout: Will only try to load the policy once if timeout is None, otherwise will retry for timeout seconds """ if self.saver is None: # the GlobalVariableSaver needs to be instantiated after the graph is created. For now, # it can be instantiated here, but it might be nicer to have a more explicit # on_graph_creation_end callback or similar to put it in self.saver = GlobalVariableSaver() self._connect() if not require_new_policy: # try just loading whatever policy is available most recently if self._load_policy(graph_manager): return message = "first" timeout_ends = time.time() + timeout while time.time() < timeout_ends or message == "first": message = self.pubsub.get_message() if message and message["type"] == "message": if message["data"] == b"end_of_policies": self._end_of_policies = True return elif message["data"] == b"new_policy": if self._load_policy(graph_manager): return else: raise ValueError( "'new_policy' message was sent, but no policy was found." ) time.sleep(1.0) if require_new_policy: raise ValueError( "Waited for {timeout} seconds on channel {channel}, but no first policy was received." .format(timeout=timeout, channel=self.params.redis_channel))
def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ Collection of all checkpoints for the network (typically only one checkpoint) :param parent_path_suffix: path suffix of the parent of the network (e.g. could be name of level manager plus name of agent) :return: checkpoint collection for the network """ savers = SaverCollection() if not self.distributed_training: savers.add(GlobalVariableSaver(self.name)) return savers
class RedisDataStore(DataStore): """ This DataStore sends policies over redis pubsub and get/set. Deployment ========== It assumes that a redis server is already available. We make this assumption because during multinode training at this time, redis is already used for communicating replay memories. Communication ============= A redis pubsub channel is used by the training worker to signal to the rollout workers that a new policy is ready. When this occurs, a new policy is loaded from the redis key/value store where key is the same as the pubsub channel. Originally, just the pubsub was used, but that could result in a race condition where the master worker publishes the first policy and waits for the rollout workers to submit all rollouts, while a delayed rollout worker waits for the first policy since it subscribed to the channel after the initial policy was published. """ def __init__(self, params: RedisDataStoreParameters): self.params = params self.saver = None self._end_of_policies = False # NOTE: a connection is not attempted at this stage because the address and port are likely # not available yet. This is because of how the kubernetes orchestrator works. At the time # of parameter construction, the address and port are not yet known since they are copied # out of the redis memory backend after it is deployed. One improvement would be to use # two separate redis deployments independently, and let this class deploy its own redis. def _connect(self): """ Connect to redis and subscribe to the pubsub channel """ self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port) self.pubsub = self.redis_connection.pubsub( ignore_subscribe_messages=True) self.pubsub.subscribe(self.params.redis_channel) self._end_of_policies = False def deploy(self): """ For now, this data store does not handle its own deployment, it piggybacks off of the redis memory backend """ return True def undeploy(self): """ For now, this data store does not handle its own deployment, it piggybacks off of the redis memory backend """ pass def save_to_store(self): """ save_to_store and load_from_store are not used in the case where the data stored needs to synchronize checkpoints saved to disk into a central file system, and not used here """ pass def load_from_store(self): """ save_to_store and load_from_store are not used in the case where the data stored needs to synchronize checkpoints saved to disk into a central file system, and not used here """ pass def save_policy(self, graph_manager): """ Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy event """ if self.saver is None: self.saver = GlobalVariableSaver() # TODO: only subscribe if this data store is being used to publish policies self._connect() self.pubsub.unsubscribe(self.params.redis_channel) policy_string = self.saver.to_string(graph_manager.sess) self.redis_connection.set(self.params.redis_channel, policy_string) self.redis_connection.publish(self.params.redis_channel, "new_policy") def _load_policy(self, graph_manager) -> bool: """ Get the most recent policy from redis and loaded into the graph_manager """ policy_string = self.redis_connection.get(self.params.redis_channel) if policy_string is None: return False self.saver.from_string(graph_manager.sess, policy_string) return True def load_policy(self, graph_manager, require_new_policy=True, timeout=0): """ :param graph_manager: the graph_manager to load the policy into :param require_new_policy: if True, only load a policy if it hasn't been loaded in this process yet before. :param timeout: Will only try to load the policy once if timeout is None, otherwise will retry for timeout seconds """ if self.saver is None: # the GlobalVariableSaver needs to be instantiated after the graph is created. For now, # it can be instantiated here, but it might be nicer to have a more explicit # on_graph_creation_end callback or similar to put it in self.saver = GlobalVariableSaver() self._connect() if not require_new_policy: # try just loading whatever policy is available most recently if self._load_policy(graph_manager): return message = "first" timeout_ends = time.time() + timeout while time.time() < timeout_ends or message == "first": message = self.pubsub.get_message() if message and message["type"] == "message": if message["data"] == b"end_of_policies": self._end_of_policies = True return elif message["data"] == b"new_policy": if self._load_policy(graph_manager): return else: raise ValueError( "'new_policy' message was sent, but no policy was found." ) time.sleep(1.0) if require_new_policy: raise ValueError( "Waited for {timeout} seconds on channel {channel}, but no first policy was received." .format(timeout=timeout, channel=self.params.redis_channel)) def end_of_policies(self) -> bool: """ This is used by the rollout workers to detect a message from the training worker signaling that training is complete. """ return self._end_of_policies