def _write_tensor_specs(initial_agent_state: Any, env_output: common.EnvOutput, agent_output: common.AgentOutput, actor_action: common.ActorAction, loss_type: Optional[int] = common.AC_LOSS): """Writes tensor specs of ActorOutput tuple to disk. Args: initial_agent_state: A tensor or nested structure of tensor without any time or batch dimensions. env_output: An instance of `EnvOutput` where individual tensors don't have time and batch dimensions. agent_output: An instance of `AgentOutput` where individual tensors don't have time and batch dimensions. actor_action: An instance of `ActorAction`. loss_type: A scalar int denoting the loss type. """ actor_output = common.ActorOutput(initial_agent_state, env_output, agent_output, actor_action, loss_type, info='') specs = tf.nest.map_structure(tf.convert_to_tensor, actor_output) specs = tf.nest.map_structure(tf.TensorSpec.from_tensor, specs) env_output = tf.nest.map_structure(add_time_dimension, specs.env_output) agent_output = tf.nest.map_structure(add_time_dimension, specs.agent_output) actor_action = tf.nest.map_structure(add_time_dimension, specs.actor_action) specs = specs._replace(env_output=env_output, agent_output=agent_output, actor_action=actor_action) utils.write_specs(FLAGS.logdir, specs)
def __init__(self, unroll_length=1): self._env = MockEnv(state_space_size=4, unroll_length=unroll_length) self._agent = MockAgent(unroll_length=unroll_length) self._actor_output_spec = common.ActorOutput( initial_agent_state=tf.TensorSpec(shape=[5], dtype=tf.float32), env_output=self._env.env_spec, agent_output=self._agent.agent_spec, actor_action=common.ActorAction( chosen_action_idx=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.int32), oracle_next_action_idx=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.int32)), loss_type=tf.TensorSpec(shape=[], dtype=tf.int32), info=tf.TensorSpec(shape=[], dtype=tf.string), )
def run_with_learner(problem_type: framework_problem_type.ProblemType, learner_address: Text, hparams: Dict[Text, Any]): """Runs actor with the given learner address and problem type. Args: problem_type: An instance of `framework_problem_type.ProblemType`. learner_address: The network address of a learner exposing two methods: `variable_values`: which returns latest value of trainable variables. `enqueue`: which accepts nested tensors of type `ActorOutput` tuple. hparams: A dict containing hyperparameter settings. """ env = problem_type.get_environment() agent = problem_type.get_agent() env_output = env.reset() initial_agent_state = agent.get_initial_state(utils.add_batch_dim( env_output.observation), batch_size=1) # Agent always expects time,batch dimensions. First add and then remove. env_output = utils.add_time_batch_dim(env_output) agent_output, _ = agent(env_output, initial_agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) actor_action = common.ActorAction( chosen_action_idx=tf.zeros([], dtype=tf.int32), oracle_next_action_idx=tf.zeros([], dtype=tf.int32)) # Remove batch_dim from returned agent's initial state. initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0), initial_agent_state) # Write TensorSpecs the learner can use for initialization. logging.info('My task id is %d', FLAGS.task) if FLAGS.task == 0: _write_tensor_specs(initial_agent_state, env_output, agent_output, actor_action) # gRPC Client creation blocks until the server responds to an RPC. Since the # server blocks at startup looking for TensorSpecs, and will not respond to # gRPC calls until these TensorSpecs are written, client creation must happen # after the actor writes TensorSpecs in order to prevent a deadlock. logging.info('Connecting to learner: %s', learner_address) client = grpc.Client(learner_address) iter_steps = 0 num_steps = 0 sum_reward = 0. # add batch_dim agent_state = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), initial_agent_state) iterations = 0 while iter_steps < hparams['max_iter'] or hparams['max_iter'] == -1: logging.info('Iteration %d of %d', iter_steps + 1, hparams['max_iter']) # Get fresh parameters from the trainer. var_dtypes = [v.dtype for v in agent.trainable_variables] # trainer also adds `iterations` to the list of variables -- which is a # counter tracking number of iterations done so far. var_dtypes.append(tf.int64) new_values = [] if iter_steps % hparams['sync_agent_every_n_steps'] == 0: new_values = client.variable_values() # pytype: disable=attribute-error if new_values: logging.debug('Fetched variables from learner.') iterations = new_values[-1].numpy() updated_agent_vars = new_values[:-1] assert len(updated_agent_vars) == len(agent.trainable_variables) for x, y in zip(agent.trainable_variables, updated_agent_vars): x.assign(y) infos = [] # Unroll agent. # Every episode sent by actor includes previous episode's final agent # state and output as well as final environment output. initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0), agent_state) env_outputs = [env_output] agent_outputs = [agent_output] actor_actions = [actor_action] loss_type = problem_type.get_episode_loss_type(iterations) for i in range(FLAGS.unroll_length): logging.debug('Unroll step %d of %d', i + 1, FLAGS.unroll_length) # Agent expects time,batch dimensions in `env_output` and batch # dimension in `agent_state`. `agent_state` already has batch_dim. env_output = utils.add_time_batch_dim(env_output) agent_output, agent_state = agent(env_output, agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) actor_action, action_val = problem_type.select_actor_action( env_output, agent_output) env_output = env.step(action_val) env_outputs.append(env_output) agent_outputs.append(agent_output) actor_actions.append(actor_action) num_steps += 1 sum_reward += env_output.reward if env_output.done: infos.append( problem_type.get_actor_info(env_output, sum_reward, num_steps)) num_steps = 0 sum_reward = 0. processed_env_output = problem_type.postprocessing( utils.stack_nested_tensors(env_outputs)) actor_output = common.ActorOutput( initial_agent_state=initial_agent_state, env_output=processed_env_output, agent_output=utils.stack_nested_tensors(agent_outputs), actor_action=utils.stack_nested_tensors(actor_actions), loss_type=tf.convert_to_tensor(loss_type, tf.int32), info=pickle.dumps(infos)) flattened = tf.nest.flatten(actor_output) client.enqueue(flattened) # pytype: disable=attribute-error iter_steps += 1