Ejemplo n.º 1
0
 def testGinConfig(self):
     gin.parse_config_file(
         test_utils.test_src_dir_path(
             'environments/configs/suite_mujoco.gin'))
     env = suite_mujoco.load()
     self.assertIsInstance(env, py_environment.Base)
     self.assertIsInstance(env, wrappers.TimeLimit)
Ejemplo n.º 2
0
def create_env(env_name):
    """Creates Environment."""
    if env_name == 'Pendulum':
        env = gym.make('Pendulum-v0')
    elif env_name == 'Hopper':
        env = suite_mujoco.load('Hopper-v2')
    elif env_name == 'Walker2D':
        env = suite_mujoco.load('Walker2d-v2')
    elif env_name == 'HalfCheetah':
        env = suite_mujoco.load('HalfCheetah-v2')
    elif env_name == 'Ant':
        env = suite_mujoco.load('Ant-v2')
    elif env_name == 'Humanoid':
        env = suite_mujoco.load('Humanoid-v2')
    else:
        raise ValueError('Unsupported environment: %s' % env_name)
    return env
Ejemplo n.º 3
0
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.enable_v2_behavior()

  collect(
      FLAGS.task,
      FLAGS.root_dir,
      replay_buffer_server_address=FLAGS.variable_container_server_address,
      variable_container_server_address=FLAGS.variable_container_server_address,
      create_env_fn=lambda: suite_mujoco.load('HalfCheetah-v2'))
Ejemplo n.º 4
0
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.enable_v2_behavior()

  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

  strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu)

  train(
      FLAGS.root_dir,
      strategy,
      replay_buffer_server_address=FLAGS.replay_buffer_server_address,
      variable_container_server_address=FLAGS.variable_container_server_address,
      create_agent_fn=_create_agent,
      create_env_fn=lambda: suite_mujoco.load('HalfCheetah-v2'),
      num_iterations=FLAGS.num_iterations,
  )
def main():

    # environment
    eval_env = tf_py_environment.TFPyEnvironment(
        suite_mujoco.load('HalfCheetah-v2'))
    # deserialize saved policy
    saved_policy = tf.compat.v2.saved_model.load('checkpoints/policy_9500/')
    # apply_policy and visualize
    total_return = 0.0
    for _ in range(10):
        episode_return = 0.0
        status = eval_env.reset()
        policy_state = saved_policy.get_initial_state(eval_env.batch_size)
        while not status.is_last():
            action = saved_policy.action(status, policy_state)
            status = eval_env.step(action.action)
            policy_state = action.state
            cv2.imshow('halfcheetah', eval_env.pyenv.envs[0].render())
            cv2.waitKey(25)
            episode_return += status.reward
        total_return += episode_return
    avg_return = total_return / 10
    print("average return is %f" % avg_return)
Ejemplo n.º 6
0
def main(_):
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    base_env = suite_mujoco.load(FLAGS.env_name)
    if hasattr(base_env, 'max_episode_steps'):
        max_episode_steps = base_env.max_episode_steps
    else:
        logging.info('Unknown max episode steps. Setting to 1000.')
        max_episode_steps = 1000
    env = base_env.gym
    env = wrappers.check_and_normalize_box_actions(env)
    env.seed(FLAGS.seed)

    eval_env = suite_mujoco.load(FLAGS.env_name).gym
    eval_env = wrappers.check_and_normalize_box_actions(eval_env)
    eval_env.seed(FLAGS.seed + 1)

    spec = (
        tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                               'observation'),
        tensor_spec.TensorSpec([env.action_space.shape[0]], tf.float32,
                               'action'),
        tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                               'next_observation'),
        tensor_spec.TensorSpec([1], tf.float32, 'reward'),
        tensor_spec.TensorSpec([1], tf.float32, 'mask'),
    )
    init_spec = tensor_spec.TensorSpec([env.observation_space.shape[0]],
                                       tf.float32, 'observation')

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        spec, batch_size=1, max_length=FLAGS.max_timesteps)
    init_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        init_spec, batch_size=1, max_length=FLAGS.max_timesteps)

    hparam_str_dict = dict(seed=FLAGS.seed, env=FLAGS.env_name)
    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    rl_algo = algae.ALGAE(env.observation_space.shape[0],
                          env.action_space.shape[0],
                          FLAGS.log_interval,
                          critic_lr=FLAGS.critic_lr,
                          actor_lr=FLAGS.actor_lr,
                          use_dqn=FLAGS.use_dqn,
                          use_init_states=FLAGS.use_init_states,
                          algae_alpha=FLAGS.algae_alpha,
                          exponent=FLAGS.f_exponent)

    episode_return = 0
    episode_timesteps = 0
    done = True

    total_timesteps = 0
    previous_time = time.time()

    replay_buffer_iter = iter(
        replay_buffer.as_dataset(sample_batch_size=FLAGS.sample_batch_size))
    init_replay_buffer_iter = iter(
        init_replay_buffer.as_dataset(
            sample_batch_size=FLAGS.sample_batch_size))

    log_dir = os.path.join(FLAGS.save_dir, 'logs')
    log_filename = os.path.join(log_dir, hparam_str)
    if not gfile.isdir(log_dir):
        gfile.mkdir(log_dir)

    eval_returns = []

    with tqdm(total=FLAGS.max_timesteps, desc='') as pbar:
        # Final return is the average of the last 10 measurmenets.
        final_returns = collections.deque(maxlen=10)
        final_timesteps = 0
        while total_timesteps < FLAGS.max_timesteps:
            _update_pbar_msg(pbar, total_timesteps)
            if done:

                if episode_timesteps > 0:
                    current_time = time.time()

                    train_measurements = [
                        ('train/returns', episode_return),
                        ('train/FPS',
                         episode_timesteps / (current_time - previous_time)),
                    ]
                    _write_measurements(summary_writer, train_measurements,
                                        total_timesteps)
                obs = env.reset()
                episode_return = 0
                episode_timesteps = 0
                previous_time = time.time()

                init_replay_buffer.add_batch(np.array([obs.astype(np.float32)
                                                       ]))

            if total_timesteps < FLAGS.num_random_actions:
                action = env.action_space.sample()
            else:
                _, action, _ = rl_algo.actor(np.array([obs]))
                action = action[0].numpy()

            if total_timesteps >= FLAGS.start_training_timesteps:
                with summary_writer.as_default():
                    target_entropy = (-env.action_space.shape[0]
                                      if FLAGS.target_entropy is None else
                                      FLAGS.target_entropy)
                    for _ in range(FLAGS.num_updates_per_env_step):
                        rl_algo.train(
                            replay_buffer_iter,
                            init_replay_buffer_iter,
                            discount=FLAGS.discount,
                            tau=FLAGS.tau,
                            target_entropy=target_entropy,
                            actor_update_freq=FLAGS.actor_update_freq)

            next_obs, reward, done, _ = env.step(action)
            if (max_episode_steps is not None
                    and episode_timesteps + 1 == max_episode_steps):
                done = True

            if not done or episode_timesteps + 1 == max_episode_steps:  # pylint: disable=protected-access
                mask = 1.0
            else:
                mask = 0.0

            replay_buffer.add_batch((np.array([obs.astype(np.float32)]),
                                     np.array([action.astype(np.float32)]),
                                     np.array([next_obs.astype(np.float32)]),
                                     np.array([[reward]]).astype(np.float32),
                                     np.array([[mask]]).astype(np.float32)))

            episode_return += reward
            episode_timesteps += 1
            total_timesteps += 1
            pbar.update(1)

            obs = next_obs

            if total_timesteps % FLAGS.eval_interval == 0:
                logging.info('Performing policy eval.')
                average_returns, evaluation_timesteps = rl_algo.evaluate(
                    eval_env, max_episode_steps=max_episode_steps)

                eval_returns.append(average_returns)
                fin = gfile.GFile(log_filename, 'w')
                np.save(fin, np.array(eval_returns))
                fin.close()

                eval_measurements = [
                    ('eval/average returns', average_returns),
                    ('eval/average episode length', evaluation_timesteps),
                ]
                # TODO(sandrafaust) Make this average of the last N.
                final_returns.append(average_returns)
                final_timesteps = evaluation_timesteps

                _write_measurements(summary_writer, eval_measurements,
                                    total_timesteps)

                logging.info('Eval: ave returns=%f, ave episode length=%f',
                             average_returns, evaluation_timesteps)
        # Final measurement.
        final_measurements = [
            ('final/average returns', sum(final_returns) / len(final_returns)),
            ('final/average episode length', final_timesteps),
        ]
        _write_measurements(summary_writer, final_measurements,
                            total_timesteps)
Ejemplo n.º 7
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(
                    batch_size * 5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )
                sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_flush_op)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 8
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_update_period=2,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for TD3."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_mujoco.load(env_name))

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
Ejemplo n.º 9
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=100000,
    exploration_noise_std=0.1,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_update_period=2,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for TD3."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_py_env = suite_mujoco.load(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = td3_agent.Td3Agent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        exploration_noise_std=exploration_noise_std,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        actor_update_period=actor_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
    )

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    collect_policy = tf_agent.collect_policy
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    trajectories, unused_info = iterator.get_next()

    train_fn = common.function(tf_agent.train)
    train_op = train_fn(experience=trajectories)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(b/126239733): Remove once Periodically can be saved.
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
Ejemplo n.º 10
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    # Training params
    num_iterations=1600,
    actor_fc_layers=(64, 64),
    value_fc_layers=(64, 64),
    learning_rate=3e-4,
    collect_sequence_length=2048,
    minibatch_size=64,
    num_epochs=10,
    # Agent params
    importance_ratio_clipping=0.2,
    lambda_value=0.95,
    discount_factor=0.99,
    entropy_regularization=0.,
    value_pred_loss_coef=0.5,
    use_gae=True,
    use_td_lambda_return=True,
    gradient_clipping=0.5,
    value_clipping=None,
    # Replay params
    reverb_port=None,
    replay_capacity=10000,
    # Others
    policy_save_interval=5000,
    summary_interval=1000,
    eval_interval=10000,
    eval_episodes=100,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """Trains and evaluates PPO (Importance Ratio Clipping).

  Args:
    root_dir: Main directory path where checkpoints, saved_models, and summaries
      will be written to.
    env_name: Name for the Mujoco environment to load.
    num_iterations: The number of iterations to perform collection and training.
    actor_fc_layers: List of fully_connected parameters for the actor network,
      where each item is the number of units in the layer.
    value_fc_layers: : List of fully_connected parameters for the value network,
      where each item is the number of units in the layer.
    learning_rate: Learning rate used on the Adam optimizer.
    collect_sequence_length: Number of steps to take in each collect run.
    minibatch_size: Number of elements in each mini batch. If `None`, the entire
      collected sequence will be treated as one batch.
    num_epochs: Number of iterations to repeat over all collected data per data
      collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for
      Roboschool and 3 for Atari.
    importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For
      more detail, see explanation at the top of the doc.
    lambda_value: Lambda parameter for TD-lambda computation.
    discount_factor: Discount factor for return computation. Default to `0.99`
      which is the value used for all environments from (Schulman, 2017).
    entropy_regularization: Coefficient for entropy regularization loss term.
      Default to `0.0` because no entropy bonus was used in (Schulman, 2017).
    value_pred_loss_coef: Multiplier for value prediction loss to balance with
      policy gradient loss. Default to `0.5`, which was used for all
      environments in the OpenAI baseline implementation. This parameters is
      irrelevant unless you are sharing part of actor_net and value_net. In that
      case, you would want to tune this coeeficient, whose value depends on the
      network architecture of your choice.
    use_gae: If True (default False), uses generalized advantage estimation for
      computing per-timestep advantage. Else, just subtracts value predictions
      from empirical return.
    use_td_lambda_return: If True (default False), uses td_lambda_return for
      training value function; here: `td_lambda_return = gae_advantage +
        value_predictions`. `use_gae` must be set to `True` as well to enable TD
        -lambda returns. If `use_td_lambda_return` is set to True while
        `use_gae` is False, the empirical return will be used and a warning will
        be logged.
    gradient_clipping: Norm length to clip gradients.
    value_clipping: Difference between new and old value predictions are clipped
      to this threshold. Value clipping could be helpful when training
      very deep networks. Default: no clipping.
    reverb_port: Port for reverb server, if None, use a randomly chosen unused
      port.
    replay_capacity: The maximum number of elements for the replay buffer. Items
      will be wasted if this is smalled than collect_sequence_length.
    policy_save_interval: How often, in train_steps, the policy will be saved.
    summary_interval: How often to write data into Tensorboard.
    eval_interval: How often to run evaluation, in train_steps.
    eval_episodes: Number of episodes to evaluate over.
    debug_summaries: Boolean for whether to gather debug summaries.
    summarize_grads_and_vars: If true, gradient summaries will be written.
  """
  collect_env = suite_mujoco.load(env_name)
  eval_env = suite_mujoco.load(env_name)
  num_environments = 1

  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))
  # TODO(b/172267869): Remove this conversion once TensorNormalizer stops
  # converting float64 inputs to float32.
  observation_tensor_spec = tf.TensorSpec(
      dtype=tf.float32, shape=observation_tensor_spec.shape)

  train_step = train_utils.create_train_step()
  actor_net_builder = ppo_actor_network.PPOActorNetwork()
  actor_net = actor_net_builder.create_sequential_actor_net(
      actor_fc_layers, action_tensor_spec)
  value_net = value_network.ValueNetwork(
      observation_tensor_spec,
      fc_layer_params=value_fc_layers,
      kernel_initializer=tf.keras.initializers.Orthogonal())

  current_iteration = tf.Variable(0, dtype=tf.int64)
  def learning_rate_fn():
    # Linearly decay the learning rate.
    return learning_rate * (1 - current_iteration / num_iterations)

  agent = ppo_clip_agent.PPOClipAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      optimizer=tf.keras.optimizers.Adam(
          learning_rate=learning_rate_fn, epsilon=1e-5),
      actor_net=actor_net,
      value_net=value_net,
      importance_ratio_clipping=importance_ratio_clipping,
      lambda_value=lambda_value,
      discount_factor=discount_factor,
      entropy_regularization=entropy_regularization,
      value_pred_loss_coef=value_pred_loss_coef,
      # This is a legacy argument for the number of times we repeat the data
      # inside of the train function, incompatible with mini batch learning.
      # We set the epoch number from the replay buffer and tf.Data instead.
      num_epochs=1,
      use_gae=use_gae,
      use_td_lambda_return=use_td_lambda_return,
      gradient_clipping=gradient_clipping,
      value_clipping=value_clipping,
      # TODO(b/150244758): Default compute_value_and_advantage_in_train to False
      # after Reverb open source.
      compute_value_and_advantage_in_train=False,
      # Skips updating normalizers in the agent, as it's handled in the learner.
      update_normalizers_in_train=False,
      debug_summaries=debug_summaries,
      summarize_grads_and_vars=summarize_grads_and_vars,
      train_step_counter=train_step)
  agent.initialize()

  reverb_server = reverb.Server(
      [
          reverb.Table(  # Replay buffer storing experience for training.
              name='training_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          ),
          reverb.Table(  # Replay buffer storing experience for normalization.
              name='normalization_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          )
      ],
      port=reverb_port)

  # Create the replay buffer.
  reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='training_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)
  reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='normalization_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)

  rb_observer = reverb_utils.ReverbTrajectorySequenceObserver(
      reverb_replay_train.py_client, ['training_table', 'normalization_table'],
      sequence_length=collect_sequence_length,
      stride_length=collect_sequence_length)

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  collect_env_step_metric = py_metrics.EnvironmentSteps()
  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={
              triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric
          }),
      triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),
  ]

  def training_dataset_fn():
    return reverb_replay_train.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  def normalization_dataset_fn():
    return reverb_replay_normalization.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  agent_learner = ppo_learner.PPOLearner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn=training_dataset_fn,
      normalization_dataset_fn=normalization_dataset_fn,
      num_samples=1,
      num_epochs=num_epochs,
      minibatch_size=minibatch_size,
      shuffle_buffer_size=collect_sequence_length,
      triggers=learning_triggers)

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
      tf_collect_policy, use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=collect_sequence_length,
      observers=[rb_observer],
      metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],
      reference_metrics=[collect_env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
      summary_interval=summary_interval)

  eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
      agent.policy, use_tf_function=True)

  if eval_interval:
    logging.info('Intial evaluation.')
    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        metrics=actor.eval_metrics(eval_episodes),
        reference_metrics=[collect_env_step_metric],
        summary_dir=os.path.join(root_dir, 'eval'),
        episodes_per_run=eval_episodes)

    eval_actor.run_and_log()

  logging.info('Training on %s', env_name)
  last_eval_step = 0
  for i in range(num_iterations):
    collect_actor.run()
    rb_observer.flush()
    agent_learner.run()
    reverb_replay_train.clear()
    reverb_replay_normalization.clear()
    current_iteration.assign_add(1)

    # Eval only if `eval_interval` has been set. Then, eval if the current train
    # step is equal or greater than the `last_eval_step` + `eval_interval` or if
    # this is the last iteration. This logic exists because agent_learner.run()
    # does not return after every train step.
    if (eval_interval and
        (agent_learner.train_step_numpy >= eval_interval + last_eval_step
         or i == num_iterations - 1)):
      logging.info('Evaluating.')
      eval_actor.run_and_log()
      last_eval_step = agent_learner.train_step_numpy

  rb_observer.close()
  reverb_server.stop()
Ejemplo n.º 11
0
 def testActionSpec(self):
     env = suite_mujoco.load('HalfCheetah-v1')
     self.assertEqual(np.float32, env.observation_spec().dtype)
     self.assertEqual((17, ), env.observation_spec().shape)
Ejemplo n.º 12
0
def load_multiple_mugs_env(
    universe,
    action_mode,
    env_name=None,
    render_size=128,
    observation_render_size=64,
    observations_whitelist=None,
    action_repeat=1,
    num_train_tasks=30,
    num_eval_tasks=10,
    eval_on_holdout_tasks=True,
    return_multiple_tasks=False,
    model_input=None,
    auto_reset_task_each_episode=False,
):

    ### HARDCODED
    # temporary sanity
    assert env_name == 'SawyerShelfMT-v0'
    assert return_multiple_tasks
    assert universe == 'gym'

    # get eval and train tasks by loading a sample env
    sample_env = suite_mujoco.load(env_name)
    # train env
    train_tasks = sample_env.init_tasks(num_tasks=num_train_tasks,
                                        is_eval_env=False)
    # eval env
    eval_tasks = sample_env.init_tasks(num_tasks=num_eval_tasks,
                                       is_eval_env=eval_on_holdout_tasks)
    del sample_env

    print("train weights", train_tasks)
    print("eval weights", eval_tasks)
    if env_name == 'SawyerShelfMT-v0':
        from meld.environments.envs.shelf.assets.generate_sawyer_shelf_xml import generate_and_save_xml_file
    else:
        raise NotImplementedError

    train_xml_path = generate_and_save_xml_file(train_tasks,
                                                action_mode,
                                                is_eval=False)
    eval_xml_path = generate_and_save_xml_file(eval_tasks,
                                               action_mode,
                                               is_eval=True)

    ### train env
    # get wrappers
    wrappers = get_wrappers(device_id=0,
                            model_input=model_input,
                            render_size=render_size,
                            observation_render_size=observation_render_size,
                            observations_whitelist=observations_whitelist)
    # load env
    gym_kwargs = {"action_mode": action_mode, "xml_path": train_xml_path}
    py_env = suite_gym.load(env_name,
                            gym_env_wrappers=wrappers,
                            gym_kwargs=gym_kwargs)
    if action_repeat > 1:
        py_env = wrappers.ActionRepeat(py_env, action_repeat)

    ### eval env
    # get wrappers
    wrappers = get_wrappers(device_id=1,
                            model_input=model_input,
                            render_size=render_size,
                            observation_render_size=observation_render_size,
                            observations_whitelist=observations_whitelist)
    # load env
    gym_kwargs = {"action_mode": action_mode, "xml_path": eval_xml_path}
    eval_py_env = suite_gym.load(env_name,
                                 gym_env_wrappers=wrappers,
                                 gym_kwargs=gym_kwargs)
    eval_py_env = video_wrapper.VideoWrapper(eval_py_env)
    if action_repeat > 1:
        eval_py_env = wrappers.ActionRepeat(eval_py_env, action_repeat)

    py_env.assign_tasks(train_tasks)
    eval_py_env.assign_tasks(eval_tasks)

    # set task list and reset variable to true
    if auto_reset_task_each_episode:
        py_env.wrapped_env().set_auto_reset_task(train_tasks)
        eval_py_env.wrapped_env().set_auto_reset_task(eval_tasks)

    return py_env, eval_py_env, train_tasks, eval_tasks
Ejemplo n.º 13
0
def load_env(env_name,
             seed,
             action_repeat=0,
             frame_stack=1,
             obs_type='pixels'):
    """Loads a learning environment.

  Args:
    env_name: Name of the environment.
    seed: Random seed.
    action_repeat: (optional) action repeat multiplier. Useful for DM control
      suite tasks.
    frame_stack: (optional) frame stack.
    obs_type: `pixels` or `state`
  Returns:
    Learning environment.
  """

    action_repeat_applied = False
    state_env = None

    if env_name.startswith('dm'):
        _, domain_name, task_name = env_name.split('-')
        if 'manipulation' in domain_name:
            env = manipulation.load(task_name)
            env = dm_control_wrapper.DmControlWrapper(env)
        else:
            env = _load_dm_env(domain_name,
                               task_name,
                               pixels=False,
                               action_repeat=action_repeat)
            action_repeat_applied = True
        env = wrappers.FlattenObservationsWrapper(env)

    elif env_name.startswith('pixels-dm'):
        if 'distractor' in env_name:
            _, _, domain_name, task_name, _ = env_name.split('-')
            distractor = True
        else:
            _, _, domain_name, task_name = env_name.split('-')
            distractor = False
        # TODO(tompson): Are there DMC environments that have other
        # max_episode_steps?
        env = _load_dm_env(domain_name,
                           task_name,
                           pixels=True,
                           action_repeat=action_repeat,
                           max_episode_steps=1000,
                           obs_type=obs_type,
                           distractor=distractor)
        action_repeat_applied = True
        if obs_type == 'pixels':
            env = FlattenImageObservationsWrapper(env)
            state_env = None
        else:
            env = JointImageObservationsWrapper(env)
            state_env = tf_py_environment.TFPyEnvironment(
                wrappers.FlattenObservationsWrapper(
                    _load_dm_env(domain_name,
                                 task_name,
                                 pixels=False,
                                 action_repeat=action_repeat)))

    else:
        env = suite_mujoco.load(env_name)
        env.seed(seed)

    if action_repeat > 1 and not action_repeat_applied:
        env = wrappers.ActionRepeat(env, action_repeat)
    if frame_stack > 1:
        env = FrameStackWrapperTfAgents(env, frame_stack)

    env = tf_py_environment.TFPyEnvironment(env)

    return env, state_env
Ejemplo n.º 14
0
 def testMujocoEnvRegistered(self):
     env = suite_mujoco.load('HalfCheetah-v1')
     self.assertIsInstance(env, py_environment.Base)
     self.assertIsInstance(env, wrappers.TimeLimit)
def main(_):

    # environment serves as the dataset in reinforcement learning
    train_env = tf_py_environment.TFPyEnvironment(
        ParallelPyEnvironment([lambda: suite_mujoco.load('HalfCheetah-v2')] *
                              batch_size))
    eval_env = tf_py_environment.TFPyEnvironment(
        suite_mujoco.load('HalfCheetah-v2'))
    # create agent
    actor_net = ActorDistributionRnnNetwork(train_env.observation_spec(),
                                            train_env.action_spec(),
                                            lstm_size=(100, 100))
    value_net = ValueRnnNetwork(train_env.observation_spec())
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
    tf_agent = ppo_agent.PPOAgent(train_env.time_step_spec(),
                                  train_env.action_spec(),
                                  optimizer=optimizer,
                                  actor_net=actor_net,
                                  value_net=value_net,
                                  normalize_observations=False,
                                  normalize_rewards=False,
                                  use_gae=True,
                                  num_epochs=25)
    tf_agent.initialize()
    # replay buffer
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=1000000)
    # policy saver
    saver = policy_saver.PolicySaver(tf_agent.policy)
    # define trajectory collector
    train_episode_count = tf_metrics.NumberOfEpisodes()
    train_total_steps = tf_metrics.EnvironmentSteps()
    train_avg_reward = tf_metrics.AverageReturnMetric(
        batch_size=train_env.batch_size)
    train_avg_episode_len = tf_metrics.AverageEpisodeLengthMetric(
        batch_size=train_env.batch_size)
    train_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        train_env,
        tf_agent.collect_policy,  # NOTE: use PPOPolicy to collect episode
        observers=[
            replay_buffer.add_batch, train_episode_count, train_total_steps,
            train_avg_reward, train_avg_episode_len
        ],  # callbacks when an episode is completely collected
        num_episodes=30,  # how many episodes are collected in an iteration
    )
    # training
    eval_avg_reward = tf_metrics.AverageReturnMetric(buffer_size=30)
    eval_avg_episode_len = tf_metrics.AverageEpisodeLengthMetric(
        buffer_size=30)
    while train_total_steps.result() < 25000000:
        train_driver.run()
        trajectories = replay_buffer.gather_all()
        loss, _ = tf_agent.train(experience=trajectories)
        replay_buffer.clear()
        # clear collected episodes right after training
        if tf_agent.train_step_counter.numpy() % 50 == 0:
            print('step = {0}: loss = {1}'.format(
                tf_agent.train_step_counter.numpy(), loss))
        if tf_agent.train_step_counter.numpy() % 500 == 0:
            # save checkpoint
            saver.save('checkpoints/policy_%d' %
                       tf_agent.train_step_counter.numpy())
            # evaluate the updated policy
            eval_avg_reward.reset()
            eval_avg_episode_len.reset()
            eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
                eval_env,
                tf_agent.policy,
                observers=[
                    eval_avg_reward,
                    eval_avg_episode_len,
                ],
                num_episodes=
                30,  # how many epsiodes are collected in an iteration
            )
            eval_driver.run()
            print(
                'step = {0}: Average Return = {1} Average Episode Length = {2}'
                .format(tf_agent.train_step_counter.numpy(),
                        train_avg_reward.result(),
                        train_avg_episode_len.result()))
    # play cartpole for the last 3 times and visualize
    import cv2
    for _ in range(3):
        status = eval_env.reset()
        policy_state = tf_agent.policy.get_initial_state(eval_env.batch_size)
        while not status.is_last():
            action = tf_agent.policy.action(status, policy_state)
            # NOTE: use greedy policy to test
            status = eval_env.step(action.action)
            policy_state = action.state
            cv2.imshow('halfcheetah', eval_env.pyenv.envs[0].render())
            cv2.waitKey(25)
Ejemplo n.º 16
0
def train_eval(
        root_dir,
        strategy: tf.distribute.Strategy,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        policy_save_interval=10000,
        replay_buffer_save_interval=100000,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    _, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    actor_net = create_sequential_actor_network(
        actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec)

    critic_net = create_sequential_critic_network(
        obs_fc_layer_units=critic_obs_fc_layers,
        action_fc_layer_units=critic_action_fc_layers,
        joint_fc_layer_units=critic_joint_fc_layers)

    with strategy.scope():
        train_step = train_utils.create_train_step()
        agent = sac_agent.SacAgent(
            time_step_tensor_spec,
            action_tensor_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.keras.optimizers.Adam(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.keras.optimizers.Adam(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.keras.optimizers.Adam(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=None,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step)
        agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR,
                                         learner.REPLAY_BUFFER_CHECKPOINT_DIR)
    reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(
        path=reverb_checkpoint_dir)
    reverb_server = reverb.Server([table],
                                  port=reverb_port,
                                  checkpointer=reverb_checkpointer)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=2).prefetch(50)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.ReverbCheckpointTrigger(
            train_step,
            interval=replay_buffer_save_interval,
            reverb_client=reverb_replay.py_client),
        # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption.
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers,
                                    strategy=strategy)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Ejemplo n.º 17
0
def load_environments(universe,
                      env_name=None,
                      domain_name=None,
                      task_name=None,
                      render_size=128,
                      observation_render_size=64,
                      observations_whitelist=None,
                      action_repeat=1):
    """Loads train and eval environments.

  The universe can either be gym, in which case domain_name and task_name are
  ignored, or dm_control, in which case env_name is ignored.
  """
    if universe == 'gym':
        tf.compat.v1.logging.info(
            'Using environment {} from {} universe.'.format(
                env_name, universe))
        gym_env_wrappers = [
            functools.partial(gym_wrappers.RenderGymWrapper,
                              render_kwargs={
                                  'height': render_size,
                                  'width': render_size,
                                  'device_id': 0
                              }),
            functools.partial(gym_wrappers.PixelObservationsGymWrapper,
                              observations_whitelist=observations_whitelist,
                              render_kwargs={
                                  'height': observation_render_size,
                                  'width': observation_render_size,
                                  'device_id': 0
                              })
        ]
        eval_gym_env_wrappers = [
            functools.partial(gym_wrappers.RenderGymWrapper,
                              render_kwargs={
                                  'height': render_size,
                                  'width': render_size,
                                  'device_id': 1
                              }),
            # segfaults if the device is the same as train env
            functools.partial(gym_wrappers.PixelObservationsGymWrapper,
                              observations_whitelist=observations_whitelist,
                              render_kwargs={
                                  'height': observation_render_size,
                                  'width': observation_render_size,
                                  'device_id': 1
                              })
        ]  # segfaults if the device is the same as train env
        py_env = suite_mujoco.load(env_name, gym_env_wrappers=gym_env_wrappers)
        eval_py_env = suite_mujoco.load(env_name,
                                        gym_env_wrappers=eval_gym_env_wrappers)
    elif universe == 'dm_control':
        tf.compat.v1.logging.info(
            'Using domain {} and task {} from {} universe.'.format(
                domain_name, task_name, universe))
        render_kwargs = {
            'height': render_size,
            'width': render_size,
            'camera_id': 0,
        }
        dm_env_wrappers = [
            wrappers.
            FlattenObservationsWrapper,  # combine position and velocity
            functools.partial(
                dm_control_wrappers.PixelObservationsDmControlWrapper,
                observations_whitelist=observations_whitelist,
                render_kwargs={
                    'height': observation_render_size,
                    'width': observation_render_size,
                    'camera_id': 0
                })
        ]
        py_env = suite_dm_control.load(domain_name,
                                       task_name,
                                       render_kwargs=render_kwargs,
                                       env_wrappers=dm_env_wrappers)
        eval_py_env = suite_dm_control.load(domain_name,
                                            task_name,
                                            render_kwargs=render_kwargs,
                                            env_wrappers=dm_env_wrappers)
    else:
        raise ValueError('Invalid universe %s.' % universe)

    eval_py_env = video_wrapper.VideoWrapper(eval_py_env)

    if action_repeat > 1:
        py_env = wrappers.ActionRepeat(py_env, action_repeat)
        eval_py_env = wrappers.ActionRepeat(eval_py_env, action_repeat)

    return py_env, eval_py_env
def env_factory(env_name):
    py_env = suite_mujoco.load(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    return tf_env
Ejemplo n.º 19
0
def main(_):
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.models_dir is None:
        raise ValueError('You must set a value for models_dir.')

    env = suite_mujoco.load(FLAGS.env_name)
    env.seed(FLAGS.seed)
    env = tf_py_environment.TFPyEnvironment(env)

    sac = actor_lib.Actor(env.observation_spec().shape[0], env.action_spec())

    model_filename = os.path.join(FLAGS.models_dir, 'DM-' + FLAGS.env_name,
                                  str(FLAGS.model_seed), '1000000')
    sac.load_weights(model_filename)

    if FLAGS.std is None:
        if 'Reacher' in FLAGS.env_name:
            std = 0.5
        elif 'Ant' in FLAGS.env_name:
            std = 0.4
        elif 'Walker' in FLAGS.env_name:
            std = 2.0
        else:
            std = 0.75
    else:
        std = FLAGS.std

    def get_action(state):
        _, action, log_prob = sac(state, std)
        return action, log_prob

    dataset = dict(model_filename=model_filename,
                   behavior_std=std,
                   trajectories=dict(states=[],
                                     actions=[],
                                     log_probs=[],
                                     next_states=[],
                                     rewards=[],
                                     masks=[]))

    for i in range(FLAGS.num_episodes):
        timestep = env.reset()
        trajectory = dict(states=[],
                          actions=[],
                          log_probs=[],
                          next_states=[],
                          rewards=[],
                          masks=[])

        while not timestep.is_last():
            action, log_prob = get_action(timestep.observation)
            next_timestep = env.step(action)

            trajectory['states'].append(timestep.observation)
            trajectory['actions'].append(action)
            trajectory['log_probs'].append(log_prob)
            trajectory['next_states'].append(next_timestep.observation)
            trajectory['rewards'].append(next_timestep.reward)
            trajectory['masks'].append(next_timestep.discount)

            timestep = next_timestep

        for k, v in trajectory.items():
            dataset['trajectories'][k].append(tf.concat(v, 0).numpy())

        logging.info('%d trajectories', i + 1)

    data_save_dir = os.path.join(FLAGS.save_dir, FLAGS.env_name,
                                 str(FLAGS.model_seed))
    if not tf.io.gfile.isdir(data_save_dir):
        tf.io.gfile.makedirs(data_save_dir)

    save_filename = os.path.join(data_save_dir, f'dualdice_{FLAGS.std}.pckl')
    with tf.io.gfile.GFile(save_filename, 'wb') as f:
        pickle.dump(dataset, f)
Ejemplo n.º 20
0
def main(_):
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    hparam_str = make_hparam_string(seed=FLAGS.seed, env_name=FLAGS.env_name)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    summary_writer.set_as_default()

    if FLAGS.d4rl:
        d4rl_env = gym.make(FLAGS.env_name)
        gym_spec = gym.spec(FLAGS.env_name)
        if gym_spec.max_episode_steps in [0, None]:  # Add TimeLimit wrapper.
            gym_env = time_limit.TimeLimit(d4rl_env, max_episode_steps=1000)
        else:
            gym_env = d4rl_env
        gym_env.seed(FLAGS.seed)
        env = tf_py_environment.TFPyEnvironment(
            gym_wrapper.GymWrapper(gym_env))

        behavior_dataset = D4rlDataset(
            d4rl_env,
            normalize_states=FLAGS.normalize_states,
            normalize_rewards=FLAGS.normalize_rewards,
            noise_scale=FLAGS.noise_scale,
            bootstrap=FLAGS.bootstrap)
    else:
        env = suite_mujoco.load(FLAGS.env_name)
        env.seed(FLAGS.seed)
        env = tf_py_environment.TFPyEnvironment(env)

        data_file_name = os.path.join(
            FLAGS.data_dir, FLAGS.env_name, '0',
            f'dualdice_{FLAGS.behavior_policy_std}.pckl')
        behavior_dataset = Dataset(data_file_name,
                                   FLAGS.num_trajectories,
                                   normalize_states=FLAGS.normalize_states,
                                   normalize_rewards=FLAGS.normalize_rewards,
                                   noise_scale=FLAGS.noise_scale,
                                   bootstrap=FLAGS.bootstrap)

    tf_dataset = behavior_dataset.with_uniform_sampling(
        FLAGS.sample_batch_size)
    tf_dataset_iter = iter(tf_dataset)

    if FLAGS.d4rl:
        with tf.io.gfile.GFile(FLAGS.d4rl_policy_filename, 'rb') as f:
            policy_weights = pickle.load(f)
        actor = utils.D4rlActor(env,
                                policy_weights,
                                is_dapg='dapg' in FLAGS.d4rl_policy_filename)
    else:
        actor = Actor(env.observation_spec().shape[0], env.action_spec())
        actor.load_weights(behavior_dataset.model_filename)

    policy_returns = utils.estimate_monte_carlo_returns(
        env, FLAGS.discount, actor, FLAGS.target_policy_std,
        FLAGS.num_mc_episodes)
    logging.info('Estimated Per-Step Average Returns=%f', policy_returns)

    if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo:
        model = QFitter(env.observation_spec().shape[0],
                        env.action_spec().shape[0], FLAGS.lr,
                        FLAGS.weight_decay, FLAGS.tau)
    elif 'mb' in FLAGS.algo:
        model = ModelBased(env.observation_spec().shape[0],
                           env.action_spec().shape[0],
                           learning_rate=FLAGS.lr,
                           weight_decay=FLAGS.weight_decay)
    elif 'dual_dice' in FLAGS.algo:
        model = DualDICE(env.observation_spec().shape[0],
                         env.action_spec().shape[0], FLAGS.weight_decay)
    if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo:
        behavior = BehaviorCloning(env.observation_spec().shape[0],
                                   env.action_spec(), FLAGS.lr,
                                   FLAGS.weight_decay)

    @tf.function
    def get_target_actions(states):
        return actor(tf.cast(behavior_dataset.unnormalize_states(states),
                             env.observation_spec().dtype),
                     std=FLAGS.target_policy_std)[1]

    @tf.function
    def get_target_logprobs(states, actions):
        log_probs = actor(tf.cast(behavior_dataset.unnormalize_states(states),
                                  env.observation_spec().dtype),
                          actions=actions,
                          std=FLAGS.target_policy_std)[2]
        if tf.rank(log_probs) > 1:
            log_probs = tf.reduce_sum(log_probs, -1)
        return log_probs

    min_reward = tf.reduce_min(behavior_dataset.rewards)
    max_reward = tf.reduce_max(behavior_dataset.rewards)
    min_state = tf.reduce_min(behavior_dataset.states, 0)
    max_state = tf.reduce_max(behavior_dataset.states, 0)

    @tf.function
    def update_step():
        (states, actions, next_states, rewards, masks, weights,
         _) = next(tf_dataset_iter)
        initial_actions = get_target_actions(behavior_dataset.initial_states)
        next_actions = get_target_actions(next_states)

        if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo:
            model.update(states, actions, next_states, next_actions, rewards,
                         masks, weights, FLAGS.discount, min_reward,
                         max_reward)
        elif 'mb' in FLAGS.algo:
            model.update(states, actions, next_states, rewards, masks, weights)
        elif 'dual_dice' in FLAGS.algo:
            model.update(behavior_dataset.initial_states, initial_actions,
                         behavior_dataset.initial_weights, states, actions,
                         next_states, next_actions, masks, weights,
                         FLAGS.discount)

        if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo:
            behavior.update(states, actions, weights)

    gc.collect()

    for i in tqdm.tqdm(range(FLAGS.num_updates), desc='Running Training'):
        update_step()

        if i % FLAGS.eval_interval == 0:
            if 'fqe' in FLAGS.algo:
                pred_returns = model.estimate_returns(
                    behavior_dataset.initial_states,
                    behavior_dataset.initial_weights, get_target_actions)
            elif 'mb' in FLAGS.algo:
                pred_returns = model.estimate_returns(
                    behavior_dataset.initial_states,
                    behavior_dataset.initial_weights, get_target_actions,
                    FLAGS.discount, min_reward, max_reward, min_state,
                    max_state)
            elif FLAGS.algo in ['dual_dice']:
                pred_returns, pred_ratio = model.estimate_returns(
                    iter(tf_dataset))

                tf.summary.scalar('train/pred ratio', pred_ratio, step=i)
            elif 'iw' in FLAGS.algo or 'dr' in FLAGS.algo:
                discount = FLAGS.discount
                _, behavior_log_probs = behavior(behavior_dataset.states,
                                                 behavior_dataset.actions)
                target_log_probs = get_target_logprobs(
                    behavior_dataset.states, behavior_dataset.actions)
                offset = 0.0
                rewards = behavior_dataset.rewards
                if 'dr' in FLAGS.algo:
                    # Doubly-robust is effectively the same as importance-weighting but
                    # transforming rewards at (s,a) to r(s,a) + gamma * V^pi(s') -
                    # Q^pi(s,a) and adding an offset to each trajectory equal to V^pi(s0).
                    offset = model.estimate_returns(
                        behavior_dataset.initial_states,
                        behavior_dataset.initial_weights, get_target_actions)
                    q_values = (model(behavior_dataset.states,
                                      behavior_dataset.actions) /
                                (1 - discount))
                    n_samples = 10
                    next_actions = [
                        get_target_actions(behavior_dataset.next_states)
                        for _ in range(n_samples)
                    ]
                    next_q_values = sum([
                        model(behavior_dataset.next_states, next_action) /
                        (1 - discount) for next_action in next_actions
                    ]) / n_samples
                    rewards = rewards + discount * next_q_values - q_values

                # Now we compute the self-normalized importance weights.
                # Self-normalization happens over trajectories per-step, so we
                # restructure the dataset as [num_trajectories, num_steps].
                num_trajectories = len(behavior_dataset.initial_states)
                max_trajectory_length = np.max(behavior_dataset.steps) + 1
                trajectory_weights = behavior_dataset.initial_weights
                trajectory_starts = np.where(
                    np.equal(behavior_dataset.steps, 0))[0]

                batched_rewards = np.zeros(
                    [num_trajectories, max_trajectory_length])
                batched_masks = np.zeros(
                    [num_trajectories, max_trajectory_length])
                batched_log_probs = np.zeros(
                    [num_trajectories, max_trajectory_length])

                for traj_idx, traj_start in enumerate(trajectory_starts):
                    traj_end = (trajectory_starts[traj_idx + 1] if traj_idx +
                                1 < len(trajectory_starts) else len(rewards))
                    traj_length = traj_end - traj_start
                    batched_rewards[
                        traj_idx, :traj_length] = rewards[traj_start:traj_end]
                    batched_masks[traj_idx, :traj_length] = 1.
                    batched_log_probs[traj_idx, :traj_length] = (
                        -behavior_log_probs[traj_start:traj_end] +
                        target_log_probs[traj_start:traj_end])

                batched_weights = (
                    batched_masks *
                    (discount**np.arange(max_trajectory_length))[None, :])

                clipped_log_probs = np.clip(batched_log_probs, -6., 2.)
                cum_log_probs = batched_masks * np.cumsum(clipped_log_probs,
                                                          axis=1)
                cum_log_probs_offset = np.max(cum_log_probs, axis=0)
                cum_probs = np.exp(cum_log_probs -
                                   cum_log_probs_offset[None, :])
                avg_cum_probs = (
                    np.sum(cum_probs * trajectory_weights[:, None], axis=0) /
                    (1e-10 + np.sum(
                        batched_masks * trajectory_weights[:, None], axis=0)))
                norm_cum_probs = cum_probs / (1e-10 + avg_cum_probs[None, :])

                weighted_rewards = batched_weights * batched_rewards * norm_cum_probs
                trajectory_values = np.sum(weighted_rewards, axis=1)
                avg_trajectory_value = (
                    (1 - discount) *
                    np.sum(trajectory_values * trajectory_weights) /
                    np.sum(trajectory_weights))
                pred_returns = offset + avg_trajectory_value

            pred_returns = behavior_dataset.unnormalize_rewards(pred_returns)

            tf.summary.scalar('train/pred returns', pred_returns, step=i)
            logging.info('pred returns=%f', pred_returns)

            tf.summary.scalar('train/true minus pred returns',
                              policy_returns - pred_returns,
                              step=i)
            logging.info('true minus pred returns=%f',
                         policy_returns - pred_returns)
Ejemplo n.º 21
0
def main(_):
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    base_env = suite_mujoco.load(FLAGS.env_name)
    if hasattr(base_env, 'max_episode_steps'):
        max_episode_steps = base_env.max_episode_steps
    else:
        logging.info('Unknown max episode steps. Setting to 1000.')
        max_episode_steps = 1000
    env = base_env.gym
    env = wrappers.check_and_normalize_box_actions(env)
    env.seed(FLAGS.seed)

    eval_env = suite_mujoco.load(FLAGS.env_name).gym
    eval_env = wrappers.check_and_normalize_box_actions(eval_env)
    eval_env.seed(FLAGS.seed + 1)

    hparam_str_dict = dict(algo=FLAGS.algo,
                           seed=FLAGS.seed,
                           env=FLAGS.env_name,
                           dqn=FLAGS.use_dqn)
    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    rl_algo = algae.ALGAE(env.observation_space.shape[0],
                          env.action_space.shape[0], [
                              float(env.action_space.low.min()),
                              float(env.action_space.high.max())
                          ],
                          FLAGS.log_interval,
                          critic_lr=FLAGS.critic_lr,
                          actor_lr=FLAGS.actor_lr,
                          use_dqn=FLAGS.use_dqn,
                          use_init_states=FLAGS.use_init_states,
                          algae_alpha=FLAGS.algae_alpha,
                          exponent=FLAGS.f_exponent)

    episode_return = 0
    episode_timesteps = 0
    done = True

    total_timesteps = 0
    previous_time = time.time()

    replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape,
                                       action_shape=env.action_space.shape,
                                       capacity=FLAGS.max_timesteps * 2,
                                       batch_size=FLAGS.sample_batch_size,
                                       device=device)

    log_dir = os.path.join(FLAGS.save_dir, 'logs')
    log_filename = os.path.join(log_dir, hparam_str)
    if not gfile.isdir(log_dir):
        gfile.mkdir(log_dir)

    eval_returns = []

    with tqdm(total=FLAGS.max_timesteps, desc='') as pbar:
        # Final return is the average of the last 10 measurmenets.
        final_returns = collections.deque(maxlen=10)
        final_timesteps = 0
        while total_timesteps < FLAGS.max_timesteps:
            _update_pbar_msg(pbar, total_timesteps)
            if done:
                print('episodic return: {}'.format(episode_return))
                if episode_timesteps > 0:
                    current_time = time.time()
                    train_measurements = [
                        ('train/returns', episode_return),
                        ('train/FPS',
                         episode_timesteps / (current_time - previous_time)),
                    ]
                    _write_measurements(summary_writer, train_measurements,
                                        total_timesteps)

                obs = env.reset()
                episode_return = 0
                episode_timesteps = 0
                previous_time = time.time()

            #init_replay_buffer.add_batch(np.array([obs.astype(np.float32)]))

            if total_timesteps < FLAGS.num_random_actions:
                action = env.action_space.sample()
            else:
                action = rl_algo.act(obs, sample=True)

            if total_timesteps >= FLAGS.start_training_timesteps:
                with summary_writer.as_default():
                    target_entropy = (-env.action_space.shape[0]
                                      if FLAGS.target_entropy is None else
                                      FLAGS.target_entropy)
                    for _ in range(FLAGS.num_updates_per_env_step):
                        rl_algo.update(
                            replay_buffer,
                            total_timesteps=total_timesteps,
                            discount=FLAGS.discount,
                            tau=FLAGS.tau,
                            target_entropy=target_entropy,
                            actor_update_freq=FLAGS.actor_update_freq)

            next_obs, reward, done, _ = env.step(action)
            if (max_episode_steps is not None
                    and episode_timesteps + 1 == max_episode_steps):
                done = True

            done_bool = 0 if episode_timesteps + 1 == max_episode_steps else float(
                done)

            replay_buffer.add(obs, action, reward, next_obs, done_bool)

            episode_return += reward
            episode_timesteps += 1
            total_timesteps += 1
            pbar.update(1)

            obs = next_obs

            if total_timesteps % FLAGS.eval_interval == 0:
                logging.info('Performing policy eval.')
                average_returns, evaluation_timesteps = evaluate(
                    eval_env, rl_algo, max_episode_steps=max_episode_steps)

                eval_returns.append(average_returns)
                fin = gfile.GFile(log_filename, 'w')
                np.save(fin, np.array(eval_returns))
                fin.close()

                eval_measurements = [
                    ('eval/average returns', average_returns),
                    ('eval/average episode length', evaluation_timesteps),
                ]
                # TODO(sandrafaust) Make this average of the last N.
                final_returns.append(average_returns)
                final_timesteps = evaluation_timesteps

                _write_measurements(summary_writer, eval_measurements,
                                    total_timesteps)

                logging.info('Eval: ave returns=%f, ave episode length=%f',
                             average_returns, evaluation_timesteps)
    # Final measurement.
    final_measurements = [
        ('final/average returns', sum(final_returns) / len(final_returns)),
        ('final/average episode length', final_timesteps),
    ]
    _write_measurements(summary_writer, final_measurements, total_timesteps)
Ejemplo n.º 22
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Ejemplo n.º 23
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss