def testSavedModel(self): path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) rng = np.random.RandomState() sample_time_step = array_spec.sample_spec_nest(self.time_step_spec, rng) batched_sample_time_step = nest_utils.batch_nested_array( sample_time_step) original_action = self.tf_policy.action(batched_sample_time_step) unbatched_original_action = nest_utils.unbatch_nested_tensors( original_action) original_action_np = tf.nest.map_structure(lambda t: t.numpy(), unbatched_original_action) saved_policy_action = eager_py_policy.action(sample_time_step) tf.nest.assert_same_structure(saved_policy_action.action, self.action_spec) np.testing.assert_array_almost_equal(original_action_np.action, saved_policy_action.action)
def testUnBatchSingleTensor(self): batched_tensor = tf.zeros([1, 2, 3], dtype=tf.float32) spec = tensor_spec.TensorSpec([2, 3], dtype=tf.float32) tensor = nest_utils.unbatch_nested_tensors(batched_tensor, spec) self.assertEqual(tensor.shape.as_list(), [2, 3])
def _action(self, time_step, policy_state, seed): if seed is not None: raise NotImplementedError( 'seed is not supported; but saw seed: {}'.format(seed)) def _action_fn(*flattened_time_step_and_policy_state): packed_py_time_step, packed_py_policy_state = tf.nest.pack_sequence_as( structure=(self._py_policy.time_step_spec, self._py_policy.policy_state_spec), flat_sequence=flattened_time_step_and_policy_state) py_action_step = self._py_policy.action( time_step=packed_py_time_step, policy_state=packed_py_policy_state) return tf.nest.flatten(py_action_step) with tf.name_scope('action'): if not self._py_policy_is_batched: time_step = nest_utils.unbatch_nested_tensors(time_step) flattened_input_tensors = tf.nest.flatten((time_step, policy_state)) flat_action_step = tf.numpy_function( _action_fn, flattened_input_tensors, self._policy_step_dtypes, name='action_numpy_function') action_step = tf.nest.pack_sequence_as( structure=self.policy_step_spec, flat_sequence=flat_action_step) if not self._py_policy_is_batched: action_step = action_step._replace( action=nest_utils.batch_nested_tensors(action_step.action)) return action_step
def _action(self, time_step, policy_state, seed): del seed def _mode(dist, spec): action = dist.mode() return tf.reshape(action, [ -1, ] + spec.shape.as_list()) # TODO(oars): Remove batched data checks when tf_env is batched. time_step_batched = nest_utils.is_batched_nested_tensors( time_step, self._time_step_spec) if not time_step_batched: time_step = nest_utils.batch_nested_tensors( time_step, self._time_step_spec) distribution_step = self._wrapped_policy.distribution( time_step, policy_state) actions = nest.map_structure(_mode, distribution_step.action, self._action_spec) if not time_step_batched: actions = nest_utils.unbatch_nested_tensors( actions, self._action_spec) return policy_step.PolicyStep(actions, distribution_step.state, distribution_step.info)
def _start_new_episode(self): self._time_step = self._env.reset() if self._env.batch_size is not None: self._time_step = nest_utils.unbatch_nested_tensors( self._time_step) self._step_type = self._time_step.step_type self._discount = self._time_step.discount self._first_step_type = self._step_type self._policy_state = self._policy.get_initial_state(None) self._start_on_next_step = False self._cur_step_num = 0
def testUnBatchedNestedTensors(self, include_sparse=False): shape = [2, 3] specs = self.nest_spec(shape, include_sparse=False) unbatched_tensors = self.zeros_from_spec(specs) tf.nest.assert_same_structure(unbatched_tensors, specs) tensors = nest_utils.unbatch_nested_tensors(unbatched_tensors, specs) tf.nest.assert_same_structure(specs, tensors) assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), shape, t) tf.nest.map_structure(assert_shapes, tensors)
def testUnBatchNestedTensors(self): shape = [2, 3] batch_size = 1 specs = self.nest_spec(shape) batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size) tf.nest.assert_same_structure(batched_tensors, specs) tensors = nest_utils.unbatch_nested_tensors(batched_tensors, specs) tf.nest.assert_same_structure(specs, tensors) assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), shape, t) tf.nest.map_structure(assert_shapes, tensors)
def testInferenceFromCheckpoint(self): path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) rng = np.random.RandomState() sample_time_step = array_spec.sample_spec_nest(self.time_step_spec, rng) batched_sample_time_step = nest_utils.batch_nested_array( sample_time_step) self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), self.tf_policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) # Use evaluate to force a copy. saved_model_variables = self.evaluate(eager_py_policy.variables()) checkpoint = tf.train.Checkpoint(policy=eager_py_policy._policy) manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_path, max_to_keep=None) eager_py_policy.update_from_checkpoint(manager.latest_checkpoint) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).all()) tf.nest.map_structure(assert_np_not_equal, saved_model_variables, self.evaluate(eager_py_policy.variables())) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, self.evaluate(self.tf_policy.variables()), self.evaluate(eager_py_policy.variables())) # Can't check if the action is different as in some cases depending on # variable initialization it will be the same. Checking that they are at # least always the same. checkpoint_action = eager_py_policy.action(sample_time_step) current_policy_action = self.tf_policy.action(batched_sample_time_step) current_policy_action = self.evaluate( nest_utils.unbatch_nested_tensors(current_policy_action)) tf.nest.map_structure(assert_np_all_equal, current_policy_action, checkpoint_action)
def _distribution(self, time_step, policy_state): batched = nest_utils.is_batched_nested_tensors(time_step, self._time_step_spec) if not batched: time_step = nest_utils.batch_nested_tensors(time_step) policy_dist_step = self._wrapped_policy.distribution( time_step, policy_state) policy_state = policy_dist_step.state policy_mean_action = policy_dist_step.action.mean() policy_info = policy_dist_step.info if not batched: policy_state = nest_utils.unbatch_nested_tensors(policy_state) policy_mean_action = nest_utils.unbatch_nested_tensors( policy_mean_action) policy_info = nest_utils.unbatch_nested_tensors(policy_info) gaussian_dist = tfp.distributions.MultivariateNormalDiag( loc=policy_mean_action, scale_diag=tf.ones_like(policy_mean_action) * self._scale) return policy_step.PolicyStep(gaussian_dist, policy_state, policy_info)
def testSavedModel(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') observation_spec = array_spec.ArraySpec([2], np.float32) action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3) time_step_spec = ts.time_step_spec(observation_spec) observation_tensor_spec = tensor_spec.from_spec(observation_spec) action_tensor_spec = tensor_spec.from_spec(action_spec) time_step_tensor_spec = tensor_spec.from_spec(time_step_spec) actor_net = actor_network.ActorNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=(10, ), ) tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec, action_tensor_spec, actor_network=actor_net) path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(tf_policy) saver.save(path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, time_step_spec, action_spec) rng = np.random.RandomState() sample_time_step = array_spec.sample_spec_nest(time_step_spec, rng) batched_sample_time_step = nest_utils.batch_nested_array( sample_time_step) original_action = tf_policy.action(batched_sample_time_step) unbatched_original_action = nest_utils.unbatch_nested_tensors( original_action) original_action_np = tf.nest.map_structure(lambda t: t.numpy(), unbatched_original_action) saved_policy_action = eager_py_policy.action(sample_time_step) tf.nest.assert_same_structure(saved_policy_action.action, action_spec) np.testing.assert_array_almost_equal(original_action_np.action, saved_policy_action.action)
def sample(self, batch_size): dummy_action_step = policy_step.PolicyStep( action=tf.constant([tf.int32.min])) dummy_time_step = ts.TimeStep(step_type=tf.constant([tf.int32.min]), reward=(np.nan * tf.ones(1)), discount=(np.nan * tf.ones(1)), observation=None) trajs = [] for transition in random.sample(self.buffer, batch_size): traj1 = trajectory.from_transition(transition.time_step, transition.action_step, transition.next_time_step) traj2 = trajectory.from_transition(transition.next_time_step, dummy_action_step, dummy_time_step) trajs.append( nest_utils.unbatch_nested_tensors( nest_utils.stack_nested_tensors([traj1, traj2], axis=1))) return nest_utils.stack_nested_tensors(trajs)
def relabel_function(cur_episode, last_step, reward_fn, full_buffer): all_data = cur_episode.gather_all() # add all actual interaction to the replay buffer all_data = nest_utils.unbatch_nested_tensors(all_data) for cur_trajectory in nest_utils.unstack_nested_tensors( all_data, full_buffer.data_spec): # was already added by previous iteration if cur_trajectory.step_type.numpy() != 2: full_buffer.add_batch( nest_utils.batch_nested_tensors(cur_trajectory)) last_traj = cur_trajectory._replace( # pylint: disable=undefined-loop-variable step_type=tf.constant(2), observation=last_step.observation[0], next_step_type=tf.constant(0), reward=tf.constant(0.0), discount=tf.constant(1., dtype=tf.float32)) full_buffer.add_batch(nest_utils.batch_nested_tensors(last_traj)) def _relabel_given_goal(relabel_goal): obs_dim = relabel_goal.shape[0] all_trajectories = nest_utils.unstack_nested_tensors( all_data, full_buffer.data_spec) last_traj_idx = len(all_trajectories) for traj_idx, cur_trajectory in enumerate(all_trajectories): if cur_trajectory.step_type.numpy() != 2: new_obs = tf.concat( [cur_trajectory.observation[:obs_dim], relabel_goal], axis=0) if traj_idx == len(all_trajectories) - 1: next_obs = tf.concat( [last_step.observation[0, :obs_dim], relabel_goal], axis=0) else: next_obs = tf.concat([ all_trajectories[traj_idx + 1].observation[:obs_dim], relabel_goal ], axis=0) new_reward = tf.constant(reward_fn(obs=next_obs)) # terminate episode if new_reward.numpy() > 0.0: new_traj = cur_trajectory._replace( observation=new_obs, next_step_type=tf.constant(2), reward=new_reward, discount=tf.constant(0., dtype=tf.float32)) last_traj_idx = traj_idx + 1 full_buffer.add_batch( nest_utils.batch_nested_tensors(new_traj)) break else: new_traj = cur_trajectory._replace( observation=new_obs, reward=new_reward, ) full_buffer.add_batch( nest_utils.batch_nested_tensors(new_traj)) if last_traj_idx == len(all_trajectories): last_observation = tf.concat( [last_step.observation[0, :obs_dim], relabel_goal], axis=0) else: last_observation = tf.concat([ all_trajectories[last_traj_idx].observation[:obs_dim], relabel_goal ], axis=0) last_traj = cur_trajectory._replace( # pylint: disable=undefined-loop-variable step_type=tf.constant(2), observation=last_observation, next_step_type=tf.constant(0), reward=tf.constant(0.0), discount=tf.constant(1., dtype=tf.float32)) full_buffer.add_batch(nest_utils.batch_nested_tensors(last_traj)) # relabel with last time step achieved in the episode if FLAGS.goal_relabel_type == 0 or (FLAGS.goal_relabel_type == 1 and last_step.reward.numpy()[0] <= 0.): obs_dim = last_step.observation.shape[1] // 2 _relabel_given_goal(last_step.observation[0, :obs_dim]) elif FLAGS.goal_relabel_type == 2 and last_step.reward.numpy()[0] <= 0.: goals = [ [1.2, 0., 2.5, 0., -1., -1.], [2., 0., 2.4, 0., 0., 0.], [0.8, 0., 1.2, 0., 0., 0.], [-0.1, -0.3, 0.3, -0.3, 0., 0.], [-0.6, -1., -0.2, -1., 0., 0.], [-1.8, -1., -1.4, -1., 0., 0.], [-2.8, -0.8, -2.4, -1., -1., -1.], [-2.4, 0., -2.4, -1., -1., -1.], [-1.2, 0., -2.4, -1., -1., -1.], [0.0, 0.0, -2.5, -1, -1., -1.], ] goals = np.stack(goals).astype('float32') print('unrelabelled goal:', last_step.observation[0, 6:].numpy()) relabel_goal_idxs = np.arange(goals.shape[0]) np.random.shuffle(relabel_goal_idxs) obs_dim = last_step.observation.shape[1] // 2 relabel_count = 0 for goal_idx in relabel_goal_idxs: chosen_goal = goals[goal_idx] if (chosen_goal == last_step.observation[0, obs_dim:].numpy()).all(): continue print('goal for relabelling:', chosen_goal) _relabel_given_goal(relabel_goal=tf.constant(chosen_goal)) relabel_count += 1 if relabel_count >= FLAGS.num_relabelled_goals: break else: print('not adding relabelled trajectories')
def _get_step(self) -> EnvStep: if self._start_on_next_step: self._start_new_episode() if StepType.is_last(self._step_type): # This is the last (terminating) observation of the environment. self._start_on_next_step = True self._num_total_steps += 1 self._num_episodes += 1 # The policy is not run on the terminal step, so we just carry over the # reward, action, and policy_info from the previous step. return EnvStep(self._step_type, tf.cast(self._cur_step_num, dtype=tf.int64), self._time_step.observation, self._action, self._time_step.reward, self._time_step.discount, self._policy_info, {}, {}) self._action, self._policy_state, self._policy_info = self._policy.action( self._time_step, self._policy_state) # Update type of log-probs to tf.float32... a bit of a bug in TF-Agents. if hasattr(self._policy_info, 'log_probability'): self._policy_info = policy_step.set_log_probability( self._policy_info, tf.cast(self._policy_info.log_probability, tf.float32)) # Sample action from policy. env_action = self._action if self._env.batch_size is not None: env_action = nest_utils.batch_nested_tensors(env_action) # Sample next step from environment. self._next_time_step = self._env.step(env_action) if self._env.batch_size is not None: self._next_time_step = nest_utils.unbatch_nested_tensors( self._next_time_step) self._next_step_type = self._next_time_step.step_type self._cur_step_num += 1 if (self._episode_step_limit and self._cur_step_num >= self._episode_step_limit): self._next_step_type = tf.convert_to_tensor( # Overwrite step type. value=StepType.LAST, dtype=self._first_step_type.dtype) self._next_step_type = tf.reshape(self._next_step_type, tf.shape(self._first_step_type)) step = EnvStep( self._step_type, tf.cast(self._cur_step_num - 1, tf.int64), self._time_step.observation, self._action, # Immediate reward given by next time step. self._next_time_step.reward, self._time_step.discount, self._policy_info, {}, {}) self._num_steps += 1 self._num_total_steps += 1 if StepType.is_first(self._step_type): self._num_total_episodes += 1 self._time_step = self._next_time_step self._step_type = self._next_step_type return step
def copy_replay_buffer(small_buffer, big_buffer): """Copy small buffer into the big buffer.""" all_data = nest_utils.unbatch_nested_tensors(small_buffer.gather_all()) for trajectory in nest_utils.unstack_nested_tensors( all_data, big_buffer.data_spec): big_buffer.add_batch(nest_utils.batch_nested_tensors(trajectory))
def train_eval( root_dir, offline_dir=None, random_seed=None, env_name='sawyer_push', eval_env_name=None, env_load_fn=get_env, max_episode_steps=1000, eval_episode_steps=1000, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train reset_goal_frequency=1000, # virtual episode size for reset-free training train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, # reset-free parameters use_minimum=True, reset_lagrange_learning_rate=3e-4, value_threshold=None, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, # Td3 parameters actor_update_period=1, exploration_noise_std=0.1, target_policy_noise=0.1, target_policy_noise_clip=0.1, dqda_clipping=None, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, # video recording for the environment video_record_interval=10000, num_videos=0, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): start_time = time.time() root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') video_dir = os.path.join(eval_dir, 'videos') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) if FLAGS.use_reset_goals in [-1]: gym_env_wrappers = (functools.partial( reset_free_wrapper.GoalTerminalResetWrapper, num_success_states=FLAGS.num_success_states, full_reset_frequency=max_episode_steps), ) elif FLAGS.use_reset_goals in [0, 1]: gym_env_wrappers = (functools.partial( reset_free_wrapper.ResetFreeWrapper, reset_goal_frequency=reset_goal_frequency, variable_horizon_for_reset=FLAGS.variable_reset_horizon, num_success_states=FLAGS.num_success_states, full_reset_frequency=max_episode_steps), ) elif FLAGS.use_reset_goals in [2]: gym_env_wrappers = (functools.partial( reset_free_wrapper.CustomOracleResetWrapper, partial_reset_frequency=reset_goal_frequency, episodes_before_full_reset=max_episode_steps // reset_goal_frequency), ) elif FLAGS.use_reset_goals in [3, 4]: gym_env_wrappers = (functools.partial( reset_free_wrapper.GoalTerminalResetFreeWrapper, reset_goal_frequency=reset_goal_frequency, num_success_states=FLAGS.num_success_states, full_reset_frequency=max_episode_steps), ) elif FLAGS.use_reset_goals in [5, 7]: gym_env_wrappers = (functools.partial( reset_free_wrapper.CustomOracleResetGoalTerminalWrapper, partial_reset_frequency=reset_goal_frequency, episodes_before_full_reset=max_episode_steps // reset_goal_frequency), ) elif FLAGS.use_reset_goals in [6]: gym_env_wrappers = (functools.partial( reset_free_wrapper.VariableGoalTerminalResetWrapper, full_reset_frequency=max_episode_steps), ) if env_name == 'playpen_reduced': train_env_load_fn = functools.partial( env_load_fn, reset_at_goal=FLAGS.reset_at_goal) else: train_env_load_fn = env_load_fn env, env_train_metrics, env_eval_metrics, aux_info = train_env_load_fn( name=env_name, max_episode_steps=None, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(name=eval_env_name, max_episode_steps=eval_episode_steps)[0]) eval_metrics += env_eval_metrics time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if FLAGS.agent_type == 'sac': actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=std_clip_transform)) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform', ) critic_net_no_entropy = None critic_no_entropy_optimizer = None if FLAGS.use_no_entropy_q: critic_net_no_entropy = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform', name='CriticNetworkNoEntropy1') critic_no_entropy_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate) tf_agent = SacAgent( time_step_spec, action_spec, num_action_samples=FLAGS.num_action_samples, actor_network=actor_net, critic_network=critic_net, critic_network_no_entropy=critic_net_no_entropy, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), critic_no_entropy_optimizer=critic_no_entropy_optimizer, alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif FLAGS.agent_type == 'td3': actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, target_policy_noise=target_policy_noise, target_policy_noise_clip=target_policy_noise_clip, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() if FLAGS.use_reset_goals > 0: if FLAGS.use_reset_goals in [4, 5, 6]: reset_goal_generator = ScheduledResetGoal( goal_dim=aux_info['reset_state_shape'][0], num_success_for_switch=FLAGS.num_success_for_switch, num_chunks=FLAGS.num_chunks, name='ScheduledResetGoalGenerator') else: # distance to initial state distribution initial_state_distance = state_distribution_distance.L2Distance( initial_state_shape=aux_info['reset_state_shape']) initial_state_distance.update(tf.constant( aux_info['reset_states'], dtype=tf.float32), update_type='complete') if use_tf_functions: initial_state_distance.distance = common.function( initial_state_distance.distance) tf_agent.compute_value = common.function( tf_agent.compute_value) # initialize reset / practice goal proposer if reset_lagrange_learning_rate > 0: reset_goal_generator = ResetGoalGenerator( goal_dim=aux_info['reset_state_shape'][0], compute_value_fn=tf_agent.compute_value, distance_fn=initial_state_distance, use_minimum=use_minimum, value_threshold=value_threshold, lagrange_variable_max=FLAGS.lagrange_max, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=reset_lagrange_learning_rate), name='reset_goal_generator') else: reset_goal_generator = FixedResetGoal( distance_fn=initial_state_distance) # if use_tf_functions: # reset_goal_generator.get_reset_goal = common.function( # reset_goal_generator.get_reset_goal) # modify the reset-free wrapper to use the reset goal generator tf_env.pyenv.envs[0].set_reset_goal_fn( reset_goal_generator.get_reset_goal) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] if FLAGS.relabel_goals: cur_episode_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, scope='CurEpisodeReplayBuffer', max_length=int(2 * min(reset_goal_frequency, max_episode_steps))) # NOTE: the buffer is replaced because cannot have two buffers.add_batch replay_observer = [cur_episode_buffer.add_batch] # initialize metrics and observers train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] train_metrics += env_train_metrics eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_agent.policy, use_tf_function=True) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if offline_dir is not None: offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=int(1e5)) # this has to be 100_000 offline_checkpointer = common.Checkpointer( ckpt_dir=offline_dir, max_to_keep=1, replay_buffer=offline_data) offline_checkpointer.initialize_or_restore() # set the reset candidates to be all the data in offline buffer if (FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0 ) or FLAGS.use_reset_goals in [4, 5, 6, 7]: tf_env.pyenv.envs[0].set_reset_candidates( nest_utils.unbatch_nested_tensors( offline_data.gather_all())) if replay_buffer.num_frames() == 0: if offline_dir is not None: copy_replay_buffer(offline_data, replay_buffer) print(replay_buffer.num_frames()) # multiply offline data if FLAGS.relabel_offline_data: data_multiplier(replay_buffer, tf_env.pyenv.envs[0].env.compute_reward) print('after data multiplication:', replay_buffer.num_frames()) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=1) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) for iter_idx in range(initial_collect_steps): time_step, policy_state = initial_collect_driver.run( time_step=time_step, policy_state=policy_state) if time_step.is_last() and FLAGS.relabel_goals: reward_fn = tf_env.pyenv.envs[0].env.compute_reward relabel_function(cur_episode_buffer, time_step, reward_fn, replay_buffer) cur_episode_buffer.clear() if FLAGS.use_reset_goals > 0 and time_step.is_last( ) and FLAGS.num_reset_candidates > 0: tf_env.pyenv.envs[0].set_reset_candidates( replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates)[0]) else: time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) # manual data save for plotting utils np_custom_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval) try: average_eval_return = np_custom_load( os.path.join(eval_dir, 'average_eval_return.npy')).tolist() average_eval_success = np_custom_load( os.path.join(eval_dir, 'average_eval_success.npy')).tolist() average_eval_final_success = np_custom_load( os.path.join(eval_dir, 'average_eval_final_success.npy')).tolist() except: # pylint: disable=bare-except average_eval_return = [] average_eval_success = [] average_eval_final_success = [] print('initialization_time:', time.time() - start_time) for iter_idx in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if time_step.is_last() and FLAGS.relabel_goals: reward_fn = tf_env.pyenv.envs[0].env.compute_reward relabel_function(cur_episode_buffer, time_step, reward_fn, replay_buffer) cur_episode_buffer.clear() # reset goal generator updates if FLAGS.use_reset_goals > 0 and iter_idx % ( FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0: if FLAGS.num_reset_candidates > 0: tf_env.pyenv.envs[0].set_reset_candidates( replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates)[0]) if reset_lagrange_learning_rate > 0: reset_goal_generator.update_lagrange_multipliers() for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: if 'Heatmap' in train_metric.name: if global_step_val % summary_interval == 0: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) else: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0: reset_states, values, initial_state_distance_vals, lagrangian = reset_goal_generator.update_summaries( step_counter=global_step) for vf_viz_metric in aux_info['value_fn_viz_metrics']: vf_viz_metric.tf_summaries(reset_states, values, train_step=global_step, step_metrics=train_metrics[:2]) if FLAGS.debug_value_fn_for_reset: num_test_lagrange = 20 hyp_lagranges = [ 1.0 * increment / num_test_lagrange for increment in range(num_test_lagrange + 1) ] door_pos = reset_states[ np.argmin(initial_state_distance_vals.numpy() - lagrangian.numpy() * values.numpy())][3:5] print('cur lagrange: %.2f, cur reset goal: (%.2f, %.2f)' % (lagrangian.numpy(), door_pos[0], door_pos[1])) for lagrange in hyp_lagranges: door_pos = reset_states[ np.argmin(initial_state_distance_vals.numpy() - lagrange * values.numpy())][3:5] print( 'test lagrange: %.2f, cur reset goal: (%.2f, %.2f)' % (lagrange, door_pos[0], door_pos[1])) print('\n') if global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) # numpy saves for plotting if 'AverageReturn' in results.keys(): average_eval_return.append( results['AverageReturn'].numpy()) if 'EvalSuccessfulAtAnyStep' in results.keys(): average_eval_success.append( results['EvalSuccessfulAtAnyStep'].numpy()) if 'EvalSuccessfulEpisodes' in results.keys(): average_eval_final_success.append( results['EvalSuccessfulEpisodes'].numpy()) elif 'EvalSuccessfulAtLastStep' in results.keys(): average_eval_final_success.append( results['EvalSuccessfulAtLastStep'].numpy()) if average_eval_return: np_custom_save( os.path.join(eval_dir, 'average_eval_return.npy'), average_eval_return) if average_eval_success: np_custom_save( os.path.join(eval_dir, 'average_eval_success.npy'), average_eval_success) if average_eval_final_success: np_custom_save( os.path.join(eval_dir, 'average_eval_final_success.npy'), average_eval_final_success) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % video_record_interval == 0: for video_idx in range(num_videos): video_name = os.path.join( video_dir, str(global_step_val), 'video_' + str(video_idx) + '.mp4') record_video( lambda: env_load_fn( # pylint: disable=g-long-lambda name=env_name, max_episode_steps=max_episode_steps)[0], video_name, eval_py_policy, max_episode_length=eval_episode_steps) return train_loss
def data_multiplier(offline_data, reward_fn): def _custom_print(some_traj): # pylint: disable=unused-variable np.set_printoptions(precision=2, suppress=True) print('step', some_traj.step_type.numpy(), 'obs', some_traj.observation.numpy(), 'action', some_traj.action.numpy(), 'reward', some_traj.reward.numpy(), 'next_step', some_traj.next_step_type.numpy(), 'discount', some_traj.discount.numpy()) all_data = nest_utils.unbatch_nested_tensors(offline_data.gather_all()) all_trajs = nest_utils.unstack_nested_tensors(all_data, offline_data.data_spec) for idx, traj in enumerate(all_trajs): # print('index:', idx) if traj.step_type.numpy() == 0: ep_start_idx = idx # print('new start index:', ep_start_idx) # TODO(architsh): remove this and change to else: # elif idx in [12, 24, 36, 48, 60, 72, 84, 96, 108]: else: # print('adding new trajectory') obs_dim = traj.observation.shape[0] // 2 relabel_goal = traj.observation[:obs_dim] # print('new goal:', relabel_goal) last_traj_idx = len(all_trajs[ep_start_idx:idx + 1]) for traj_idx, cur_trajectory in enumerate( all_trajs[ep_start_idx:idx + 1]): if cur_trajectory.step_type.numpy() != 2: new_obs = tf.concat( [cur_trajectory.observation[:obs_dim], relabel_goal], axis=0) next_obs = tf.concat([ all_trajs[ep_start_idx + traj_idx + 1].observation[:obs_dim], relabel_goal ], axis=0) new_reward = tf.constant(reward_fn(obs=next_obs)) # terminate episode if new_reward.numpy() > 0.0: new_traj = cur_trajectory._replace( observation=new_obs, next_step_type=tf.constant(2), reward=new_reward, discount=tf.constant(0., dtype=tf.float32)) last_traj_idx = ep_start_idx + traj_idx + 1 # _custom_print(new_traj) offline_data.add_batch( nest_utils.batch_nested_tensors(new_traj)) break else: new_traj = cur_trajectory._replace( observation=new_obs, reward=new_reward, ) # _custom_print(new_traj) offline_data.add_batch( nest_utils.batch_nested_tensors(new_traj)) last_observation = tf.concat( [all_trajs[last_traj_idx].observation[:obs_dim], relabel_goal], axis=0) last_traj = cur_trajectory._replace( # pylint: disable=undefined-loop-variable step_type=tf.constant(2), observation=last_observation, next_step_type=tf.constant(0), reward=tf.constant(0.0), discount=tf.constant(1., dtype=tf.float32)) # _custom_print(last_traj) offline_data.add_batch(nest_utils.batch_nested_tensors(last_traj))
def copy_replay_buffer(small_buffer, big_buffer): """Copy small buffer into the big buffer.""" all_data = nest_utils.unbatch_nested_tensors(small_buffer.gather_all()) for trajectory in nest_utils.unstack_nested_tensors( # pylint: disable=redefined-outer-name all_data, big_buffer.data_spec): big_buffer.add_batch(trajectory)