def main(_): logging.set_verbosity(logging.INFO) tf.compat.v1.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu) train_eval(FLAGS.root_dir, strategy=strategy, num_iterations=FLAGS.num_iterations, reverb_port=FLAGS.reverb_port, eval_interval=FLAGS.eval_interval)
def main(_): logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu) train( root_dir=FLAGS.root_dir, environment_name=gin.REQUIRED, strategy=strategy, replay_buffer_server_address=FLAGS.replay_buffer_server_address, variable_container_server_address=FLAGS.variable_container_server_address)
def test_tpu_strategy(self, mock_tpu_cluster_resolver, mock_experimental_connect_to_cluster, mock_initialize_tpu_system, mock_tpu_strategy): resolver = mock.MagicMock() mock_tpu_cluster_resolver.return_value = resolver mock_strategy = mock.MagicMock() mock_tpu_strategy.return_value = mock_strategy strategy = strategy_utils.get_strategy(tpu='bns_address', use_gpu=False) mock_tpu_cluster_resolver.assert_called_with(tpu='bns_address') mock_experimental_connect_to_cluster.assert_called_with(resolver) mock_initialize_tpu_system.assert_called_with(resolver) self.assertIs(strategy, mock_strategy)
def test_mirrored_strategy(self, mock_mirrored_strategy): mirrored_strategy = mock.MagicMock() mock_mirrored_strategy.return_value = mirrored_strategy strategy = strategy_utils.get_strategy(False, use_gpu=True) self.assertIs(strategy, mirrored_strategy)
def test_get_distribution_strategy_default(self): # Get a default strategy to compare against. default_strategy = tf.distribute.get_strategy() strategy = strategy_utils.get_strategy(tpu=False, use_gpu=False) self.assertIsInstance(strategy, type(default_strategy))
def train_eval( root_dir, dataset_path, env_name, # Training params tpu=False, use_gpu=False, num_gradient_updates=1000000, actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256, 256), # Agent params batch_size=256, bc_steps=0, actor_learning_rate=3e-5, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, reward_scale_factor=1.0, cql_alpha_learning_rate=3e-4, cql_alpha=5.0, cql_tau=10.0, num_cql_samples=10, reward_noise_variance=0.0, include_critic_entropy_term=False, use_lagrange_cql_alpha=True, log_cql_alpha_clipping=None, softmax_temperature=1.0, # Data params reward_shift=0.0, action_clipping=None, use_trajectories=False, data_shuffle_buffer_size_per_record=1, data_shuffle_buffer_size=100, data_num_shards=1, data_block_length=10, data_parallel_reads=None, data_parallel_calls=10, data_prefetch=10, data_cycle_length=10, # Others policy_save_interval=10000, eval_interval=10000, summary_interval=1000, learner_iterations_per_call=1, eval_episodes=10, debug_summaries=False, summarize_grads_and_vars=False, seed=None): """Trains and evaluates CQL-SAC.""" logging.info('Training CQL-SAC on: %s', env_name) tf.random.set_seed(seed) np.random.seed(seed) # Load environment. env = load_d4rl(env_name) tf_env = tf_py_environment.TFPyEnvironment(env) strategy = strategy_utils.get_strategy(tpu, use_gpu) if not dataset_path.endswith('.tfrecord'): dataset_path = os.path.join(dataset_path, env_name, '%s*.tfrecord' % env_name) logging.info('Loading dataset from %s', dataset_path) dataset_paths = tf.io.gfile.glob(dataset_path) # Create dataset. with strategy.scope(): dataset = create_tf_record_dataset( dataset_paths, batch_size, shuffle_buffer_size_per_record=data_shuffle_buffer_size_per_record, shuffle_buffer_size=data_shuffle_buffer_size, num_shards=data_num_shards, cycle_length=data_cycle_length, block_length=data_block_length, num_parallel_reads=data_parallel_reads, num_parallel_calls=data_parallel_calls, num_prefetch=data_prefetch, strategy=strategy, reward_shift=reward_shift, action_clipping=action_clipping, use_trajectories=use_trajectories) # Create agent. time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() with strategy.scope(): train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = cql_sac_agent.CqlSacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), cql_alpha=cql_alpha, num_cql_samples=num_cql_samples, include_critic_entropy_term=include_critic_entropy_term, use_lagrange_cql_alpha=use_lagrange_cql_alpha, cql_alpha_learning_rate=cql_alpha_learning_rate, target_update_tau=5e-3, target_update_period=1, random_seed=seed, cql_tau=cql_tau, reward_noise_variance=reward_noise_variance, num_bc_steps=bc_steps, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=reward_scale_factor, gradient_clipping=None, log_cql_alpha_clipping=log_cql_alpha_clipping, softmax_temperature=softmax_temperature, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() # Create learner. saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=100) ] cql_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn=lambda: dataset, triggers=learning_triggers, summary_interval=summary_interval, strategy=strategy) # Create actor for evaluation. tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor(env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) # Run. dummy_trajectory = trajectory.mid((), (), (), 0., 1.) num_learner_iterations = int(num_gradient_updates / learner_iterations_per_call) for _ in range(num_learner_iterations): # Mimic collecting environment steps since we loaded a static dataset. for _ in range(learner_iterations_per_call): collect_env_step_metric(dummy_trajectory) cql_learner.run(iterations=learner_iterations_per_call) if eval_interval and train_step.numpy() % eval_interval == 0: eval_actor.run_and_log()
from tf_agents.replay_buffers.tf_uniform_replay_buffer import TFUniformReplayBuffer from tf_agents.metrics.tf_metrics import AverageReturnMetric from tf_agents.policies import random_tf_policy from drivers import TFRenderDriver from tf_agents.policies.actor_policy import ActorPolicy from tf_agents.trajectories import time_step as ts from tf_agents.policies.policy_saver import PolicySaver if __name__ == '__main__': py_env = suite_pybullet.load('AntBulletEnv-v0') py_env.render(mode="human") env = tf_py_environment.TFPyEnvironment(py_env) strategy = strategy_utils.get_strategy(tpu=False, use_gpu=True) replay_buffer_capacity = 2000 learning_rate = 1e-3 fc_layer_params = [128, 64, 64] num_iterations = 100 log_interval = 2 eval_interval = 2 action_tensor_spec = tensor_spec.from_spec(env.action_spec()) num_actions = action_tensor_spec.shape[0] with strategy.scope():