Beispiel #1
0
    def test_sac_fd(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        sac_network = sac.make_networks(spec)

        batch_size = 10
        sac_config = sac.SACConfig(
            batch_size=batch_size,
            target_entropy=sac.target_entropy_from_env_spec(spec),
            min_replay_size=1)
        lfd_config = config.LfdConfig(initial_insert_count=0,
                                      demonstration_ratio=0.2)
        sac_fd_config = sacfd_agents.SACfDConfig(lfd_config=lfd_config,
                                                 sac_config=sac_config)
        counter = counting.Counter()
        agent = sacfd_agents.SACfD(spec=spec,
                                   sac_network=sac_network,
                                   sac_fd_config=sac_fd_config,
                                   lfd_iterator_fn=fake_demonstration_iterator,
                                   seed=0,
                                   counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
Beispiel #2
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = sac.make_networks(environment_spec)

    # Construct the agent.
    config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        seed=FLAGS.seed)
    agent = sac.SAC(environment_spec, agent_networks, config=config)

    # Create the environment loop used for training.
    train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop')
    # Create the evaluation actor and loop.
    eval_actor = agent.builder.make_actor(
        policy_network=sac.apply_policy_and_sample(agent_networks,
                                                   eval_mode=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop')

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Beispiel #3
0
def main(_):
    task = FLAGS.task
    environment_factory = lambda seed: helpers.make_environment(task)
    program = sac.DistributedSAC(
        environment_factory=environment_factory,
        network_factory=sac.make_networks,
        config=sac.SACConfig(**{'num_sgd_steps_per_step': 64}),
        num_actors=4,
        seed=1,
        max_number_of_steps=100).build()

    # Launch experiment.
    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Beispiel #4
0
def main(_):
    task = FLAGS.task
    environment_factory = lambda seed: helpers.make_environment(task)
    sac_config = sac.SACConfig(num_sgd_steps_per_step=64)
    sac_builder = sac.SACBuilder(sac_config)

    ail_config = ail.AILConfig(direct_rl_batch_size=sac_config.batch_size *
                               sac_config.num_sgd_steps_per_step)

    def network_factory(spec: specs.EnvironmentSpec) -> ail.AILNetworks:
        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            return ail.DiscriminatorModule(environment_spec=spec,
                                           use_action=True,
                                           use_next_obs=True,
                                           network_core=ail.DiscriminatorMLP(
                                               [4, 4], ))(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        return ail.AILNetworks(ail.make_discriminator(
            spec, discriminator_transformed),
                               imitation_reward_fn=ail.rewards.gail_reward(),
                               direct_rl_networks=sac.make_networks(spec))

    def policy_network(
            network: ail.AILNetworks,
            eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy:
        return sac.apply_policy_and_sample(network.direct_rl_networks,
                                           eval_mode=eval_mode)

    program = ail.DistributedAIL(
        environment_factory=environment_factory,
        rl_agent=sac_builder,
        config=ail_config,
        network_factory=network_factory,
        seed=0,
        batch_size=sac_config.batch_size * sac_config.num_sgd_steps_per_step,
        make_demonstrations=functools.partial(
            helpers.make_demonstration_iterator,
            dataset_name=FLAGS.dataset_name),
        policy_network=policy_network,
        evaluator_policy_network=(lambda n: policy_network(n, eval_mode=True)),
        num_actors=4,
        max_number_of_steps=100,
        discriminator_loss=ail.losses.gail_loss()).build()

    # Launch experiment.
    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Beispiel #5
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = sac.make_networks(environment_spec)

    # Construct the agent.
    config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step)
    agent = sac.SAC(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed)

    # Create the environment loop used for training.
    logger = experiment_utils.make_experiment_logger(label='train',
                                                     steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=sac.apply_policy_and_sample(agent_networks,
                                                   eval_mode=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Beispiel #6
0
    def test_distributed_sac_fd(self):
        def make_env(seed):
            del seed
            return fakes.ContinuousEnvironment(episode_length=10,
                                               action_dim=3,
                                               observation_dim=5,
                                               bounded=True)

        spec = specs.make_environment_spec(make_env(seed=0))

        batch_size = 10
        sac_config = sac.SACConfig(
            batch_size=batch_size,
            target_entropy=sac.target_entropy_from_env_spec(spec),
            min_replay_size=16,
            samples_per_insert=2)
        lfd_config = config.LfdConfig(initial_insert_count=0,
                                      demonstration_ratio=0.2)
        sac_fd_config = sacfd_agents.SACfDConfig(lfd_config=lfd_config,
                                                 sac_config=sac_config)

        agent = sacfd_agents.DistributedSACfD(
            environment_factory=make_env,
            network_factory=sac.make_networks,
            sac_fd_config=sac_fd_config,
            lfd_iterator_fn=fake_demonstration_iterator,
            seed=0,
            num_actors=2)

        program = agent.build()
        (learner_node, ) = program.groups['learner']
        learner_node.disable_run()  # pytype: disable=attribute-error

        lp.launch(program, launch_type='test_mt')

        learner: acme.Learner = learner_node.create_handle().dereference()

        for _ in range(5):
            learner.step()
Beispiel #7
0
def train_and_evaluate(distance_fn, rng):
    """Train a policy on the learned distance function and evaluate task success.

  Args:
    distance_fn: function mapping a (state, goal)-pair to a state embedding and
        a distance estimate used for policy learning.
    rng: random key used to initialize evaluation actor.
  """
    goal_image = load_goal_image(FLAGS.robot_data_path)
    logdir = FLAGS.logdir
    video_dir = paths.process_path(logdir, 'videos')
    print('Writing videos to', video_dir)
    counter = counting.Counter()
    eval_counter = counting.Counter(counter, prefix='eval', time_delta=0.0)
    # Include training episodes and steps and walltime in the first eval logs.
    counter.increment(episodes=0, steps=0, walltime=0)

    environment = make_environment(
        task=FLAGS.task,
        end_on_success=FLAGS.end_on_success,
        max_episode_steps=FLAGS.max_episode_steps,
        distance_fn=distance_fn,
        goal_image=goal_image,
        baseline_distance=FLAGS.baseline_distance,
        logdir=video_dir,
        counter=counter,
        record_every=FLAGS.record_episodes_frequency,
        num_episodes_to_record=FLAGS.num_episodes_to_record)
    environment_spec = specs.make_environment_spec(environment)
    print('Environment spec')
    print(environment_spec)
    agent_networks = sac.make_networks(environment_spec)

    config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=FLAGS.min_replay_size)
    agent = deprecated_sac.SAC(environment_spec,
                               agent_networks,
                               config=config,
                               counter=counter,
                               seed=FLAGS.seed)

    env_logger = loggers.CSVLogger(logdir, 'env_loop', flush_every=5)
    eval_env_logger = loggers.CSVLogger(logdir, 'eval_env_loop', flush_every=1)
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      label='train_loop',
                                      logger=env_logger,
                                      counter=counter)

    eval_actor = agent.builder.make_actor(random_key=rng,
                                          policy=sac.apply_policy_and_sample(
                                              agent_networks, eval_mode=True),
                                          environment_spec=environment_spec,
                                          variable_source=agent)

    eval_video_dir = paths.process_path(logdir, 'eval_videos')
    print('Writing eval videos to', eval_video_dir)
    if FLAGS.baseline_distance_from_goal_to_goal:
        state = goal_image
        if distance_fn.history_length > 1:
            state = np.stack([goal_image] * distance_fn.history_length,
                             axis=-1)
        unused_embeddings, baseline_distance = distance_fn(state, goal_image)
        print('Baseline prediction', baseline_distance)
    else:
        baseline_distance = FLAGS.baseline_distance
    eval_env = make_environment(task=FLAGS.task,
                                end_on_success=False,
                                max_episode_steps=FLAGS.max_episode_steps,
                                distance_fn=distance_fn,
                                goal_image=goal_image,
                                eval_mode=True,
                                logdir=eval_video_dir,
                                counter=eval_counter,
                                record_every=FLAGS.num_eval_episodes,
                                num_episodes_to_record=FLAGS.num_eval_episodes,
                                baseline_distance=baseline_distance)

    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     label='eval_loop',
                                     logger=eval_env_logger,
                                     counter=eval_counter)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
Beispiel #8
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = sac.make_networks(environment_spec)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the SACBuilder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    sac_config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    sac_builder = sac.SACBuilder(sac_config)
    sac_networks = sac.make_networks(environment_spec)
    sac_networks = add_bc_pretraining(sac_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=sac_config.batch_size *
                               sac_config.num_sgd_steps_per_step)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=sac_networks)

    agent = ail.AIL(spec=environment_spec,
                    rl_agent=sac_builder,
                    network=ail_network,
                    config=ail_config,
                    seed=FLAGS.seed,
                    batch_size=sac_config.batch_size *
                    sac_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=sac.apply_policy_and_sample(sac_networks),
                    discriminator_loss=ail.losses.gail_loss())

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=sac.apply_policy_and_sample(agent_networks,
                                                   eval_mode=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Beispiel #9
0
    def test_ail_flax(self):
        shutil.rmtree(flags.FLAGS.test_tmpdir)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        environment = fakes.ContinuousEnvironment(
            episode_length=EPISODE_LENGTH,
            action_dim=CONTINUOUS_ACTION_DIM,
            observation_dim=CONTINUOUS_OBS_DIM,
            bounded=True)
        spec = specs.make_environment_spec(environment)

        networks = sac.make_networks(spec=spec)
        config = sac.SACConfig(batch_size=batch_size,
                               samples_per_insert_tolerance_rate=float('inf'),
                               min_replay_size=1)
        base_builder = sac.SACBuilder(config=config)
        direct_rl_batch_size = batch_size
        behavior_policy = sac.apply_policy_and_sample(networks)

        discriminator_module = DiscriminatorModule(spec, linen.Dense(1))

        def apply_fn(params: networks_lib.Params,
                     policy_params: networks_lib.Params,
                     state: networks_lib.Params, transitions: types.Transition,
                     is_training: bool,
                     rng: networks_lib.PRNGKey) -> networks_lib.Logits:
            del policy_params
            variables = dict(params=params, **state)
            return discriminator_module.apply(variables,
                                              transitions.observation,
                                              transitions.action,
                                              transitions.next_observation,
                                              is_training=is_training,
                                              rng=rng,
                                              mutable=state.keys())

        def init_fn(rng):
            variables = discriminator_module.init(rng,
                                                  dummy_obs,
                                                  dummy_actions,
                                                  dummy_obs,
                                                  is_training=False,
                                                  rng=rng)
            init_state, discriminator_params = variables.pop('params')
            return discriminator_params, init_state

        dummy_obs = utils.zeros_like(spec.observations)
        dummy_obs = utils.add_batch_dim(dummy_obs)
        dummy_actions = utils.zeros_like(spec.actions)
        dummy_actions = utils.add_batch_dim(dummy_actions)
        discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn,
                                                                apply=apply_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(is_sequence_based=False,
                                 share_iterator=True,
                                 direct_rl_batch_size=direct_rl_batch_size,
                                 discriminator_batch_size=2,
                                 policy_variable_name=None,
                                 min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        counter = counting.Counter()
        # Construct the agent.
        agent = local_layout.LocalLayout(
            seed=0,
            environment_spec=spec,
            builder=builder,
            networks=networks,
            policy_network=behavior_policy,
            min_replay_size=1,
            batch_size=batch_size,
            counter=counter,
        )

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        train_loop.run(num_episodes=1)
Beispiel #10
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))