예제 #1
0
파일: run_d4pg.py 프로젝트: pchtsp/acme
def main(_):
  # Create an environment, grab the spec, and use it to create networks.
  environment = make_environment()
  environment_spec = specs.make_environment_spec(environment)
  agent_networks = make_networks(environment_spec.actions)

  # Construct the agent.
  agent = d4pg.D4PG(
      environment_spec=environment_spec,
      policy_network=agent_networks['policy'],
      critic_network=agent_networks['critic'],
      observation_network=agent_networks['observation'],  # pytype: disable=wrong-arg-types
  )

  # Create the environment loop used for training.
  train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop')

  # Create the evaluation policy.
  eval_policy = snt.Sequential([
      agent_networks['observation'],
      agent_networks['policy'],
  ])

  # Create the evaluation actor and loop.
  eval_actor = actors.FeedForwardActor(policy_network=eval_policy)
  eval_env = make_environment()
  eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop')

  for _ in range(FLAGS.num_episodes // FLAGS.num_episodes_per_eval):
    train_loop.run(num_episodes=FLAGS.num_episodes_per_eval)
    eval_loop.run(num_episodes=1)
예제 #2
0
def main(_):
    # Load environment.
    environment = utils.load_environment(FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Create Rewarder.
    demonstrations = utils.load_demonstrations(demo_dir=FLAGS.demo_dir,
                                               env_name=FLAGS.env_name)
    pwil_rewarder = rewarder.PWILRewarder(
        demonstrations,
        subsampling=FLAGS.subsampling,
        env_specs=environment_spec,
        num_demonstrations=FLAGS.num_demonstrations,
        observation_only=FLAGS.state_only)

    # Define D4PG agent.
    agent_networks = utils.make_d4pg_networks(environment_spec.actions)
    agent = d4pg.D4PG(
        environment_spec=environment_spec,
        policy_network=agent_networks['policy'],
        critic_network=agent_networks['critic'],
        observation_network=agent_networks['observation'],
        samples_per_insert=FLAGS.samples_per_insert,
        sigma=FLAGS.sigma,
    )

    # Prefill the agent's Replay Buffer.
    utils.prefill_rb_with_demonstrations(
        agent=agent,
        demonstrations=pwil_rewarder.demonstrations,
        num_transitions_rb=FLAGS.num_transitions_rb,
        reward=pwil_rewarder.reward_scale)

    # Create the eval policy (without exploration noise).
    eval_policy = snt.Sequential([
        agent_networks['observation'],
        agent_networks['policy'],
    ])
    eval_agent = FeedForwardActor(policy_network=eval_policy)

    # Define train/eval loops.
    logger = csv_logger.CSVLogger(directory=FLAGS.workdir, label='train_logs')
    train_loop = imitation_loop.TrainEnvironmentLoop(environment,
                                                     agent,
                                                     pwil_rewarder,
                                                     logger=logger)

    eval_logger = csv_logger.CSVLogger(directory=FLAGS.workdir,
                                       label='eval_logs')
    eval_loop = imitation_loop.EvalEnvironmentLoop(environment,
                                                   eval_agent,
                                                   pwil_rewarder,
                                                   logger=eval_logger)

    for _ in range(FLAGS.num_iterations):
        train_loop.run(num_steps=FLAGS.num_steps_per_iteration)
        eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
예제 #3
0
def main(_):
    # Initialize Neptune and create an experiment.
    neptune.init(FLAGS.neptune_project_name)
    experiment = neptune.create_experiment(name='Acme example')

    # Create an environment, grab the spec, and use it to create networks.
    environment = make_environment()
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = make_networks(environment_spec.actions)

    # Construct the agent.
    agent = d4pg.D4PG(
        environment_spec=environment_spec,
        policy_network=agent_networks['policy'],
        critic_network=agent_networks['critic'],
        observation_network=agent_networks['observation'],
        sigma=1.0,  # pytype: disable=wrong-arg-types
        logger=make_logger(experiment, prefix='learner'),
    )

    # Create the environment loop used for training.
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      label='train_loop',
                                      logger=make_logger(
                                          experiment,
                                          prefix='train',
                                          smoothing_regex='return'))

    # Create the evaluation policy.
    eval_policy = snt.Sequential([
        agent_networks['observation'],
        agent_networks['policy'],
    ])

    # Create the evaluation actor and loop.
    eval_actor = actors.FeedForwardActor(policy_network=eval_policy)
    eval_env = make_environment()
    eval_logger = make_logger(experiment,
                              prefix='eval',
                              aggregate_regex='return')
    eval_loop = acme.EnvironmentLoop(
        eval_env,
        eval_actor,
        label='eval_loop',
        logger=eval_logger,
    )

    for _ in range(FLAGS.num_episodes // FLAGS.num_episodes_per_eval):
        train_loop.run(num_episodes=FLAGS.num_episodes_per_eval)
        eval_loop.run(num_episodes=5)
        eval_logger.dump()
예제 #4
0
    def test_d4pg(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        agent_networks = make_networks(spec.actions)

        # Construct the agent.
        agent = d4pg.D4PG(
            environment_spec=spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
예제 #5
0
critic_network = snt.Sequential([
    # The multiplexer concatenates the observations/actions.
    networks.CriticMultiplexer(),
    networks.LayerNormMLP((512, 512, 256), activate_final=True),
    networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51),
])

# Create a logger for the agent and environment loop.
agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.)
env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)

# Create the D4PG agent.
agent = d4pg.D4PG(environment_spec=environment_spec,
                  policy_network=policy_network,
                  critic_network=critic_network,
                  observation_network=observation_network,
                  sigma=1.0,
                  logger=agent_logger,
                  checkpoint=False)

# Create an loop connecting this agent to the environment created above.
env_loop = environment_loop.EnvironmentLoop(environment,
                                            agent,
                                            logger=env_loop_logger)

# Run a `num_episodes` training episodes.
# Rerun this cell until the agent has learned the given task.
env_loop.run(num_episodes=5000)


@tf.function(input_signature=[tf.TensorSpec(shape=(1, 32), dtype=np.float32)])
예제 #6
0
def make_acme_agent(environment_spec,
                    residual_spec,
                    obs_network_type,
                    crop_frames,
                    full_image_size,
                    crop_margin_size,
                    late_fusion,
                    binary_grip_action=False,
                    input_type=None,
                    counter=None,
                    logdir=None,
                    agent_logger=None):
    """Initialize acme agent based on residual spec and agent flags."""
    # TODO(minttu): Is environment_spec needed or could we use residual_spec?
    del logdir  # Setting logdir for the learner ckpts not currently supported.
    obs_network = None
    if obs_network_type is not None:
        obs_network = agents.ObservationNet(network_type=obs_network_type,
                                            input_type=input_type,
                                            add_linear_layer=False,
                                            crop_frames=crop_frames,
                                            full_image_size=full_image_size,
                                            crop_margin_size=crop_margin_size,
                                            late_fusion=late_fusion)

    eval_policy = None
    if FLAGS.agent == 'MPO':
        agent_networks = networks.make_mpo_networks(
            environment_spec.actions,
            policy_init_std=FLAGS.policy_init_std,
            obs_network=obs_network)

        rl_agent = mpo.MPO(
            environment_spec=residual_spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_rl),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            counter=counter,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )
    elif FLAGS.agent == 'DMPO':
        agent_networks = networks.make_dmpo_networks(
            environment_spec.actions,
            policy_layer_sizes=FLAGS.rl_policy_layer_sizes,
            critic_layer_sizes=FLAGS.rl_critic_layer_sizes,
            vmin=FLAGS.critic_vmin,
            vmax=FLAGS.critic_vmax,
            num_atoms=FLAGS.critic_num_atoms,
            policy_init_std=FLAGS.policy_init_std,
            binary_grip_action=binary_grip_action,
            obs_network=obs_network)

        # spec = residual_spec if obs_network is None else environment_spec
        spec = residual_spec
        rl_agent = dmpo.DistributionalMPO(
            environment_spec=spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            counter=counter,
            # logdir=logdir,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )
        # Learned policy without exploration.
        eval_policy = (tf.function(
            snt.Sequential([
                tf_utils.to_sonnet_module(agent_networks['observation']),
                agent_networks['policy'],
                tf_networks.StochasticMeanHead()
            ])))
    elif FLAGS.agent == 'D4PG':
        agent_networks = networks.make_d4pg_networks(
            residual_spec.actions,
            vmin=FLAGS.critic_vmin,
            vmax=FLAGS.critic_vmax,
            num_atoms=FLAGS.critic_num_atoms,
            policy_weights_init_scale=FLAGS.policy_weights_init_scale,
            obs_network=obs_network)

        # TODO(minttu): downscale action space to [-1, 1] to match clipped gaussian.
        rl_agent = d4pg.D4PG(
            environment_spec=residual_spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            sigma=FLAGS.policy_init_std,
            counter=counter,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )

        # Learned policy without exploration.
        eval_policy = tf.function(
            snt.Sequential([
                tf_utils.to_sonnet_module(agent_networks['observation']),
                agent_networks['policy']
            ]))

    else:
        raise NotImplementedError('Supported agents: MPO, DMPO, D4PG.')
    return rl_agent, eval_policy