示例#1
0
def main(_):
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    base_env = suite_mujoco.load(FLAGS.env_name)
    if hasattr(base_env, 'max_episode_steps'):
        max_episode_steps = base_env.max_episode_steps
    else:
        logging.info('Unknown max episode steps. Setting to 1000.')
        max_episode_steps = 1000
    env = base_env.gym
    env = wrappers.check_and_normalize_box_actions(env)
    env.seed(FLAGS.seed)

    eval_env = suite_mujoco.load(FLAGS.env_name).gym
    eval_env = wrappers.check_and_normalize_box_actions(eval_env)
    eval_env.seed(FLAGS.seed + 1)

    spec = (
        tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                               'observation'),
        tensor_spec.TensorSpec([env.action_space.shape[0]], tf.float32,
                               'action'),
        tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                               'next_observation'),
        tensor_spec.TensorSpec([1], tf.float32, 'reward'),
        tensor_spec.TensorSpec([1], tf.float32, 'mask'),
    )
    init_spec = tensor_spec.TensorSpec([env.observation_space.shape[0]],
                                       tf.float32, 'observation')

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        spec, batch_size=1, max_length=FLAGS.max_timesteps)
    init_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        init_spec, batch_size=1, max_length=FLAGS.max_timesteps)

    hparam_str_dict = dict(seed=FLAGS.seed, env=FLAGS.env_name)
    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    rl_algo = algae.ALGAE(env.observation_space.shape[0],
                          env.action_space.shape[0],
                          FLAGS.log_interval,
                          critic_lr=FLAGS.critic_lr,
                          actor_lr=FLAGS.actor_lr,
                          use_dqn=FLAGS.use_dqn,
                          use_init_states=FLAGS.use_init_states,
                          algae_alpha=FLAGS.algae_alpha,
                          exponent=FLAGS.f_exponent)

    episode_return = 0
    episode_timesteps = 0
    done = True

    total_timesteps = 0
    previous_time = time.time()

    replay_buffer_iter = iter(
        replay_buffer.as_dataset(sample_batch_size=FLAGS.sample_batch_size))
    init_replay_buffer_iter = iter(
        init_replay_buffer.as_dataset(
            sample_batch_size=FLAGS.sample_batch_size))

    log_dir = os.path.join(FLAGS.save_dir, 'logs')
    log_filename = os.path.join(log_dir, hparam_str)
    if not gfile.isdir(log_dir):
        gfile.mkdir(log_dir)

    eval_returns = []

    with tqdm(total=FLAGS.max_timesteps, desc='') as pbar:
        # Final return is the average of the last 10 measurmenets.
        final_returns = collections.deque(maxlen=10)
        final_timesteps = 0
        while total_timesteps < FLAGS.max_timesteps:
            _update_pbar_msg(pbar, total_timesteps)
            if done:

                if episode_timesteps > 0:
                    current_time = time.time()

                    train_measurements = [
                        ('train/returns', episode_return),
                        ('train/FPS',
                         episode_timesteps / (current_time - previous_time)),
                    ]
                    _write_measurements(summary_writer, train_measurements,
                                        total_timesteps)
                obs = env.reset()
                episode_return = 0
                episode_timesteps = 0
                previous_time = time.time()

                init_replay_buffer.add_batch(np.array([obs.astype(np.float32)
                                                       ]))

            if total_timesteps < FLAGS.num_random_actions:
                action = env.action_space.sample()
            else:
                _, action, _ = rl_algo.actor(np.array([obs]))
                action = action[0].numpy()

            if total_timesteps >= FLAGS.start_training_timesteps:
                with summary_writer.as_default():
                    target_entropy = (-env.action_space.shape[0]
                                      if FLAGS.target_entropy is None else
                                      FLAGS.target_entropy)
                    for _ in range(FLAGS.num_updates_per_env_step):
                        rl_algo.train(
                            replay_buffer_iter,
                            init_replay_buffer_iter,
                            discount=FLAGS.discount,
                            tau=FLAGS.tau,
                            target_entropy=target_entropy,
                            actor_update_freq=FLAGS.actor_update_freq)

            next_obs, reward, done, _ = env.step(action)
            if (max_episode_steps is not None
                    and episode_timesteps + 1 == max_episode_steps):
                done = True

            if not done or episode_timesteps + 1 == max_episode_steps:  # pylint: disable=protected-access
                mask = 1.0
            else:
                mask = 0.0

            replay_buffer.add_batch((np.array([obs.astype(np.float32)]),
                                     np.array([action.astype(np.float32)]),
                                     np.array([next_obs.astype(np.float32)]),
                                     np.array([[reward]]).astype(np.float32),
                                     np.array([[mask]]).astype(np.float32)))

            episode_return += reward
            episode_timesteps += 1
            total_timesteps += 1
            pbar.update(1)

            obs = next_obs

            if total_timesteps % FLAGS.eval_interval == 0:
                logging.info('Performing policy eval.')
                average_returns, evaluation_timesteps = rl_algo.evaluate(
                    eval_env, max_episode_steps=max_episode_steps)

                eval_returns.append(average_returns)
                fin = gfile.GFile(log_filename, 'w')
                np.save(fin, np.array(eval_returns))
                fin.close()

                eval_measurements = [
                    ('eval/average returns', average_returns),
                    ('eval/average episode length', evaluation_timesteps),
                ]
                # TODO(sandrafaust) Make this average of the last N.
                final_returns.append(average_returns)
                final_timesteps = evaluation_timesteps

                _write_measurements(summary_writer, eval_measurements,
                                    total_timesteps)

                logging.info('Eval: ave returns=%f, ave episode length=%f',
                             average_returns, evaluation_timesteps)
        # Final measurement.
        final_measurements = [
            ('final/average returns', sum(final_returns) / len(final_returns)),
            ('final/average episode length', final_timesteps),
        ]
        _write_measurements(summary_writer, final_measurements,
                            total_timesteps)
示例#2
0
def main(_):
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    final_output_dir = os.path.join(FLAGS.save_dir, str(uuid.uuid4()))
    os.makedirs(final_output_dir, exist_ok=True)
    with open(os.path.join(final_output_dir, 'params.json'),
              'w') as params_file:
        json.dump({
            'env_name': FLAGS.env_name,
            'seed': FLAGS.seed,
        }, params_file)

    #base_env = suite_mujoco.load(FLAGS.env_name)
    base_env = gym.make(FLAGS.env_name)
    if hasattr(base_env, '_max_episode_steps'):
        max_episode_steps = base_env._max_episode_steps
    else:
        logging.info('Unknown max episode steps. Setting to 1000.')
        max_episode_steps = 1000
    #env = base_env.gym
    env = base_env
    env = wrappers.check_and_normalize_box_actions(env)
    env.seed(FLAGS.seed)

    #eval_env = suite_mujoco.load(FLAGS.env_name).gym
    eval_env = gym.make(FLAGS.env_name)
    eval_env = wrappers.check_and_normalize_box_actions(eval_env)
    eval_env.seed(FLAGS.seed + 1)

    spec = (
        tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                               'observation'),
        tensor_spec.TensorSpec([env.action_space.shape[0]], tf.float32,
                               'action'),
        tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32,
                               'next_observation'),
        tensor_spec.TensorSpec([1], tf.float32, 'reward'),
        tensor_spec.TensorSpec([1], tf.float32, 'mask'),
    )
    init_spec = tensor_spec.TensorSpec([env.observation_space.shape[0]],
                                       tf.float32, 'observation')

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        spec, batch_size=1, max_length=FLAGS.max_timesteps)
    init_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        init_spec, batch_size=1, max_length=FLAGS.max_timesteps)

    # load dataset
    dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    print('loading from buffer. %d items loaded' % N)
    obs = dataset["observations"][:N - 1]
    next_obs = dataset["observations"][1:]
    action = dataset["actions"][:N - 1]
    reward = np.expand_dims(dataset["rewards"][:N - 1], axis=-1)
    mask = np.expand_dims(dataset["terminals"][:N - 1], axis=-1)
    #data = (obs[i:i+batch_size], action[i:i+batch_size], next_obs[i:i+batch_size], reward[i:i+batch_size], mask[i:i+batch_size])

    episode_steps = 0
    batch_obs = []
    batch_actions = []
    batch_next_obs = []
    batch_reward = []
    batch_mask = []
    for i in tqdm(range(0, N - 1)):
        batch_obs.append(obs[i])
        batch_actions.append(action[i])
        batch_next_obs.append(next_obs[i])
        batch_reward.append(reward[i])
        batch_mask.append(mask[i])
        episode_steps += 1

        if episode_steps == max_episode_steps - 1:
            data = (batch_obs, batch_actions, batch_next_obs, batch_reward,
                    batch_mask)
            data = tuple([np.array(_d) for _d in data])
            replay_buffer.add_batch(data)
            episode_steps = 0
            batch_obs = []
            batch_actions = []
            batch_next_obs = []
            batch_reward = []
            batch_mask = []

    hparam_str_dict = dict(seed=FLAGS.seed, env=FLAGS.env_name)
    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])
    hparam_str = 'seed_%d_env_%s' % (FLAGS.seed, FLAGS.env_name)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(final_output_dir, 'tb', hparam_str))

    rl_algo = algae.ALGAE(env.observation_space.shape[0],
                          env.action_space.shape[0],
                          FLAGS.log_interval,
                          critic_lr=FLAGS.critic_lr,
                          actor_lr=FLAGS.actor_lr,
                          use_dqn=FLAGS.use_dqn,
                          use_init_states=FLAGS.use_init_states,
                          algae_alpha=FLAGS.algae_alpha,
                          exponent=FLAGS.f_exponent)

    episode_return = 0
    episode_timesteps = 0
    done = True

    total_timesteps = 0
    previous_time = time.time()

    replay_buffer_iter = iter(
        replay_buffer.as_dataset(sample_batch_size=FLAGS.sample_batch_size))
    init_replay_buffer_iter = iter(
        init_replay_buffer.as_dataset(
            sample_batch_size=FLAGS.sample_batch_size))

    log_dir = os.path.join(final_output_dir, 'logs')
    log_filename = os.path.join(log_dir, hparam_str + '_results.npy')
    if not gfile.isdir(log_dir):
        gfile.mkdir(log_dir)

    eval_returns = []

    with tqdm(total=FLAGS.max_timesteps,
              desc='',
              mininterval=2.0,
              miniters=100) as pbar:
        # Final return is the average of the last 10 measurmenets.
        final_returns = collections.deque(maxlen=10)
        final_timesteps = 0
        while total_timesteps < FLAGS.max_timesteps:
            _update_pbar_msg(pbar, total_timesteps)
            if done:

                if episode_timesteps > 0:
                    current_time = time.time()

                    train_measurements = [
                        ('train/returns', episode_return),
                        ('train/FPS',
                         episode_timesteps / (current_time - previous_time)),
                    ]
                    _write_measurements(summary_writer, train_measurements,
                                        total_timesteps)
                obs = env.reset()
                episode_return = 0
                episode_timesteps = 0
                previous_time = time.time()

                init_replay_buffer.add_batch(np.array([obs.astype(np.float32)
                                                       ]))

            if total_timesteps < FLAGS.num_random_actions:
                action = env.action_space.sample()
            else:
                _, action, _ = rl_algo.actor(np.array([obs]))
                action = action[0].numpy()

            if total_timesteps >= FLAGS.start_training_timesteps:
                with summary_writer.as_default():
                    target_entropy = (-env.action_space.shape[0]
                                      if FLAGS.target_entropy is None else
                                      FLAGS.target_entropy)
                    for _ in range(FLAGS.num_updates_per_env_step):
                        rl_algo.train(
                            replay_buffer_iter,
                            init_replay_buffer_iter,
                            discount=FLAGS.discount,
                            tau=FLAGS.tau,
                            target_entropy=target_entropy,
                            actor_update_freq=FLAGS.actor_update_freq)

            next_obs, reward, done, _ = env.step(action)
            if (max_episode_steps is not None
                    and episode_timesteps + 1 == max_episode_steps):
                done = True

            if not done or episode_timesteps + 1 == max_episode_steps:  # pylint: disable=protected-access
                mask = 1.0
            else:
                mask = 0.0
            """
      replay_buffer.add_batch((np.array([obs.astype(np.float32)]),
                               np.array([action.astype(np.float32)]),
                               np.array([next_obs.astype(np.float32)]),
                               np.array([[reward]]).astype(np.float32),
                               np.array([[mask]]).astype(np.float32)))
      """

            episode_return += reward
            episode_timesteps += 1
            total_timesteps += 1
            pbar.update(1)

            obs = next_obs

            if total_timesteps % FLAGS.eval_interval == 0:
                logging.info('Performing policy eval.')
                average_returns, evaluation_timesteps = rl_algo.evaluate(
                    eval_env, max_episode_steps=max_episode_steps)

                eval_returns.append(average_returns)
                fin = gfile.GFile(log_filename, 'w')
                np.save(fin, np.array(eval_returns))
                fin.close()

                eval_measurements = [
                    ('eval/average returns', average_returns),
                    ('eval/average episode length', evaluation_timesteps),
                ]
                # TODO(sandrafaust) Make this average of the last N.
                final_returns.append(average_returns)
                final_timesteps = evaluation_timesteps

                _write_measurements(summary_writer, eval_measurements,
                                    total_timesteps)

                logging.info('Eval: ave returns=%f, ave episode length=%f',
                             average_returns, evaluation_timesteps)
        # Final measurement.
        final_measurements = [
            ('final/average returns', sum(final_returns) / len(final_returns)),
            ('final/average episode length', final_timesteps),
        ]
        _write_measurements(summary_writer, final_measurements,
                            total_timesteps)