コード例 #1
0
ファイル: sac_train_eval.py プロジェクト: morgandu/agents
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.enable_v2_behavior()

    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu)

    train_eval(FLAGS.root_dir,
               strategy=strategy,
               num_iterations=FLAGS.num_iterations,
               reverb_port=FLAGS.reverb_port,
               eval_interval=FLAGS.eval_interval)
コード例 #2
0
ファイル: sac_train.py プロジェクト: pratikkorat26/agents
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.enable_v2_behavior()

  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

  strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu)

  train(
      root_dir=FLAGS.root_dir,
      environment_name=gin.REQUIRED,
      strategy=strategy,
      replay_buffer_server_address=FLAGS.replay_buffer_server_address,
      variable_container_server_address=FLAGS.variable_container_server_address)
コード例 #3
0
    def test_tpu_strategy(self, mock_tpu_cluster_resolver,
                          mock_experimental_connect_to_cluster,
                          mock_initialize_tpu_system, mock_tpu_strategy):
        resolver = mock.MagicMock()
        mock_tpu_cluster_resolver.return_value = resolver
        mock_strategy = mock.MagicMock()
        mock_tpu_strategy.return_value = mock_strategy

        strategy = strategy_utils.get_strategy(tpu='bns_address',
                                               use_gpu=False)

        mock_tpu_cluster_resolver.assert_called_with(tpu='bns_address')
        mock_experimental_connect_to_cluster.assert_called_with(resolver)
        mock_initialize_tpu_system.assert_called_with(resolver)
        self.assertIs(strategy, mock_strategy)
コード例 #4
0
    def test_mirrored_strategy(self, mock_mirrored_strategy):
        mirrored_strategy = mock.MagicMock()
        mock_mirrored_strategy.return_value = mirrored_strategy

        strategy = strategy_utils.get_strategy(False, use_gpu=True)
        self.assertIs(strategy, mirrored_strategy)
コード例 #5
0
    def test_get_distribution_strategy_default(self):
        # Get a default strategy to compare against.
        default_strategy = tf.distribute.get_strategy()

        strategy = strategy_utils.get_strategy(tpu=False, use_gpu=False)
        self.assertIsInstance(strategy, type(default_strategy))
コード例 #6
0
def train_eval(
        root_dir,
        dataset_path,
        env_name,
        # Training params
        tpu=False,
        use_gpu=False,
        num_gradient_updates=1000000,
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256, 256),
        # Agent params
        batch_size=256,
        bc_steps=0,
        actor_learning_rate=3e-5,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        reward_scale_factor=1.0,
        cql_alpha_learning_rate=3e-4,
        cql_alpha=5.0,
        cql_tau=10.0,
        num_cql_samples=10,
        reward_noise_variance=0.0,
        include_critic_entropy_term=False,
        use_lagrange_cql_alpha=True,
        log_cql_alpha_clipping=None,
        softmax_temperature=1.0,
        # Data params
        reward_shift=0.0,
        action_clipping=None,
        use_trajectories=False,
        data_shuffle_buffer_size_per_record=1,
        data_shuffle_buffer_size=100,
        data_num_shards=1,
        data_block_length=10,
        data_parallel_reads=None,
        data_parallel_calls=10,
        data_prefetch=10,
        data_cycle_length=10,
        # Others
        policy_save_interval=10000,
        eval_interval=10000,
        summary_interval=1000,
        learner_iterations_per_call=1,
        eval_episodes=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        seed=None):
    """Trains and evaluates CQL-SAC."""
    logging.info('Training CQL-SAC on: %s', env_name)
    tf.random.set_seed(seed)
    np.random.seed(seed)

    # Load environment.
    env = load_d4rl(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    strategy = strategy_utils.get_strategy(tpu, use_gpu)

    if not dataset_path.endswith('.tfrecord'):
        dataset_path = os.path.join(dataset_path, env_name,
                                    '%s*.tfrecord' % env_name)
    logging.info('Loading dataset from %s', dataset_path)
    dataset_paths = tf.io.gfile.glob(dataset_path)

    # Create dataset.
    with strategy.scope():
        dataset = create_tf_record_dataset(
            dataset_paths,
            batch_size,
            shuffle_buffer_size_per_record=data_shuffle_buffer_size_per_record,
            shuffle_buffer_size=data_shuffle_buffer_size,
            num_shards=data_num_shards,
            cycle_length=data_cycle_length,
            block_length=data_block_length,
            num_parallel_reads=data_parallel_reads,
            num_parallel_calls=data_parallel_calls,
            num_prefetch=data_prefetch,
            strategy=strategy,
            reward_shift=reward_shift,
            action_clipping=action_clipping,
            use_trajectories=use_trajectories)

    # Create agent.
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()
    with strategy.scope():
        train_step = train_utils.create_train_step()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        agent = cql_sac_agent.CqlSacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.keras.optimizers.Adam(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.keras.optimizers.Adam(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.keras.optimizers.Adam(
                learning_rate=alpha_learning_rate),
            cql_alpha=cql_alpha,
            num_cql_samples=num_cql_samples,
            include_critic_entropy_term=include_critic_entropy_term,
            use_lagrange_cql_alpha=use_lagrange_cql_alpha,
            cql_alpha_learning_rate=cql_alpha_learning_rate,
            target_update_tau=5e-3,
            target_update_period=1,
            random_seed=seed,
            cql_tau=cql_tau,
            reward_noise_variance=reward_noise_variance,
            num_bc_steps=bc_steps,
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=0.99,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=None,
            log_cql_alpha_clipping=log_cql_alpha_clipping,
            softmax_temperature=softmax_temperature,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step)
        agent.initialize()

    # Create learner.
    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    collect_env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(saved_model_dir,
                                         agent,
                                         train_step,
                                         interval=policy_save_interval,
                                         metadata_metrics={
                                             triggers.ENV_STEP_METADATA_KEY:
                                             collect_env_step_metric
                                         }),
        triggers.StepPerSecondLogTrigger(train_step, interval=100)
    ]
    cql_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn=lambda: dataset,
                                  triggers=learning_triggers,
                                  summary_interval=summary_interval,
                                  strategy=strategy)

    # Create actor for evaluation.
    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)
    eval_actor = actor.Actor(env,
                             eval_greedy_policy,
                             train_step,
                             metrics=actor.eval_metrics(eval_episodes),
                             summary_dir=os.path.join(root_dir, 'eval'),
                             episodes_per_run=eval_episodes)

    # Run.
    dummy_trajectory = trajectory.mid((), (), (), 0., 1.)
    num_learner_iterations = int(num_gradient_updates /
                                 learner_iterations_per_call)
    for _ in range(num_learner_iterations):
        # Mimic collecting environment steps since we loaded a static dataset.
        for _ in range(learner_iterations_per_call):
            collect_env_step_metric(dummy_trajectory)

        cql_learner.run(iterations=learner_iterations_per_call)
        if eval_interval and train_step.numpy() % eval_interval == 0:
            eval_actor.run_and_log()
コード例 #7
0
from tf_agents.replay_buffers.tf_uniform_replay_buffer import TFUniformReplayBuffer
from tf_agents.metrics.tf_metrics import AverageReturnMetric
from tf_agents.policies import random_tf_policy
from drivers import TFRenderDriver

from tf_agents.policies.actor_policy import ActorPolicy
from tf_agents.trajectories import time_step as ts
from tf_agents.policies.policy_saver import PolicySaver

if __name__ == '__main__':

    py_env = suite_pybullet.load('AntBulletEnv-v0')
    py_env.render(mode="human")
    env = tf_py_environment.TFPyEnvironment(py_env)

    strategy = strategy_utils.get_strategy(tpu=False, use_gpu=True)

    replay_buffer_capacity = 2000
    learning_rate = 1e-3
    fc_layer_params = [128, 64, 64]

    num_iterations = 100

    log_interval = 2
    eval_interval = 2

    action_tensor_spec = tensor_spec.from_spec(env.action_spec())

    num_actions = action_tensor_spec.shape[0]

    with strategy.scope():