def test_call_v2(self, zero_mask, init_with_text_state, avg_all_img_states): self._get_agent('v2', init_with_text_state, avg_all_img_states) if zero_mask: disc_mask = np.tile(np.array([[False]] * 3), [1, self.batch_size]) observation = self._test_environment.observation observation[constants.DISC_MASK] = disc_mask self._test_environment._replace(observation=observation) env_output = self._env.reset() observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), env_output.observation) initial_agent_state = self._agent.get_initial_state( observation, batch_size=1) # Tests batch_size =1 for actor's scenario. # Agent always expects time,batch dimensions. First add and then remove. env_output = utils.add_time_batch_dim(env_output) agent_output, _ = self._agent(env_output, initial_agent_state) # Output shape = [time, batch, ...] self.assertEqual(agent_output.policy_logits['similarity'].shape, [1, 1, 1]) self.assertEqual(agent_output.policy_logits['labels'].shape, [1, 1]) self.assertEqual(agent_output.baseline.shape, [1, 1]) # Remove time-batch dims for single env_output (i.e., batch=1, timestep=1). env_output, agent_output = utils.remove_time_batch_dim(env_output, agent_output) self.assertEqual(agent_output.policy_logits['similarity'].shape, [1]) self.assertEqual(agent_output.policy_logits['labels'].shape, []) self.assertEqual(agent_output.baseline.shape, []) # Tests with custom states and env. initial_input_state = [(tf.random.normal([self.batch_size, 512]), tf.random.normal([self.batch_size, 512])), (tf.random.normal([self.batch_size, 512]), tf.random.normal([self.batch_size, 512]))] text_enc_output = tf.random.normal([self.batch_size, 5, 512]) initial_agent_state = (initial_input_state, (text_enc_output, tf.nest.map_structure(tf.identity, initial_input_state))) agent_output, _ = self._agent(self._test_environment, initial_agent_state) # Note that the agent_output has an extra time dim. self.assertEqual(agent_output.policy_logits['similarity'].shape, [1, self.batch_size, self.batch_size]) self.assertEqual(agent_output.policy_logits['similarity_mask'].shape, [1, self.batch_size, self.batch_size]) self.assertEqual(agent_output.policy_logits['labels'].shape, [1, self.batch_size]) self.assertEqual(agent_output.baseline.shape, [1, self.batch_size]) if zero_mask: expected_similarity_mask = [[[True, True], [True, True]]] expected_labels = [[0.0, 0.0]] else: expected_similarity_mask = [[[True, True], [True, True]]] expected_labels = [[1.0, 0.0]] self.assertAllEqual(agent_output.policy_logits['similarity_mask'], expected_similarity_mask) self.assertAllEqual(agent_output.policy_logits['labels'], expected_labels)
def testTimeBatchDim(self): x = tf.ones(shape=(2, 3)) y = tf.ones(shape=(2, 3, 4)) x, y = utils.add_time_batch_dim(x, y) np.testing.assert_equal((1, 1, 2, 3), x.shape) np.testing.assert_equal((1, 1, 2, 3, 4), y.shape) x, y = utils.remove_time_batch_dim(x, y) np.testing.assert_equal((2, 3), x.shape) np.testing.assert_equal((2, 3, 4), y.shape)
def test_call_ndh(self): self._agent = agent.R2RAgent(agent_config.get_ndh_agent_config()) self.data_dir = FLAGS.test_srcdir + ( 'valan/r2r/testdata') self._env_config = hparam.HParams( problem='NDH', history='all', path_type='trusted_path', max_goal_room_panos=4, scan_base_dir=self.data_dir, data_base_dir=self.data_dir, vocab_dir=self.data_dir, problem_path=os.path.join(self.data_dir, 'NDH'), vocab_file='vocab.txt', images_per_pano=36, max_conns=14, image_encoding_dim=64, direction_encoding_dim=256, image_features_dir=os.path.join(self.data_dir, 'image_features'), instruction_len=50, max_agent_actions=6, reward_fn=env_config.RewardFunction.get_reward_fn('distance_to_goal')) self._runtime_config = common.RuntimeConfig(task_id=0, num_tasks=1) self._env = env.R2REnv( data_sources=['R2R_small_split'], runtime_config=self._runtime_config, env_config=self._env_config) env_output = self._env.reset() observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), env_output.observation) initial_agent_state = self._agent.get_initial_state( observation, batch_size=1) # Agent always expects time,batch dimensions. First add and then remove. env_output = utils.add_time_batch_dim(env_output) agent_output, _ = self._agent(env_output, initial_agent_state) self.assertEqual(agent_output.policy_logits.shape, [1, 1, 14]) self.assertEqual(agent_output.baseline.shape, [1, 1]) initial_agent_state = ([ (tf.random.normal([self.batch_size, 512]), tf.random.normal([self.batch_size, 512])), (tf.random.normal([self.batch_size, 512]), tf.random.normal([self.batch_size, 512])) ], tf.random.normal([self.batch_size, 5, 512])) agent_output, _ = self._agent(self._test_environment, initial_agent_state) self.assertEqual(agent_output.policy_logits.shape, [self.time_step, self.batch_size, 14]) self.assertEqual(agent_output.baseline.shape, [self.time_step, self.batch_size])
def test_call(self): self._get_agent('default') env_output = self._env.reset() observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), env_output.observation) initial_agent_state = self._agent.get_initial_state( observation, batch_size=1) # Agent always expects time,batch dimensions. First add and then remove. env_output = utils.add_time_batch_dim(env_output) agent_output, _ = self._agent(env_output, initial_agent_state) initial_agent_state = ([ (tf.random.normal([self.batch_size, 512]), tf.random.normal([self.batch_size, 512])), (tf.random.normal([self.batch_size, 512]), tf.random.normal([self.batch_size, 512])) ], tf.random.normal([self.batch_size, 5, 512])) agent_output, _ = self._agent(self._test_environment, initial_agent_state) self.assertEqual(agent_output.policy_logits.shape, [3, self.batch_size])
def run_with_learner(problem_type: framework_problem_type.ProblemType, learner_address: Text, hparams: Dict[Text, Any]): """Runs actor with the given learner address and problem type. Args: problem_type: An instance of `framework_problem_type.ProblemType`. learner_address: The network address of a learner exposing two methods: `variable_values`: which returns latest value of trainable variables. `enqueue`: which accepts nested tensors of type `ActorOutput` tuple. hparams: A dict containing hyperparameter settings. """ env = problem_type.get_environment() agent = problem_type.get_agent() env_output = env.reset() initial_agent_state = agent.get_initial_state(utils.add_batch_dim( env_output.observation), batch_size=1) # Agent always expects time,batch dimensions. First add and then remove. env_output = utils.add_time_batch_dim(env_output) agent_output, _ = agent(env_output, initial_agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) actor_action = common.ActorAction( chosen_action_idx=tf.zeros([], dtype=tf.int32), oracle_next_action_idx=tf.zeros([], dtype=tf.int32)) # Remove batch_dim from returned agent's initial state. initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0), initial_agent_state) # Write TensorSpecs the learner can use for initialization. logging.info('My task id is %d', FLAGS.task) if FLAGS.task == 0: _write_tensor_specs(initial_agent_state, env_output, agent_output, actor_action) # gRPC Client creation blocks until the server responds to an RPC. Since the # server blocks at startup looking for TensorSpecs, and will not respond to # gRPC calls until these TensorSpecs are written, client creation must happen # after the actor writes TensorSpecs in order to prevent a deadlock. logging.info('Connecting to learner: %s', learner_address) client = grpc.Client(learner_address) iter_steps = 0 num_steps = 0 sum_reward = 0. # add batch_dim agent_state = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), initial_agent_state) iterations = 0 while iter_steps < hparams['max_iter'] or hparams['max_iter'] == -1: logging.info('Iteration %d of %d', iter_steps + 1, hparams['max_iter']) # Get fresh parameters from the trainer. var_dtypes = [v.dtype for v in agent.trainable_variables] # trainer also adds `iterations` to the list of variables -- which is a # counter tracking number of iterations done so far. var_dtypes.append(tf.int64) new_values = [] if iter_steps % hparams['sync_agent_every_n_steps'] == 0: new_values = client.variable_values() # pytype: disable=attribute-error if new_values: logging.debug('Fetched variables from learner.') iterations = new_values[-1].numpy() updated_agent_vars = new_values[:-1] assert len(updated_agent_vars) == len(agent.trainable_variables) for x, y in zip(agent.trainable_variables, updated_agent_vars): x.assign(y) infos = [] # Unroll agent. # Every episode sent by actor includes previous episode's final agent # state and output as well as final environment output. initial_agent_state = tf.nest.map_structure(lambda t: tf.squeeze(t, 0), agent_state) env_outputs = [env_output] agent_outputs = [agent_output] actor_actions = [actor_action] loss_type = problem_type.get_episode_loss_type(iterations) for i in range(FLAGS.unroll_length): logging.debug('Unroll step %d of %d', i + 1, FLAGS.unroll_length) # Agent expects time,batch dimensions in `env_output` and batch # dimension in `agent_state`. `agent_state` already has batch_dim. env_output = utils.add_time_batch_dim(env_output) agent_output, agent_state = agent(env_output, agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) actor_action, action_val = problem_type.select_actor_action( env_output, agent_output) env_output = env.step(action_val) env_outputs.append(env_output) agent_outputs.append(agent_output) actor_actions.append(actor_action) num_steps += 1 sum_reward += env_output.reward if env_output.done: infos.append( problem_type.get_actor_info(env_output, sum_reward, num_steps)) num_steps = 0 sum_reward = 0. processed_env_output = problem_type.postprocessing( utils.stack_nested_tensors(env_outputs)) actor_output = common.ActorOutput( initial_agent_state=initial_agent_state, env_output=processed_env_output, agent_output=utils.stack_nested_tensors(agent_outputs), actor_action=utils.stack_nested_tensors(actor_actions), loss_type=tf.convert_to_tensor(loss_type, tf.int32), info=pickle.dumps(infos)) flattened = tf.nest.flatten(actor_output) client.enqueue(flattened) # pytype: disable=attribute-error iter_steps += 1
def run_with_aggregator(problem_type, aggregator_address: Text, hparams): """Run evaluation actor with given problem_type, aggregator and hparams. Args: problem_type: An instance of `framework_problem_type.ProblemType`. aggregator_address: The aggregator address to which we will send data for batching. hparams: A dict containing hyperparameter settings. """ assert isinstance(problem_type, framework_problem_type.ProblemType) env = problem_type.get_environment() agent = problem_type.get_agent() env_output = env.reset() agent_state = agent.get_initial_state(utils.add_batch_dim( env_output.observation), batch_size=1) # Agent always expects time,batch dimensions. _, _ = agent(utils.add_time_batch_dim(env_output), agent_state) logging.info('Connecting to aggregator %s', aggregator_address) aggregator = grpc.Client(aggregator_address) iter_steps = 0 latest_checkpoint_path = '' while hparams['max_iter'] == -1 or iter_steps < hparams['max_iter']: logging.info('Iteration %d of %d', iter_steps + 1, hparams['max_iter']) checkpoint_directory = os.path.join(hparams['logdir'], 'model.ckpt') checkpoint_path = tf.train.latest_checkpoint(checkpoint_directory) if checkpoint_path == latest_checkpoint_path or not checkpoint_path: logging.info( 'Waiting for next checkpoint. Previously evaluated checkpoint %s', latest_checkpoint_path) time.sleep(30) continue ckpt = tf.train.Checkpoint(agent=agent) ckpt.restore(checkpoint_path) latest_checkpoint_path = checkpoint_path logging.info('Evaluating latest checkpoint - %s', latest_checkpoint_path) step = int(latest_checkpoint_path.split('-')[-1]) logging.debug('Step %d', step) for i in range(hparams['num_episodes_per_iter']): logging.debug('Episode number %d of %d', i + 1, hparams['num_episodes_per_iter']) action_list = [] env_output_list = [env_output] while True: env_output = utils.add_time_batch_dim(env_output) agent_output, agent_state = agent(env_output, agent_state) env_output, agent_output = utils.remove_time_batch_dim( env_output, agent_output) _, action_val = problem_type.select_actor_action( env_output, agent_output) env_output = env.step(action_val) action_list.append(action_val) env_output_list.append(env_output) if env_output.done: eval_result = problem_type.eval(action_list, env_output_list) # eval_result is a dict. eval_result[common.STEP] = step aggregator.eval_enqueue(pickle.dumps(eval_result)) # pytype: disable=attribute-error break iter_steps += 1
def plan_actor_action(self, agent_output, agent_state, agent_instance, env_output, env_instance, beam_size, planning_horizon, temperature=1.0): initial_env_state = env_instance.get_state() initial_time_step = env_output.observation[constants.TIME_STEP] beam = [common.PlanningState(score=0, agent_output=agent_output, agent_state=agent_state, env_output=env_output, env_state=initial_env_state, action_history=[])] planning_step = 1 while True: next_beam = [] for state in beam: if state.action_history and (state.action_history[-1].action_val == constants.STOP_NODE_ID): # Path is done. This won't be reflected in env_output.done since # stop actions are not performed during planning. next_beam.append(state) continue # Find the beam_size best next actions based on policy log probability. num_actions = tf.math.count_nonzero(state.env_output.observation[ constants.CONN_IDS] >= constants.STOP_NODE_ID).numpy() policy_logprob = tf.nn.log_softmax( state.agent_output.policy_logits / temperature) logprob, ix = tf.math.top_k( policy_logprob, k=min(num_actions, beam_size)) action_vals = tf.gather( state.env_output.observation[constants.CONN_IDS], ix) oracle_action = state.env_output.observation[ constants.ORACLE_NEXT_ACTION] oracle_action_indices = tf.where( tf.equal(state.env_output.observation[constants.CONN_IDS], oracle_action)) oracle_action_idx = tf.reduce_min(oracle_action_indices) # Expand each action and add to the beam for the next iteration. for j, action_val in enumerate(action_vals.numpy()): next_action = common.ActorAction( chosen_action_idx=int(ix[j].numpy()), oracle_next_action_idx=int(oracle_action_idx.numpy()), action_val=int(action_val), log_prob=float(logprob[j].numpy())) if action_val == constants.STOP_NODE_ID: # Don't perform stop actions which trigger a new episode that can't # be reset using set_state. next_state = common.PlanningState( score=state.score + logprob[j], agent_output=state.agent_output, agent_state=state.agent_state, env_output=state.env_output, env_state=state.env_state, action_history=state.action_history + [next_action]) else: # Perform the non-stop action. env_instance.set_state(state.env_state) next_env_output = env_instance.step(action_val) next_env_output = utils.add_time_batch_dim(next_env_output) next_agent_output, next_agent_state = agent_instance( next_env_output, state.agent_state) next_env_output, next_agent_output = utils.remove_time_batch_dim( next_env_output, next_agent_output) next_state = common.PlanningState( score=state.score + logprob[j], agent_output=next_agent_output, agent_state=next_agent_state, env_output=next_env_output, env_state=env_instance.get_state(), action_history=state.action_history + [next_action]) next_beam.append(next_state) def _log_beam(beam): for item in beam: path_string = '\t'.join( [str(a.action_val) for a in item.action_history]) score_string = '\t'.join( ['%.4f' % a.log_prob for a in item.action_history]) logging.debug('Score: %.4f', item.score) logging.debug('Log prob: %s', score_string) logging.debug('Steps: %s', path_string) # Reduce the next beam to only the top beam_size paths. beam = sorted(next_beam, reverse=True, key=operator.attrgetter('score')) beam = beam[:beam_size] logging.debug('Planning step %d', planning_step) _log_beam(beam) # Break if all episodes are done. if all(item.action_history[-1].action_val == constants.STOP_NODE_ID for item in beam): break # Break if exceeded planning_horizon. if planning_step >= planning_horizon: break # Break if we are planning beyond the max actions per episode, since this # will also trigger a new episode (same as the stop action). if initial_time_step + planning_step >= env_instance._max_actions_per_episode: break planning_step += 1 # Restore the environment to it's initial state so the agent can still act. env_instance.set_state(initial_env_state) return beam[0].action_history