def select_actor_action(self, env_output, agent_output): oracle_next_action = env_output.observation[ constants.ORACLE_NEXT_ACTION] oracle_next_action_indices = tf.where( tf.equal(env_output.observation[constants.CONN_IDS], oracle_next_action)) oracle_next_action_idx = tf.reduce_min(oracle_next_action_indices) assert self._mode, 'mode must be set.' if self._mode == 'train': if self._loss_type == common.CE_LOSS: # This is teacher-forcing mode, so choose action same as oracle action. action_idx = oracle_next_action_idx elif self._loss_type == common.AC_LOSS: # Choose next pano from probability distribution over next panos action_idx = tfp.distributions.Categorical( logits=agent_output.policy_logits).sample() else: raise ValueError('Unsupported loss type {}'.format( self._loss_type)) else: # In non-train modes, choose greedily. action_idx = tf.argmax(agent_output.policy_logits, axis=-1) action_val = env_output.observation[constants.CONN_IDS][action_idx] return common.ActorAction(chosen_action_idx=int(action_idx.numpy()), oracle_next_action_idx=int( oracle_next_action_idx.numpy())), int( action_val.numpy())
def select_actor_action(self, env_output, unused_agent_output): """Returns the next ground truth action pano id.""" time_step = env_output.observation[constants.TIME_STEP] current_pano_id = env_output.observation[constants.PANO_ID] golden_path = env_output.observation[constants.GOLDEN_PATH] golden_path_len = sum( [1 for pid in golden_path if pid != constants.INVALID_NODE_ID]) # Sanity check: ensure pano id is on the golden path. if current_pano_id != golden_path[time_step]: raise ValueError( 'Current pano id does not match that in golden path: {} vs. {}' .format(current_pano_id, golden_path[time_step])) if ((current_pano_id == env_output.observation[constants.GOAL_PANO_ID] and time_step == golden_path_len - 1) or current_pano_id == constants.STOP_NODE_ID): next_golden_pano_id = constants.STOP_NODE_ID else: next_golden_pano_id = golden_path[time_step + 1] try: unused_action_idx = tf.where( tf.equal(env_output.observation[constants.CONN_IDS], next_golden_pano_id)) except ValueError: # Current and next panos are not connected, use idx for invalid node. unused_action_idx = unused_action_idx = tf.where( tf.equal(env_output.observation[constants.CONN_IDS], constants.INVALID_NODE_ID)) unused_action_idx = tf.cast(tf.reduce_min(unused_action_idx), tf.int32) return common.ActorAction( chosen_action_idx=unused_action_idx.numpy(), oracle_next_action_idx=unused_action_idx.numpy()), int( next_golden_pano_id)
def select_actor_action(self, env_output, agent_output): # Always selects action=1 by default. action_idx = 1 action_val = 1 oracle_next_action_idx = 1 return common.ActorAction( chosen_action_idx=action_idx, oracle_next_action_idx=oracle_next_action_idx), action_val
def select_actor_action(self, env_output, agent_output): # Agent_output is unused here. oracle_next_action = env_output.observation[constants.ORACLE_NEXT_ACTION] oracle_next_action_indices = tf.where( tf.equal(env_output.observation[constants.CONN_IDS], oracle_next_action)) oracle_next_action_idx = tf.reduce_min(oracle_next_action_indices) assert self._mode, 'mode must be set.' action_idx = oracle_next_action_idx action_val = env_output.observation[constants.CONN_IDS][action_idx] return common.ActorAction( chosen_action_idx=int(action_idx.numpy()), oracle_next_action_idx=int(oracle_next_action_idx.numpy())), int( action_val)
def __init__(self, unroll_length=1): self._env = MockEnv(state_space_size=4, unroll_length=unroll_length) self._agent = MockAgent(unroll_length=unroll_length) self._actor_output_spec = common.ActorOutput( initial_agent_state=tf.TensorSpec(shape=[5], dtype=tf.float32), env_output=self._env.env_spec, agent_output=self._agent.agent_spec, actor_action=common.ActorAction( chosen_action_idx=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.int32), oracle_next_action_idx=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.int32)), loss_type=tf.TensorSpec(shape=[], dtype=tf.int32), info=tf.TensorSpec(shape=[], dtype=tf.string), )
def select_actor_action(self, env_output, agent_output): assert self._mode, 'mode must be set for selecting action in actor.' oracle_next_action = env_output.observation[ streetview_constants.ORACLE_NEXT_ACTION] if self._mode == 'train': if self._loss_type == common.CE_LOSS: # This is teacher-forcing mode, so choose action same as oracle action. action_idx = oracle_next_action elif self._loss_type == common.AC_LOSS: action_idx = tfp.distributions.Categorical( logits=agent_output.policy_logits).sample() else: # In non-train modes, choose greedily. action_idx = tf.argmax(agent_output.policy_logits, axis=-1) # Return ActorAction and the action to be passed to the env step function. return common.ActorAction( chosen_action_idx=int(action_idx.numpy()), oracle_next_action_idx=int( oracle_next_action.numpy())), action_idx.numpy()
def select_actor_action(self, env_output, agent_output): oracle_next_action = env_output.observation[ constants.ORACLE_NEXT_ACTION] oracle_next_action_indices = tf.where( tf.equal(env_output.observation[constants.CONN_IDS], oracle_next_action)) oracle_next_action_idx = tf.reduce_min(oracle_next_action_indices) if self._loss_type == common.CE_LOSS: # This is teacher-forcing mode, so choose action same as oracle action. action_idx = oracle_next_action_idx elif self._loss_type == common.AC_LOSS: # Choose next pano from probability distribution over next panos action_idx = tfp.distributions.Categorical( logits=agent_output.policy_logits).sample() else: raise ValueError('Unsupported loss type {}'.format( self._loss_type)) action_val = env_output.observation[constants.CONN_IDS][action_idx] policy_logprob = tf.nn.log_softmax(agent_output.policy_logits) return common.ActorAction( chosen_action_idx=int(action_idx.numpy()), oracle_next_action_idx=int(oracle_next_action_idx.numpy()), action_val=int(action_val.numpy()), log_prob=float(policy_logprob[action_idx].numpy()))
def run_with_learner(problem_type: framework_problem_type.ProblemType, learner_address: Text, hparams: Dict[Text, Any]): """Runs actor with the given learner address and problem type. Args: problem_type: An instance of `framework_problem_type.ProblemType`. learner_address: The network address of a learner exposing two methods: `variable_values`: which returns latest value of trainable variables. `enqueue`: which accepts nested tensors of type `ActorOutput` tuple. hparams: A dict containing hyperparameter settings. """ env = problem_type.get_environment() agent = problem_type.get_agent() env_output = env.reset() initial_agent_state = agent.get_initial_state(utils.add_batch_dim( env_output.observation), batch_size=1) # Agent always expects time,batch dimensions. First add and then remove. env_output = utils.add_time_batch_dim(env_output) agent_output, _ = agent(env_output, initial_agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) actor_action = common.ActorAction( chosen_action_idx=tf.zeros([], dtype=tf.int32), oracle_next_action_idx=tf.zeros([], dtype=tf.int32)) # Remove batch_dim from returned agent's initial state. initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0), initial_agent_state) # Write TensorSpecs the learner can use for initialization. logging.info('My task id is %d', FLAGS.task) if FLAGS.task == 0: _write_tensor_specs(initial_agent_state, env_output, agent_output, actor_action) # gRPC Client creation blocks until the server responds to an RPC. Since the # server blocks at startup looking for TensorSpecs, and will not respond to # gRPC calls until these TensorSpecs are written, client creation must happen # after the actor writes TensorSpecs in order to prevent a deadlock. logging.info('Connecting to learner: %s', learner_address) client = grpc.Client(learner_address) iter_steps = 0 num_steps = 0 sum_reward = 0. # add batch_dim agent_state = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), initial_agent_state) iterations = 0 while iter_steps < hparams['max_iter'] or hparams['max_iter'] == -1: logging.info('Iteration %d of %d', iter_steps + 1, hparams['max_iter']) # Get fresh parameters from the trainer. var_dtypes = [v.dtype for v in agent.trainable_variables] # trainer also adds `iterations` to the list of variables -- which is a # counter tracking number of iterations done so far. var_dtypes.append(tf.int64) new_values = [] if iter_steps % hparams['sync_agent_every_n_steps'] == 0: new_values = client.variable_values() # pytype: disable=attribute-error if new_values: logging.debug('Fetched variables from learner.') iterations = new_values[-1].numpy() updated_agent_vars = new_values[:-1] assert len(updated_agent_vars) == len(agent.trainable_variables) for x, y in zip(agent.trainable_variables, updated_agent_vars): x.assign(y) infos = [] # Unroll agent. # Every episode sent by actor includes previous episode's final agent # state and output as well as final environment output. initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0), agent_state) env_outputs = [env_output] agent_outputs = [agent_output] actor_actions = [actor_action] loss_type = problem_type.get_episode_loss_type(iterations) for i in range(FLAGS.unroll_length): logging.debug('Unroll step %d of %d', i + 1, FLAGS.unroll_length) # Agent expects time,batch dimensions in `env_output` and batch # dimension in `agent_state`. `agent_state` already has batch_dim. env_output = utils.add_time_batch_dim(env_output) agent_output, agent_state = agent(env_output, agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) actor_action, action_val = problem_type.select_actor_action( env_output, agent_output) env_output = env.step(action_val) env_outputs.append(env_output) agent_outputs.append(agent_output) actor_actions.append(actor_action) num_steps += 1 sum_reward += env_output.reward if env_output.done: infos.append( problem_type.get_actor_info(env_output, sum_reward, num_steps)) num_steps = 0 sum_reward = 0. processed_env_output = problem_type.postprocessing( utils.stack_nested_tensors(env_outputs)) actor_output = common.ActorOutput( initial_agent_state=initial_agent_state, env_output=processed_env_output, agent_output=utils.stack_nested_tensors(agent_outputs), actor_action=utils.stack_nested_tensors(actor_actions), loss_type=tf.convert_to_tensor(loss_type, tf.int32), info=pickle.dumps(infos)) flattened = tf.nest.flatten(actor_output) client.enqueue(flattened) # pytype: disable=attribute-error iter_steps += 1
def plan_actor_action(self, agent_output, agent_state, agent_instance, env_output, env_instance, beam_size, planning_horizon, temperature=1.0): initial_env_state = env_instance.get_state() initial_time_step = env_output.observation[constants.TIME_STEP] beam = [common.PlanningState(score=0, agent_output=agent_output, agent_state=agent_state, env_output=env_output, env_state=initial_env_state, action_history=[])] planning_step = 1 while True: next_beam = [] for state in beam: if state.action_history and (state.action_history[-1].action_val == constants.STOP_NODE_ID): # Path is done. This won't be reflected in env_output.done since # stop actions are not performed during planning. next_beam.append(state) continue # Find the beam_size best next actions based on policy log probability. num_actions = tf.math.count_nonzero(state.env_output.observation[ constants.CONN_IDS] >= constants.STOP_NODE_ID).numpy() policy_logprob = tf.nn.log_softmax( state.agent_output.policy_logits / temperature) logprob, ix = tf.math.top_k( policy_logprob, k=min(num_actions, beam_size)) action_vals = tf.gather( state.env_output.observation[constants.CONN_IDS], ix) oracle_action = state.env_output.observation[ constants.ORACLE_NEXT_ACTION] oracle_action_indices = tf.where( tf.equal(state.env_output.observation[constants.CONN_IDS], oracle_action)) oracle_action_idx = tf.reduce_min(oracle_action_indices) # Expand each action and add to the beam for the next iteration. for j, action_val in enumerate(action_vals.numpy()): next_action = common.ActorAction( chosen_action_idx=int(ix[j].numpy()), oracle_next_action_idx=int(oracle_action_idx.numpy()), action_val=int(action_val), log_prob=float(logprob[j].numpy())) if action_val == constants.STOP_NODE_ID: # Don't perform stop actions which trigger a new episode that can't # be reset using set_state. next_state = common.PlanningState( score=state.score + logprob[j], agent_output=state.agent_output, agent_state=state.agent_state, env_output=state.env_output, env_state=state.env_state, action_history=state.action_history + [next_action]) else: # Perform the non-stop action. env_instance.set_state(state.env_state) next_env_output = env_instance.step(action_val) next_env_output = utils.add_time_batch_dim(next_env_output) next_agent_output, next_agent_state = agent_instance( next_env_output, state.agent_state) next_env_output, next_agent_output = utils.remove_time_batch_dim( next_env_output, next_agent_output) next_state = common.PlanningState( score=state.score + logprob[j], agent_output=next_agent_output, agent_state=next_agent_state, env_output=next_env_output, env_state=env_instance.get_state(), action_history=state.action_history + [next_action]) next_beam.append(next_state) def _log_beam(beam): for item in beam: path_string = '\t'.join( [str(a.action_val) for a in item.action_history]) score_string = '\t'.join( ['%.4f' % a.log_prob for a in item.action_history]) logging.debug('Score: %.4f', item.score) logging.debug('Log prob: %s', score_string) logging.debug('Steps: %s', path_string) # Reduce the next beam to only the top beam_size paths. beam = sorted(next_beam, reverse=True, key=operator.attrgetter('score')) beam = beam[:beam_size] logging.debug('Planning step %d', planning_step) _log_beam(beam) # Break if all episodes are done. if all(item.action_history[-1].action_val == constants.STOP_NODE_ID for item in beam): break # Break if exceeded planning_horizon. if planning_step >= planning_horizon: break # Break if we are planning beyond the max actions per episode, since this # will also trigger a new episode (same as the stop action). if initial_time_step + planning_step >= env_instance._max_actions_per_episode: break planning_step += 1 # Restore the environment to it's initial state so the agent can still act. env_instance.set_state(initial_env_state) return beam[0].action_history