示例#1
0
def _write_tensor_specs(initial_agent_state: Any,
                        env_output: common.EnvOutput,
                        agent_output: common.AgentOutput,
                        actor_action: common.ActorAction,
                        loss_type: Optional[int] = common.AC_LOSS):
    """Writes tensor specs of ActorOutput tuple to disk.

  Args:
    initial_agent_state: A tensor or nested structure of tensor without any time
      or batch dimensions.
    env_output: An instance of `EnvOutput` where individual tensors don't have
      time and batch dimensions.
    agent_output: An instance of `AgentOutput` where individual tensors don't
      have time and batch dimensions.
    actor_action: An instance of `ActorAction`.
    loss_type: A scalar int denoting the loss type.
  """
    actor_output = common.ActorOutput(initial_agent_state,
                                      env_output,
                                      agent_output,
                                      actor_action,
                                      loss_type,
                                      info='')
    specs = tf.nest.map_structure(tf.convert_to_tensor, actor_output)
    specs = tf.nest.map_structure(tf.TensorSpec.from_tensor, specs)
    env_output = tf.nest.map_structure(add_time_dimension, specs.env_output)
    agent_output = tf.nest.map_structure(add_time_dimension,
                                         specs.agent_output)
    actor_action = tf.nest.map_structure(add_time_dimension,
                                         specs.actor_action)
    specs = specs._replace(env_output=env_output,
                           agent_output=agent_output,
                           actor_action=actor_action)
    utils.write_specs(FLAGS.logdir, specs)
示例#2
0
 def __init__(self, unroll_length=1):
     self._env = MockEnv(state_space_size=4, unroll_length=unroll_length)
     self._agent = MockAgent(unroll_length=unroll_length)
     self._actor_output_spec = common.ActorOutput(
         initial_agent_state=tf.TensorSpec(shape=[5], dtype=tf.float32),
         env_output=self._env.env_spec,
         agent_output=self._agent.agent_spec,
         actor_action=common.ActorAction(
             chosen_action_idx=tf.TensorSpec(shape=[unroll_length + 1],
                                             dtype=tf.int32),
             oracle_next_action_idx=tf.TensorSpec(shape=[unroll_length + 1],
                                                  dtype=tf.int32)),
         loss_type=tf.TensorSpec(shape=[], dtype=tf.int32),
         info=tf.TensorSpec(shape=[], dtype=tf.string),
     )
示例#3
0
def run_with_learner(problem_type: framework_problem_type.ProblemType,
                     learner_address: Text, hparams: Dict[Text, Any]):
    """Runs actor with the given learner address and problem type.

  Args:
    problem_type: An instance of `framework_problem_type.ProblemType`.
    learner_address: The network address of a learner exposing two methods:
      `variable_values`: which returns latest value of trainable variables.
      `enqueue`: which accepts nested tensors of type `ActorOutput` tuple.
    hparams: A dict containing hyperparameter settings.
  """
    env = problem_type.get_environment()
    agent = problem_type.get_agent()
    env_output = env.reset()
    initial_agent_state = agent.get_initial_state(utils.add_batch_dim(
        env_output.observation),
                                                  batch_size=1)
    # Agent always expects time,batch dimensions. First add and then remove.
    env_output = utils.add_time_batch_dim(env_output)
    agent_output, _ = agent(env_output, initial_agent_state)
    env_output, agent_output = utils.remove_time_batch_dim(
        env_output, agent_output)
    actor_action = common.ActorAction(
        chosen_action_idx=tf.zeros([], dtype=tf.int32),
        oracle_next_action_idx=tf.zeros([], dtype=tf.int32))
    # Remove batch_dim from returned agent's initial state.
    initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0),
                                                initial_agent_state)

    # Write TensorSpecs the learner can use for initialization.
    logging.info('My task id is %d', FLAGS.task)
    if FLAGS.task == 0:
        _write_tensor_specs(initial_agent_state, env_output, agent_output,
                            actor_action)

    # gRPC Client creation blocks until the server responds to an RPC. Since the
    # server blocks at startup looking for TensorSpecs, and will not respond to
    # gRPC calls until these TensorSpecs are written, client creation must happen
    # after the actor writes TensorSpecs in order to prevent a deadlock.
    logging.info('Connecting to learner: %s', learner_address)
    client = grpc.Client(learner_address)

    iter_steps = 0
    num_steps = 0
    sum_reward = 0.
    # add batch_dim
    agent_state = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0),
                                        initial_agent_state)

    iterations = 0
    while iter_steps < hparams['max_iter'] or hparams['max_iter'] == -1:
        logging.info('Iteration %d of %d', iter_steps + 1, hparams['max_iter'])
        # Get fresh parameters from the trainer.
        var_dtypes = [v.dtype for v in agent.trainable_variables]
        # trainer also adds `iterations` to the list of variables -- which is a
        # counter tracking number of iterations done so far.
        var_dtypes.append(tf.int64)
        new_values = []
        if iter_steps % hparams['sync_agent_every_n_steps'] == 0:
            new_values = client.variable_values()  # pytype: disable=attribute-error
        if new_values:
            logging.debug('Fetched variables from learner.')
            iterations = new_values[-1].numpy()
            updated_agent_vars = new_values[:-1]
            assert len(updated_agent_vars) == len(agent.trainable_variables)
            for x, y in zip(agent.trainable_variables, updated_agent_vars):
                x.assign(y)

        infos = []
        # Unroll agent.
        # Every episode sent by actor includes previous episode's final agent
        # state and output as well as final environment output.
        initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0),
                                                    agent_state)
        env_outputs = [env_output]
        agent_outputs = [agent_output]
        actor_actions = [actor_action]
        loss_type = problem_type.get_episode_loss_type(iterations)

        for i in range(FLAGS.unroll_length):
            logging.debug('Unroll step %d of %d', i + 1, FLAGS.unroll_length)
            # Agent expects time,batch dimensions in `env_output` and batch
            # dimension in `agent_state`. `agent_state` already has batch_dim.
            env_output = utils.add_time_batch_dim(env_output)
            agent_output, agent_state = agent(env_output, agent_state)

            env_output, agent_output = utils.remove_time_batch_dim(
                env_output, agent_output)

            actor_action, action_val = problem_type.select_actor_action(
                env_output, agent_output)

            env_output = env.step(action_val)

            env_outputs.append(env_output)
            agent_outputs.append(agent_output)
            actor_actions.append(actor_action)
            num_steps += 1
            sum_reward += env_output.reward

            if env_output.done:
                infos.append(
                    problem_type.get_actor_info(env_output, sum_reward,
                                                num_steps))
                num_steps = 0
                sum_reward = 0.

        processed_env_output = problem_type.postprocessing(
            utils.stack_nested_tensors(env_outputs))

        actor_output = common.ActorOutput(
            initial_agent_state=initial_agent_state,
            env_output=processed_env_output,
            agent_output=utils.stack_nested_tensors(agent_outputs),
            actor_action=utils.stack_nested_tensors(actor_actions),
            loss_type=tf.convert_to_tensor(loss_type, tf.int32),
            info=pickle.dumps(infos))
        flattened = tf.nest.flatten(actor_output)
        client.enqueue(flattened)  # pytype: disable=attribute-error
        iter_steps += 1