예제 #1
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                           num_minibatches=FLAGS.num_minibatches,
                           num_epochs=FLAGS.num_epochs,
                           batch_size=FLAGS.batch_size)

    learner_logger = experiment_utils.make_experiment_logger(
        label='learner', steps_key='learner_steps')
    agent = ppo.PPO(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed,
                    counter=counting.Counter(prefix='learner'),
                    logger=learner_logger)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
예제 #2
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = value_dice.make_networks(environment_spec)

    # Construct the agent.
    config = value_dice.ValueDiceConfig(
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step)
    agent = value_dice.ValueDice(environment_spec,
                                 agent_networks,
                                 config=config,
                                 make_demonstrations=functools.partial(
                                     helpers.make_demonstration_iterator,
                                     dataset_name=FLAGS.dataset_name),
                                 seed=FLAGS.seed)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=value_dice.apply_policy_and_sample(agent_networks,
                                                          eval_mode=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
예제 #3
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = td3.make_networks(environment_spec)

    # Construct the agent.
    config = td3.TD3Config(num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step)
    agent = td3.TD3(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            agent_networks, environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
예제 #4
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the TD3Builder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    td3_config = td3.TD3Config(
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    td3_networks = td3.make_networks(environment_spec)
    if FLAGS.pretrain:
        td3_networks = add_bc_pretraining(td3_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size *
                               td3_config.num_sgd_steps_per_step)
    dac_config = ail.DACConfig(ail_config, td3_config)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=td3_networks)

    agent = ail.DAC(spec=environment_spec,
                    network=ail_network,
                    config=dac_config,
                    seed=FLAGS.seed,
                    batch_size=td3_config.batch_size *
                    td3_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=td3.get_default_behavior_policy(
                        td3_networks,
                        action_specs=environment_spec.actions,
                        sigma=td3_config.sigma))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    # TODO(lukstafi): sigma=0 for eval?
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            td3_networks, action_specs=environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
예제 #5
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    ppo_config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                               num_minibatches=FLAGS.ppo_num_minibatches,
                               num_epochs=FLAGS.ppo_num_epochs,
                               batch_size=FLAGS.transition_batch_size //
                               FLAGS.unroll_length,
                               learning_rate=0.0003,
                               entropy_cost=0,
                               gae_lambda=0.8,
                               value_cost=0.25)
    ppo_networks = ppo.make_continuous_networks(environment_spec)
    if FLAGS.pretrain:
        ppo_networks = add_bc_pretraining(ppo_networks)

    discriminator_batch_size = FLAGS.transition_batch_size
    ail_config = ail.AILConfig(
        direct_rl_batch_size=ppo_config.batch_size * ppo_config.unroll_length,
        discriminator_batch_size=discriminator_batch_size,
        is_sequence_based=True,
        num_sgd_steps_per_step=FLAGS.num_discriminator_steps_per_step,
        share_iterator=FLAGS.share_iterator,
    )

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        # Note: observation embedding is not needed for e.g. Mujoco.
        return ail.DiscriminatorModule(
            environment_spec=environment_spec,
            use_action=True,
            use_next_obs=True,
            network_core=ail.DiscriminatorMLP([4, 4], ),
        )(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=ppo_networks)

    agent = ail.GAIL(spec=environment_spec,
                     network=ail_network,
                     config=ail.GAILConfig(ail_config, ppo_config),
                     seed=FLAGS.seed,
                     batch_size=ppo_config.batch_size,
                     make_demonstrations=functools.partial(
                         helpers.make_demonstration_iterator,
                         dataset_name=FLAGS.dataset_name),
                     policy_network=ppo.make_inference_fn(ppo_networks))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)