def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    expert_actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': expert_actor})
    actor_load = f'dataset/sac/{FLAGS.env.id}/policy.npy'
    loader.load_state_dict(np.load(actor_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % actor_load)

    loader = nn.ModuleDict({'policy': policy})
    # policy_load = 'benchmarks/discounted-policies/bc/bc-Hopper-v2-100-2020-05-16-18-39-51/final.npy'
    policy_load = 'benchmarks/discounted-policies/gail_nn/gail-Hopper-v2-100-2020-05-17-00-50-42/final.npy'
    loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
    logger.warning('Load policy from %s' % policy_load)

    for i in range(10):
        state = env.reset()
        return_ = 0.
        for t in range(env.max_episode_steps):
            env.render()
            action = expert_actor.get_actions(state[None],
                                              fetch='actions_mean')[0]

            next_state, reward, done, info = env.step(action)
            return_ += reward
            if done:
                break
            state = next_state
        print(return_)
    time.sleep(2)
    for i in range(10):
        state = env.reset()
        return_ = 0.
        for t in range(env.max_episode_steps):
            env.render()
            action = policy.get_actions(state[None], fetch='actions_mean')[0]

            next_state, reward, done, info = env.step(action)
            return_ += reward
            if done:
                break
            state = next_state
        print(return_)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    collect_mb = FLAGS.env.env_type == 'mb'
    if collect_mb:
        env_id = 'MB' + FLAGS.env.id
        logger.warning('Collect dataset for imitating environments')
    else:
        env_id = FLAGS.env.id
        logger.warning('Collect dataset for imitating policies')
    env = create_env(env_id,
                     FLAGS.seed,
                     FLAGS.log_dir,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)

    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.info('Load policy from %s' % FLAGS.ckpt.policy_load)

    state_traj, action_traj, next_state_traj, reward_traj, len_traj = [], [], [], [], []
    returns = []
    while len(state_traj) < 50:
        states = np.zeros([env.max_episode_steps, dim_state], dtype=np.float32)
        actions = np.zeros([env.max_episode_steps, dim_action],
                           dtype=np.float32)
        next_states = np.zeros([env.max_episode_steps, dim_state],
                               dtype=np.float32)
        rewards = np.zeros([env.max_episode_steps], dtype=np.float32)
        state = env.reset()
        done = False
        t = 0
        while not done:
            action = actor.get_actions(state[None], fetch='actions_mean')
            next_state, reward, done, info = env.step(action)

            states[t] = state
            actions[t] = action
            rewards[t] = reward
            next_states[t] = next_state
            t += 1
            if done:
                break
            state = next_state
        if t < 700 or np.sum(rewards) < 0:
            continue
        state_traj.append(states)
        action_traj.append(actions)
        next_state_traj.append(next_states)
        reward_traj.append(rewards)
        len_traj.append(t)

        returns.append(np.sum(rewards))
        logger.info('# %d: collect a trajectory return = %.4f length = %d',
                    len(state_traj), np.sum(rewards), t)

    state_traj = np.array(state_traj)
    action_traj = np.array(action_traj)
    next_state_traj = np.array(next_state_traj)
    reward_traj = np.array(reward_traj)
    len_traj = np.array(len_traj)
    assert len(state_traj.shape) == len(action_traj.shape) == 3
    assert len(reward_traj.shape) == 2 and len(len_traj.shape) == 1

    dataset = {
        'a_B_T_Da': action_traj,
        'len_B': len_traj,
        'obs_B_T_Do': state_traj,
        'r_B_T': reward_traj
    }
    if collect_mb:
        dataset['next_obs_B_T_Do'] = next_state_traj
    logger.info('Expert avg return = %.4f avg length = %d', np.mean(returns),
                np.mean(len_traj))

    if collect_mb:
        root_dir = 'dataset/mb2'
    else:
        root_dir = 'dataset/sac'

    save_dir = f'{root_dir}/{FLAGS.env.id}'
    os.makedirs(save_dir, exist_ok=True)
    shutil.copy(FLAGS.ckpt.policy_load, os.path.join(save_dir, 'policy.npy'))

    save_path = f'{root_dir}/{FLAGS.env.id}.h5'
    f = h5py.File(save_path, 'w')
    f.update(dataset)
    f.close()
    logger.info('save dataset into %s' % save_path)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=float(FLAGS.BC.lr),
                                    train_std=FLAGS.BC.train_std)

    expert_actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': expert_actor})
    if FLAGS.BC.dagger:
        loader.load_state_dict(
            np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
        logger.warning('Load expert policy from %s' % FLAGS.ckpt.policy_load)
    runner = Runner(env, max_steps=env.max_episode_steps, rescale_action=False)

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    expert_batch = expert_dataset.sample(10)
    expert_state = np.stack([t.obs for t in expert_batch])
    expert_action = np.stack([t.action for t in expert_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(expert_state),
                np.mean(expert_action))
    del expert_batch, expert_state, expert_action
    set_random_seed(FLAGS.seed)

    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    batch_size = FLAGS.BC.batch_size
    eval_gamma = 0.999
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            eval_returns_discount, eval_lengths_discount = evaluate(
                policy, env_eval, gamma=eval_gamma)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths))),
                             discounted_episode=dict(
                                 returns=np.mean(eval_returns_discount),
                                 lengths=int(np.mean(eval_lengths_discount)))))

        expert_batch = expert_dataset.sample(batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              fetch='train loss grad_norm')

        if FLAGS.BC.dagger and t % FLAGS.BC.collect_freq == 0 and t > 0:
            if t // FLAGS.BC.collect_freq == 1:
                collect_policy = expert_actor
                stochastic = False
                logger.info('Collect samples with expert actor...')
            else:
                collect_policy = policy
                stochastic = True
                logger.info('Collect samples with learned policy...')
            runner.reset()
            data, ep_infos = runner.run(collect_policy,
                                        FLAGS.BC.n_collect_samples, stochastic)
            data.action = expert_actor.get_actions(data.state,
                                                   fetch='actions_mean')
            returns = [info['return'] for info in ep_infos]
            lengths = [info['length'] for info in ep_infos]
            for i in range(len(data)):
                expert_dataset.push_back(data[i].state, data[i].action,
                                         data[i].next_state, data[i].reward,
                                         data[i].mask, data[i].timeout)
            logger.info('Collect %d samples avg return = %.4f avg length = %d',
                        len(data), np.mean(returns), np.mean(lengths))
        if t % 100 == 0:
            mse_loss = policy.get_mse_loss(expert_state, expert_action)
            log_kvs(prefix='BC',
                    kvs=dict(iter=t,
                             loss=loss,
                             grad_norm=grad_norm,
                             mse_loss=mse_loss))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate(policy, env_eval, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
示例#4
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    expert_done = np.stack([t.done for t in expert_dataset.buffer()])
    np.testing.assert_allclose(
        expert_next_state[:-1] * (1 - expert_done[:-1][:, None]),
        expert_state[1:] * (1 - expert_done[:-1][:, None]))
    del expert_state, expert_next_state, expert_done
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    eval_batch = expert_dataset.sample(1024)
    eval_state = np.stack([t.obs for t in eval_batch])
    eval_action = np.stack([t.action for t in eval_batch])
    eval_next_state = np.stack([t.next_obs for t in eval_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(eval_state),
                np.mean(eval_action))
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    set_random_seed(FLAGS.seed)

    # expert actor
    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)
    # generator
    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff,
                               normalizers=normalizers)
    vfn = MLPVFunction(dim_state, dim_action, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    if FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        loc, scale = np.mean(expert_state, axis=0,
                             keepdims=True), np.std(expert_state,
                                                    axis=0,
                                                    keepdims=True)
        del expert_batch, expert_state
        logger.info('loc = {}\nscale={}'.format(loc, scale))
        discriminator = Discriminator(dim_state,
                                      dim_action,
                                      normalizers=normalizers,
                                      subsampling_rate=subsampling_rate,
                                      loc=loc,
                                      scale=scale,
                                      **FLAGS.GAIL.discriminator.as_dict())
    else:
        raise NotImplementedError
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=FLAGS.BC.lr,
                                    train_std=FLAGS.BC.train_std)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.info('Load policy from %s' % FLAGS.ckpt.policy_load)
    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers,
        'discriminator': discriminator
    })
    print(saver)

    # updater normalizer
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_action = np.stack([t.action for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)
    del expert_state, expert_action, expert_next_state

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # pretrain
    for n_updates in range(FLAGS.GAIL.pretrain_iters):
        expert_batch = expert_dataset.sample(FLAGS.BC.batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        expert_next_state = np.stack([t.next_obs for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              expert_next_state,
                                              fetch='train loss grad_norm')
        if n_updates % 100 == 0:
            mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                           expert_next_state)
            logger.info(
                '[Pretrain] iter = %d grad_norm = %.4f loss = %.4f mse_loss = %.4f',
                n_updates, grad_norm, loss, mse_loss)

    # virtual env
    virtual_env = VirtualEnv(policy,
                             env,
                             n_envs=FLAGS.env.num_env,
                             stochastic_model=True)
    virtual_runner = VirtualRunner(virtual_env,
                                   max_steps=env.max_episode_steps,
                                   gamma=FLAGS.TRPO.gamma,
                                   lambda_=FLAGS.TRPO.lambda_,
                                   rescale_action=False)
    env_eval_stochastic = VirtualEnv(policy,
                                     env,
                                     n_envs=4,
                                     stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy,
                                        env,
                                        n_envs=4,
                                        stochastic_model=False)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    true_return = np.mean(eval_returns)
    for t in range(0, FLAGS.GAIL.total_timesteps,
                   FLAGS.TRPO.rollout_samples * FLAGS.GAIL.g_iters):
        time_st = time.time()
        if t % FLAGS.GAIL.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(
                prefix='Evaluate',
                kvs=dict(
                    iter=t,
                    stochastic_episode=dict(
                        returns=np.mean(eval_returns_stochastic),
                        lengths=int(np.mean(eval_lengths_stochastic))),
                    episode=dict(returns=np.mean(eval_returns_deterministic),
                                 lengths=int(
                                     np.mean(eval_lengths_deterministic))),
                    evaluation_error=dict(
                        stochastic_error=true_return -
                        np.mean(eval_returns_stochastic),
                        stochastic_abs=np.abs(
                            true_return - np.mean(eval_returns_stochastic)),
                        stochastic_rel=np.abs(true_return -
                                              np.mean(eval_returns_stochastic))
                        / true_return,
                        deterministic_error=true_return -
                        np.mean(eval_returns_deterministic),
                        deterministic_abs=np.abs(
                            true_return - np.mean(eval_returns_deterministic)),
                        deterministic_rel=np.abs(true_return - np.mean(
                            eval_returns_deterministic)) / true_return)))
        # Generator
        generator_dataset = None
        for n_update in range(FLAGS.GAIL.g_iters):
            data, ep_infos = virtual_runner.run(actor,
                                                FLAGS.TRPO.rollout_samples,
                                                stochastic=False)
            # if FLAGS.TRPO.normalization:
            #     normalizers.state.update(data.state)
            #     normalizers.action.update(data.action)
            #     normalizers.diff.update(data.next_state - data.state)
            if t == 0:
                np.testing.assert_allclose(data.reward,
                                           env.mb_step(data.state, data.action,
                                                       data.next_state)[0],
                                           atol=1e-4,
                                           rtol=1e-4)
            if t == 0 and n_update == 0 and not FLAGS.GAIL.learn_absorbing:
                data_ = data.copy()
                data_ = data_.reshape(
                    [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
                for e in range(env.n_envs):
                    samples = data_[:, e]
                    masks = 1 - (samples.done | samples.timeout)[...,
                                                                 np.newaxis]
                    masks = masks[:-1]
                    assert np.allclose(samples.state[1:] * masks,
                                       samples.next_state[:-1] * masks)
            t += FLAGS.TRPO.rollout_samples
            data.reward = discriminator.get_reward(data.state, data.action,
                                                   data.next_state)
            advantages, values = virtual_runner.compute_advantage(vfn, data)
            train_info = algo.train(max_ent_coef, data, advantages, values)
            fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
            train_info['reward'] = np.mean(data.reward)
            train_info['fps'] = fps

            expert_batch = expert_dataset.sample(256)
            expert_state = np.stack([t.obs for t in expert_batch])
            expert_action = np.stack([t.action for t in expert_batch])
            expert_next_state = np.stack([t.next_obs for t in expert_batch])
            train_mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                                 expert_next_state)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action,
                                                eval_next_state)
            train_info['mse_loss'] = dict(train=train_mse_loss,
                                          eval=eval_mse_loss)
            log_kvs(prefix='TRPO', kvs=dict(iter=t, **train_info))

            generator_dataset = data

        # Discriminator
        for n_update in range(FLAGS.GAIL.d_iters):
            batch_size = FLAGS.GAIL.d_batch_size
            d_train_infos = dict()
            for generator_subset in generator_dataset.iterator(batch_size):
                expert_batch = expert_dataset.sample(batch_size)
                expert_state = np.stack([t.obs for t in expert_batch])
                expert_action = np.stack([t.action for t in expert_batch])
                expert_next_state = np.stack(
                    [t.next_obs for t in expert_batch])
                expert_mask = None
                train_info = discriminator.train(
                    expert_state,
                    expert_action,
                    expert_next_state,
                    generator_subset.state,
                    generator_subset.action,
                    generator_subset.next_state,
                    expert_mask,
                )
                for k, v in train_info.items():
                    if k not in d_train_infos:
                        d_train_infos[k] = []
                    d_train_infos[k].append(v)
            d_train_infos = {k: np.mean(v) for k, v in d_train_infos.items()}
            if n_update == FLAGS.GAIL.d_iters - 1:
                log_kvs(prefix='Discriminator',
                        kvs=dict(iter=t, **d_train_infos))

        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(
            actor, env_eval_stochastic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
示例#5
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir,
                   rescale_action=FLAGS.env.rescale_action)
    env_eval = make_env(FLAGS.env.id,
                        FLAGS.env.env_type,
                        num_env=4,
                        seed=FLAGS.seed + 1000,
                        log_dir=FLAGS.log_dir,
                        rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    vfn = MLPVFunction(dim_state, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    discriminator = Discriminator(dim_state,
                                  dim_action,
                                  normalizers=normalizers,
                                  **FLAGS.GAIL.discriminator.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())

    # load expert dataset
    if not os.path.exists(FLAGS.GAIL.buf_load):
        raise FileNotFoundError('Expert dataset (%s) doest not exist' %
                                FLAGS.GAIL.buf_load)
    expert_dataset = Mujoco_Dset(FLAGS.GAIL.buf_load,
                                 train_fraction=FLAGS.GAIL.train_frac,
                                 traj_limitation=FLAGS.GAIL.traj_limit)

    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers
    })
    runner = Runner(env,
                    max_steps=env.max_episode_steps,
                    gamma=FLAGS.TRPO.gamma,
                    lambda_=FLAGS.TRPO.lambda_)
    print(saver)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    for t in range(0, FLAGS.GAIL.total_timesteps,
                   FLAGS.TRPO.rollout_samples * FLAGS.GAIL.g_iters):
        time_st = time.time()
        if t % FLAGS.GAIL.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths)))))

        # Generator
        generator_dataset = None
        for n_update in range(FLAGS.GAIL.g_iters):
            data, ep_infos = runner.run(policy, FLAGS.TRPO.rollout_samples)
            if FLAGS.TRPO.normalization:
                normalizers.state.update(data.state)
                normalizers.action.update(data.action)
                normalizers.diff.update(data.next_state - data.state)
            if t == 0 and n_update == 0:
                data_ = data.copy()
                data_ = data_.reshape(
                    [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
                for e in range(env.n_envs):
                    samples = data_[:, e]
                    masks = 1 - (samples.done | samples.timeout)[...,
                                                                 np.newaxis]
                    masks = masks[:-1]
                    assert np.allclose(samples.state[1:] * masks,
                                       samples.next_state[:-1] * masks)
            t += FLAGS.TRPO.rollout_samples
            data.reward = discriminator.get_reward(data.state, data.action)
            advantages, values = runner.compute_advantage(vfn, data)
            train_info = algo.train(max_ent_coef, data, advantages, values)
            fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
            train_info['reward'] = np.mean(data.reward)
            train_info['fps'] = fps
            log_kvs(prefix='TRPO', kvs=dict(iter=t, **train_info))

            generator_dataset = data

        # Discriminator
        for n_update in range(FLAGS.GAIL.d_iters):
            batch_size = FLAGS.GAIL.d_batch_size
            d_train_infos = dict()
            for generator_subset in generator_dataset.iterator(batch_size):
                expert_state, expert_action = expert_dataset.get_next_batch(
                    batch_size)
                train_info = discriminator.train(expert_state, expert_action,
                                                 generator_subset.state,
                                                 generator_subset.action)
                for k, v in train_info.items():
                    if k not in d_train_infos:
                        d_train_infos[k] = []
                    d_train_infos[k].append(v)
            d_train_infos = {k: np.mean(v) for k, v in d_train_infos.items()}
            if n_update == FLAGS.GAIL.d_iters - 1:
                log_kvs(prefix='Discriminator',
                        kvs=dict(iter=t, **d_train_infos))

        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff,
                               normalizers=normalizers)
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=float(FLAGS.BC.lr),
                                    train_std=FLAGS.BC.train_std)

    actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    expert_done = np.stack([t.done for t in expert_dataset.buffer()])
    np.testing.assert_allclose(
        expert_next_state[:-1] * (1 - expert_done[:-1][:, None]),
        expert_state[1:] * (1 - expert_done[:-1][:, None]))
    del expert_state, expert_next_state, expert_done
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    eval_batch = expert_dataset.sample(1024)
    eval_state = np.stack([t.obs for t in eval_batch])
    eval_action = np.stack([t.action for t in eval_batch])
    eval_next_state = np.stack([t.next_obs for t in eval_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(eval_state),
                np.mean(eval_action))
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    set_random_seed(FLAGS.seed)

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % FLAGS.ckpt.policy_load)
    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    # updater normalizer
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_action = np.stack([t.action for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)
    del expert_state, expert_action, expert_next_state

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # virtual env
    env_eval_stochastic = VirtualEnv(policy,
                                     env,
                                     n_envs=4,
                                     stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy,
                                        env,
                                        n_envs=4,
                                        stochastic_model=False)

    batch_size = FLAGS.BC.batch_size
    true_return = np.mean(eval_returns)
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(
                prefix='Evaluate',
                kvs=dict(
                    iter=t,
                    stochastic_episode=dict(
                        returns=np.mean(eval_returns_stochastic),
                        lengths=int(np.mean(eval_lengths_stochastic))),
                    episode=dict(returns=np.mean(eval_returns_deterministic),
                                 lengths=int(
                                     np.mean(eval_lengths_deterministic))),
                    evaluation_error=dict(
                        stochastic_error=true_return -
                        np.mean(eval_returns_stochastic),
                        stochastic_abs=np.abs(
                            true_return - np.mean(eval_returns_stochastic)),
                        stochastic_rel=np.abs(true_return -
                                              np.mean(eval_returns_stochastic))
                        / true_return,
                        deterministic_error=true_return -
                        np.mean(eval_returns_deterministic),
                        deterministic_abs=np.abs(
                            true_return - np.mean(eval_returns_deterministic)),
                        deterministic_rel=np.abs(true_return - np.mean(
                            eval_returns_deterministic)) / true_return)))

        expert_batch = expert_dataset.sample(batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        expert_next_state = np.stack([t.next_obs for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              expert_next_state,
                                              fetch='train loss grad_norm')

        if t % 100 == 0:
            train_mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                                 expert_next_state)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action,
                                                eval_next_state)
            log_kvs(prefix='BC',
                    kvs=dict(iter=t,
                             grad_norm=grad_norm,
                             loss=loss,
                             mse_loss=dict(train=train_mse_loss,
                                           eval=eval_mse_loss)))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(
            actor, env_eval_stochastic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
示例#7
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir)
    state_spec = env.observation_space
    action_spec = env.action_space

    logger.info('[{}]: state_spec:{}, action_spec:{}'.format(
        FLAGS.env.id, state_spec.shape, action_spec.n))

    dtype = gen_dtype(env,
                      'state action next_state mu reward done timeout info')
    buffer = ReplayBuffer(env.n_envs,
                          FLAGS.ACER.n_steps,
                          stacked_frame=FLAGS.env.env_type == 'atari',
                          dtype=dtype,
                          size=FLAGS.ACER.buffer_size)

    if len(state_spec.shape) == 3:
        policy = CNNPolicy(state_spec, action_spec)
    else:
        policy = MLPPolicy(state_spec, action_spec)

    algo = ACER(state_spec,
                action_spec,
                policy,
                lr=FLAGS.ACER.lr,
                lrschedule=FLAGS.ACER.lrschedule,
                total_timesteps=FLAGS.ACER.total_timesteps,
                ent_coef=FLAGS.ACER.ent_coef,
                q_coef=FLAGS.ACER.q_coef,
                trust_region=FLAGS.ACER.trust_region)
    runner = Runner(env,
                    max_steps=env.max_episode_steps,
                    gamma=FLAGS.ACER.gamma)
    saver = nn.ModuleDict({'policy': policy})
    print(saver)

    tf.get_default_session().run(tf.global_variables_initializer())
    algo.update_old_policy(0.)

    n_steps = FLAGS.ACER.n_steps
    n_batches = n_steps * env.n_envs
    n_stages = FLAGS.ACER.total_timesteps // n_batches

    returns = collections.deque(maxlen=40)
    lengths = collections.deque(maxlen=40)
    replay_reward = collections.deque(maxlen=40)
    time_st = time.time()
    for t in range(n_stages):
        data, ep_infos = runner.run(policy, n_steps)
        returns.extend([info['return'] for info in ep_infos])
        lengths.extend([info['length'] for info in ep_infos])

        if t == 0:  # check runner
            indices = np.arange(0, n_batches, env.n_envs)
            for _ in range(env.n_envs):
                samples = data[indices]
                masks = 1 - (samples.done | samples.timeout)
                masks = masks[:-1]
                masks = np.reshape(masks,
                                   [-1] + [1] * len(samples.state.shape[1:]))
                np.testing.assert_allclose(samples.state[1:] * masks,
                                           samples.next_state[:-1] * masks)
                indices += 1

        buffer.store_episode(data)
        if t == 1:  # check buffer
            data_ = buffer.sample(idx=[1 for _ in range(env.n_envs)])
            check_data_equal(data_, data, ('state', 'action', 'next_state',
                                           'mu', 'reward', 'done', 'timeout'))

        # on-policy training
        qret = runner.compute_qret(policy, data)
        train_info = algo.train(data, qret, t * n_batches)
        replay_reward.append(np.mean(data.reward))
        # off-policy training
        if t * n_batches > FLAGS.ACER.replay_start:
            n = np.random.poisson(FLAGS.ACER.replay_ratio)
            for _ in range(n):
                data = buffer.sample()
                qret = runner.compute_qret(policy, data)
                algo.train(data, qret, t * n_batches)
                replay_reward.append(np.mean(data.reward))

        if t * n_batches % FLAGS.ACER.log_interval == 0:
            fps = int(t * n_batches / (time.time() - time_st))
            kvs = dict(iter=t * n_batches,
                       episode=dict(
                           returns=np.mean(returns) if len(returns) > 0 else 0,
                           lengths=np.mean(lengths).astype(np.int32)
                           if len(lengths) > 0 else 0),
                       **train_info,
                       replay_reward=np.mean(replay_reward)
                       if len(replay_reward) > 0 else 0.,
                       fps=fps)
            log_kvs(prefix='ACER', kvs=kvs)

        if t * n_batches % FLAGS.ACER.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    bc_normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    bc_policy = GaussianMLPPolicy(dim_state,
                                  dim_action,
                                  FLAGS.TRPO.policy_hidden_sizes,
                                  output_diff=FLAGS.TRPO.output_diff,
                                  normalizers=bc_normalizers)

    gail_normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    gail_policy = GaussianMLPPolicy(dim_state,
                                    dim_action,
                                    FLAGS.TRPO.policy_hidden_sizes,
                                    output_diff=FLAGS.TRPO.output_diff,
                                    normalizers=gail_normalizers)

    actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    policy_load = f'dataset/mb2/{FLAGS.env.id}/policy.npy'
    loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % policy_load)

    bc_policy_load = "benchmarks/mbrl_benchmark/mbrl2_bc_30_1000/mbrl2_bc-Walker2d-v2-100-2020-05-22-16-02-12/final.npy"
    loader = nn.ModuleDict({
        'policy': bc_policy,
        'normalizers': bc_normalizers
    })
    loader.load_state_dict(np.load(bc_policy_load, allow_pickle=True)[()])
    logger.warning('Load bc policy from %s' % bc_policy_load)

    gail_policy_load = "benchmarks/mbrl_benchmark/mbrl2_gail_grad_penalty/mbrl2_gail-Walker2d-v2-100-2020-05-22-12-10-07/final.npy"
    loader = nn.ModuleDict({
        'policy': gail_policy,
        'normalizers': gail_normalizers
    })
    loader.load_state_dict(np.load(gail_policy_load, allow_pickle=True)[()])
    logger.warning('Load gail policy from %s' % gail_policy_load)

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    real_runner = Runner(env,
                         max_steps=env.max_episode_steps,
                         rescale_action=False)
    # virtual env
    env_bc_stochastic = VirtualEnv(bc_policy,
                                   env,
                                   n_envs=1,
                                   stochastic_model=True)
    env_bc_deterministic = VirtualEnv(bc_policy,
                                      env,
                                      n_envs=1,
                                      stochastic_model=False)
    runner_bc_stochastic = VirtualRunner(env_bc_stochastic,
                                         max_steps=env.max_episode_steps,
                                         rescale_action=False)
    runner_bc_deterministic = VirtualRunner(env_bc_deterministic,
                                            max_steps=env.max_episode_steps,
                                            rescale_action=False)

    env_gail_stochastic = VirtualEnv(gail_policy,
                                     env,
                                     n_envs=1,
                                     stochastic_model=True)
    env_gail_deterministic = VirtualEnv(gail_policy,
                                        env,
                                        n_envs=1,
                                        stochastic_model=False)
    runner_gail_stochastic = VirtualRunner(env_gail_stochastic,
                                           max_steps=env.max_episode_steps)
    runner_gail_deterministic = VirtualRunner(env_gail_deterministic,
                                              max_steps=env.max_episode_steps)

    data_actor, ep_infos = real_runner.run(actor,
                                           n_samples=int(2e3),
                                           stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for actor avg return = %.4f avg length = %d',
        len(data_actor), np.mean(returns), np.mean(lengths))

    data_bc_stochastic, ep_infos = runner_bc_stochastic.run(actor,
                                                            n_samples=int(2e3),
                                                            stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for bc stochastic policy avg return = %.4f avg length = %d',
        len(data_bc_stochastic), np.mean(returns), np.mean(lengths))

    reward_ref, _ = env.mb_step(data_bc_stochastic.state,
                                data_bc_stochastic.action,
                                data_bc_stochastic.next_state)
    np.testing.assert_allclose(reward_ref,
                               data_bc_stochastic.reward,
                               rtol=1e-4,
                               atol=1e-4)

    data_bc_deterministic, ep_infos = runner_bc_deterministic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for bc deterministic policy avg return = %.4f avg length = %d',
        len(data_bc_deterministic), np.mean(returns), np.mean(lengths))

    reward_ref, _ = env.mb_step(data_bc_deterministic.state,
                                data_bc_deterministic.action,
                                data_bc_deterministic.next_state)
    np.testing.assert_allclose(reward_ref,
                               data_bc_deterministic.reward,
                               rtol=1e-4,
                               atol=1e-4)

    data_gail_stochastic, ep_infos = runner_gail_stochastic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for gail stochastic policy avg return = %.4f avg length = %d',
        len(data_gail_stochastic), np.mean(returns), np.mean(lengths))
    data_gail_deterministic, ep_infos = runner_gail_deterministic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for gail deterministic policy avg return = %.4f avg length = %d',
        len(data_bc_deterministic), np.mean(returns), np.mean(lengths))

    t_sne = manifold.TSNE(init='pca', random_state=2020)
    data = np.concatenate([
        data.state for data in [
            data_actor, data_bc_stochastic, data_bc_deterministic,
            data_gail_stochastic, data_gail_deterministic
        ]
    ],
                          axis=0)
    step = np.concatenate([
        data.step for data in [
            data_actor, data_bc_stochastic, data_bc_deterministic,
            data_gail_stochastic, data_gail_deterministic
        ]
    ],
                          axis=0)
    loc, scale = bc_normalizers.state.eval('mean std')
    data = (data - loc) / (1e-6 + scale)
    embedding = t_sne.fit_transform(data)

    fig, axarrs = plt.subplots(nrows=1,
                               ncols=5,
                               figsize=[6 * 5, 4],
                               squeeze=False,
                               sharex=True,
                               sharey=True,
                               dpi=300)
    start = 0
    indices = 0
    g2c = {}
    for title in [
            'expert', 'bc_stochastic', 'bc_deterministic', 'gail_stochastic',
            'gail_deterministic'
    ]:
        g2c[title] = axarrs[0][indices].scatter(embedding[start:start + 2000,
                                                          0],
                                                embedding[start:start + 2000,
                                                          1],
                                                c=step[start:start + 2000])
        axarrs[0][indices].set_title(title)
        indices += 1
        start += 2000
    plt.colorbar(list(g2c.values())[0], ax=axarrs.flatten())
    plt.tight_layout()
    plt.savefig(f'{FLAGS.log_dir}/visualize.png', bbox_inches='tight')

    data = {
        'expert': data_actor.state,
        'bc_stochastic': data_bc_stochastic.state,
        'bc_deterministic': data_bc_deterministic.state,
        'gail_stochastic': data_gail_stochastic.state,
        'gail_deterministic': data_gail_deterministic.state
    }
    np.savez(f'{FLAGS.log_dir}/data.npz', **data)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id, seed=FLAGS.seed, rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    # expert actor
    actor = Actor(dim_state, dim_action, init_std=0.)
    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    expert_state, expert_action, expert_next_state, expert_reward = collect_samples_from_true_env(
        env=env, actor=actor, nb_episode=FLAGS.GAIL.traj_limit, subsampling_rate=subsampling_rate)
    logger.info('Collect % d samples avg return = %.4f', len(expert_state), np.mean(expert_reward))
    eval_state, eval_action, eval_next_state, eval_reward = collect_samples_from_true_env(
        env=env, actor=actor, nb_episode=3, seed=FLAGS.seed)
    loc, scale = np.mean(expert_state, axis=0, keepdims=True), np.std(expert_state, axis=0, keepdims=True)
    logger.info('loc = {}\nscale={}'.format(loc, scale))

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state, dim_action, FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff, normalizers=normalizers)
    bc_loss = BehavioralCloningLoss(dim_state, dim_action, policy, lr=float(FLAGS.BC.lr), train_std=FLAGS.BC.train_std)

    tf.get_default_session().run(tf.global_variables_initializer())
    set_random_seed(FLAGS.seed)

    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    # updater normalizer
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor, env, gamma=eval_gamma)
    logger.warning('Test policy true value = %.4f true length = %d (gamma = %f)',
                   np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # virtual env
    env_eval_stochastic = VirtualEnv(policy, env, n_envs=4, stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy, env, n_envs=4, stochastic_model=False)

    batch_size = FLAGS.BC.batch_size
    true_return = np.mean(eval_returns)
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(prefix='Evaluate', kvs=dict(
                iter=t, stochastic_episode=dict(
                    returns=np.mean(eval_returns_stochastic), lengths=int(np.mean(eval_lengths_stochastic))
                ), episode=dict(
                    returns=np.mean(eval_returns_deterministic), lengths=int(np.mean(eval_lengths_deterministic))
                ),  evaluation_error=dict(
                    stochastic_error=true_return-np.mean(eval_returns_stochastic),
                    stochastic_abs=np.abs(true_return-np.mean(eval_returns_stochastic)),
                    stochastic_rel=np.abs(true_return-np.mean(eval_returns_stochastic))/true_return,
                    deterministic_error=true_return-np.mean(eval_returns_deterministic),
                    deterministic_abs=np.abs(true_return - np.mean(eval_returns_deterministic)),
                    deterministic_rel=np.abs(true_return-np.mean(eval_returns_deterministic))/true_return
                )
            ))

        indices = np.random.randint(low=0, high=len(expert_state), size=batch_size)
        expert_state_ = expert_state[indices]
        expert_action_ = expert_action[indices]
        expert_next_state_ = expert_next_state[indices]
        _, loss, grad_norm = bc_loss.get_loss(expert_state_, expert_action_, expert_next_state_,
                                              fetch='train loss grad_norm')

        if t % 100 == 0:
            train_mse_loss = policy.get_mse_loss(expert_state_, expert_action_, expert_next_state_)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action, eval_next_state)
            log_kvs(prefix='BC', kvs=dict(
                iter=t, grad_norm=grad_norm, loss=loss, mse_loss=dict(train=train_mse_loss, eval=eval_mse_loss)
            ))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(actor, env_eval_deterministic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
示例#10
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    # load expert dataset
    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    expert_batch = expert_dataset.sample(10)
    expert_state = np.stack([t.obs for t in expert_batch])
    expert_action = np.stack([t.action for t in expert_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(expert_state),
                np.mean(expert_action))
    del expert_batch, expert_state, expert_action
    set_random_seed(FLAGS.seed)

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    vfn = MLPVFunction(dim_state, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    if FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        loc, scale = np.mean(expert_state, axis=0,
                             keepdims=True), np.std(expert_state,
                                                    axis=0,
                                                    keepdims=True)
        del expert_batch, expert_state
        discriminator = Discriminator(dim_state,
                                      dim_action,
                                      normalizers=normalizers,
                                      subsampling_rate=subsampling_rate,
                                      loc=loc,
                                      scale=scale,
                                      **FLAGS.GAIL.discriminator.as_dict())
    elif FLAGS.GAIL.reward_type in {'simplex', 'l2'}:
        discriminator = LinearReward(
            dim_state, dim_action, simplex=FLAGS.GAIL.reward_type == 'simplex')
    else:
        raise NotImplementedError
    tf.get_default_session().run(tf.global_variables_initializer())

    if not FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        discriminator.build(expert_state, expert_action)
        del expert_batch, expert_state, expert_action

    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers,
        'discriminator': discriminator
    })
    runner = Runner(env,
                    max_steps=env.max_episode_steps,
                    gamma=FLAGS.TRPO.gamma,
                    lambda_=FLAGS.TRPO.lambda_,
                    add_absorbing_state=FLAGS.GAIL.learn_absorbing)
    print(saver)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    eval_gamma = 0.999
    for t in range(0, FLAGS.GAIL.total_timesteps,
                   FLAGS.TRPO.rollout_samples * FLAGS.GAIL.g_iters):
        time_st = time.time()
        if t % FLAGS.GAIL.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            eval_returns_discount, eval_lengths_discount = evaluate(
                policy, env_eval, gamma=eval_gamma)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths))),
                             discounted_episode=dict(
                                 returns=np.mean(eval_returns_discount),
                                 lengths=int(np.mean(eval_lengths_discount)))))

        # Generator
        generator_dataset = None
        for n_update in range(FLAGS.GAIL.g_iters):
            data, ep_infos = runner.run(policy, FLAGS.TRPO.rollout_samples)
            if FLAGS.TRPO.normalization:
                normalizers.state.update(data.state)
                normalizers.action.update(data.action)
                normalizers.diff.update(data.next_state - data.state)
            if t == 0 and n_update == 0 and not FLAGS.GAIL.learn_absorbing:
                data_ = data.copy()
                data_ = data_.reshape(
                    [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
                for e in range(env.n_envs):
                    samples = data_[:, e]
                    masks = 1 - (samples.done | samples.timeout)[...,
                                                                 np.newaxis]
                    masks = masks[:-1]
                    assert np.allclose(samples.state[1:] * masks,
                                       samples.next_state[:-1] * masks)
            t += FLAGS.TRPO.rollout_samples
            data.reward = discriminator.get_reward(data.state, data.action)
            advantages, values = runner.compute_advantage(vfn, data)
            train_info = algo.train(max_ent_coef, data, advantages, values)
            fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
            train_info['reward'] = np.mean(data.reward)
            train_info['fps'] = fps

            expert_batch = expert_dataset.sample(256)
            expert_state = np.stack([t.obs for t in expert_batch])
            expert_action = np.stack([t.action for t in expert_batch])
            train_info['mse_loss'] = policy.get_mse_loss(
                expert_state, expert_action)
            log_kvs(prefix='TRPO', kvs=dict(iter=t, **train_info))

            generator_dataset = data

        # Discriminator
        if FLAGS.GAIL.reward_type in {'nn', 'vb'}:
            for n_update in range(FLAGS.GAIL.d_iters):
                batch_size = FLAGS.GAIL.d_batch_size
                d_train_infos = dict()
                for generator_subset in generator_dataset.iterator(batch_size):
                    expert_batch = expert_dataset.sample(batch_size)
                    expert_state = np.stack([t.obs for t in expert_batch])
                    expert_action = np.stack([t.action for t in expert_batch])
                    expert_mask = np.stack([
                        t.mask for t in expert_batch
                    ]).flatten() if FLAGS.GAIL.learn_absorbing else None
                    train_info = discriminator.train(
                        expert_state,
                        expert_action,
                        generator_subset.state,
                        generator_subset.action,
                        expert_mask,
                    )
                    for k, v in train_info.items():
                        if k not in d_train_infos:
                            d_train_infos[k] = []
                        d_train_infos[k].append(v)
                d_train_infos = {
                    k: np.mean(v)
                    for k, v in d_train_infos.items()
                }
                if n_update == FLAGS.GAIL.d_iters - 1:
                    log_kvs(prefix='Discriminator',
                            kvs=dict(iter=t, **d_train_infos))
        else:
            train_info = discriminator.train(generator_dataset.state,
                                             generator_dataset.action)
            log_kvs(prefix='Discriminator', kvs=dict(iter=t, **train_info))

        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate(policy, env_eval, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
示例#11
0
文件: main.py 项目: liziniu/RLX
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir,
                   rescale_action=FLAGS.env.rescale_action)
    env_eval = make_env(FLAGS.env.id,
                        FLAGS.env.env_type,
                        num_env=4,
                        seed=FLAGS.seed + 1000,
                        log_dir=FLAGS.log_dir)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.TD3.actor_hidden_sizes)
    critic = Critic(dim_state,
                    dim_action,
                    hidden_sizes=FLAGS.TD3.critic_hidden_sizes)
    td3 = TD3(dim_state,
              dim_action,
              actor=actor,
              critic=critic,
              **FLAGS.TD3.algo.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())
    td3.update_actor_target(tau=0.0)
    td3.update_critic_target(tau=0.0)

    dtype = gen_dtype(env, 'state action next_state reward done timeout')
    buffer = Dataset(dtype=dtype, max_size=FLAGS.TD3.buffer_size)
    saver = nn.ModuleDict({'actor': actor, 'critic': critic})
    print(saver)

    n_steps = np.zeros(env.n_envs)
    n_returns = np.zeros(env.n_envs)

    train_returns = collections.deque(maxlen=40)
    train_lengths = collections.deque(maxlen=40)
    states = env.reset()
    time_st = time.time()
    for t in range(FLAGS.TD3.total_timesteps):
        if t < FLAGS.TD3.init_random_steps:
            actions = np.array(
                [env.action_space.sample() for _ in range(env.n_envs)])
        else:
            raw_actions = actor.get_actions(states)
            noises = np.random.normal(loc=0.,
                                      scale=FLAGS.TD3.explore_noise,
                                      size=raw_actions.shape)
            actions = np.clip(raw_actions + noises, -1, 1)
        next_states, rewards, dones, infos = env.step(actions)
        n_returns += rewards
        n_steps += 1
        timeouts = n_steps == env.max_episode_steps
        terminals = np.copy(dones)
        for e, info in enumerate(infos):
            if info.get('TimeLimit.truncated', False):
                terminals[e] = False

        transitions = [
            states, actions,
            next_states.copy(), rewards, terminals,
            timeouts.copy()
        ]
        buffer.extend(np.rec.fromarrays(transitions, dtype=dtype))

        indices = np.where(dones | timeouts)[0]
        if len(indices) > 0:
            next_states[indices] = env.partial_reset(indices)

            train_returns.extend(n_returns[indices])
            train_lengths.extend(n_steps[indices])
            n_returns[indices] = 0
            n_steps[indices] = 0
        states = next_states.copy()

        if t == 2000:
            assert env.n_envs == 1
            samples = buffer.sample(size=None, indices=np.arange(2000))
            masks = 1 - (samples.done | samples.timeout)[..., np.newaxis]
            masks = masks[:-1]
            assert np.allclose(samples.state[1:] * masks,
                               samples.next_state[:-1] * masks)

        if t >= FLAGS.TD3.init_random_steps:
            samples = buffer.sample(FLAGS.TD3.batch_size)
            train_info = td3.train(samples)
            if t % FLAGS.TD3.log_freq == 0:
                fps = int(t / (time.time() - time_st))
                train_info['fps'] = fps
                log_kvs(prefix='TD3',
                        kvs=dict(iter=t,
                                 episode=dict(
                                     returns=np.mean(train_returns)
                                     if len(train_returns) > 0 else 0.,
                                     lengths=int(
                                         np.mean(train_lengths)
                                         if len(train_lengths) > 0 else 0)),
                                 **train_info))

        if t % FLAGS.TD3.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(actor,
                                                  env_eval,
                                                  deterministic=False)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths)))))

        if t % FLAGS.TD3.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
示例#12
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir,
                   rescale_action=FLAGS.env.rescale_action)
    env_eval = make_env(FLAGS.env.id,
                        FLAGS.env.env_type,
                        num_env=4,
                        seed=FLAGS.seed + 1000,
                        log_dir=FLAGS.log_dir)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    vfn = MLPVFunction(dim_state, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())

    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers
    })
    runner = Runner(env,
                    max_steps=env.max_episode_steps,
                    gamma=FLAGS.TRPO.gamma,
                    lambda_=FLAGS.TRPO.lambda_,
                    partial_episode_bootstrapping=FLAGS.TRPO.peb)
    print(saver)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    train_returns = collections.deque(maxlen=40)
    train_lengths = collections.deque(maxlen=40)
    for t in range(0, FLAGS.TRPO.total_timesteps, FLAGS.TRPO.rollout_samples):
        time_st = time.time()
        if t % FLAGS.TRPO.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths)))))

        data, ep_infos = runner.run(policy, FLAGS.TRPO.rollout_samples)
        if t == 0:
            data_ = data.copy()
            data_ = data_.reshape(
                [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
            for e in range(env.n_envs):
                samples = data_[:, e]
                masks = 1 - (samples.done | samples.timeout)[..., np.newaxis]
                masks = masks[:-1]
                assert np.allclose(samples.state[1:] * masks,
                                   samples.next_state[:-1] * masks)

        if FLAGS.TRPO.normalization:
            normalizers.state.update(data.state)
            normalizers.action.update(data.action)
            normalizers.diff.update(data.next_state - data.state)
        advantages, values = runner.compute_advantage(vfn, data)
        train_info = algo.train(max_ent_coef, data, advantages, values)
        train_returns.extend([info['return'] for info in ep_infos])
        train_lengths.extend([info['length'] for info in ep_infos])
        fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
        train_info['fps'] = fps
        log_kvs(prefix='TRPO',
                kvs=dict(iter=t,
                         episode=dict(
                             returns=np.mean(train_returns)
                             if len(train_returns) > 0 else 0.,
                             lengths=int(
                                 np.mean(train_lengths
                                         ) if len(train_lengths) > 0 else 0)),
                         **train_info))

        t += FLAGS.TRPO.rollout_samples
        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
示例#13
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir,
                   rescale_action=FLAGS.env.rescale_action)
    env_eval = make_env(FLAGS.env.id,
                        FLAGS.env.env_type,
                        num_env=4,
                        seed=FLAGS.seed + 1000,
                        log_dir=FLAGS.log_dir)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)
    critic = Critic(dim_state,
                    dim_action,
                    hidden_sizes=FLAGS.SAC.critic_hidden_sizes)
    target_entropy = FLAGS.SAC.target_entropy
    if target_entropy is None:
        target_entropy = -dim_action
    sac = SAC(dim_state,
              dim_action,
              actor=actor,
              critic=critic,
              target_entropy=target_entropy,
              **FLAGS.SAC.algo.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())
    sac.update_critic_target(tau=0.0)

    dtype = gen_dtype(env, 'state action next_state reward done')
    buffer = Dataset(dtype=dtype, max_size=FLAGS.SAC.buffer_size)
    saver = nn.ModuleDict({'actor': actor, 'critic': critic})
    print(saver)

    n_steps = np.zeros(env.n_envs)
    n_returns = np.zeros(env.n_envs)

    train_returns = collections.deque(maxlen=40)
    train_lengths = collections.deque(maxlen=40)
    states = env.reset()
    time_st = time.time()
    for t in range(FLAGS.SAC.total_timesteps):
        if t < FLAGS.SAC.init_random_steps:
            actions = np.array(
                [env.action_space.sample() for _ in range(env.n_envs)])
        else:
            actions = actor.get_actions(states)
        next_states, rewards, dones, infos = env.step(actions)
        n_returns += rewards
        n_steps += 1
        timeouts = n_steps == env.max_episode_steps
        terminals = np.copy(dones)
        for e, info in enumerate(infos):
            if FLAGS.SAC.peb and info.get('TimeLimit.truncated', False):
                terminals[e] = False

        transitions = [states, actions, next_states.copy(), rewards, terminals]
        buffer.extend(np.rec.fromarrays(transitions, dtype=dtype))

        indices = np.where(dones | timeouts)[0]
        if len(indices) > 0:
            next_states[indices] = env.partial_reset(indices)

            train_returns.extend(n_returns[indices])
            train_lengths.extend(n_steps[indices])
            n_returns[indices] = 0
            n_steps[indices] = 0
        states = next_states.copy()

        if t >= FLAGS.SAC.init_random_steps:
            samples = buffer.sample(FLAGS.SAC.batch_size)
            train_info = sac.train(samples)
            if t % FLAGS.SAC.log_freq == 0:
                fps = int(t / (time.time() - time_st))
                train_info['fps'] = fps
                log_kvs(prefix='SAC',
                        kvs=dict(iter=t,
                                 episode=dict(
                                     returns=np.mean(train_returns)
                                     if len(train_returns) > 0 else 0.,
                                     lengths=int(
                                         np.mean(train_lengths)
                                         if len(train_lengths) > 0 else 0)),
                                 **train_info))

        if t % FLAGS.SAC.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(actor, env_eval)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths)))))

        if t % FLAGS.SAC.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)

    tf.get_default_session().run(tf.global_variables_initializer())

    expert_result_path = os.path.join('logs', 'expert-%s.yml' % FLAGS.env.id)
    if not os.path.exists(expert_result_path):
        expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
        expert_reward = expert_dataset.get_average_reward()
        logger.info('Expert Reward %f', expert_reward)
        if FLAGS.GAIL.learn_absorbing:
            expert_dataset.add_absorbing_states(env)

        expert_result = dict()
        for gamma in [0.9, 0.99, 0.999, 1.0]:
            expert_returns = []
            discount = 1.
            expert_return = 0.
            for timestep in expert_dataset.buffer():
                expert_return += discount * timestep.reward[0]
                discount *= gamma
                if timestep.done:
                    expert_returns.append(float(expert_return))
                    discount = 1.
                    expert_return = 0.
            expert_result[gamma] = [
                float(np.mean(expert_returns)), expert_returns
            ]
            logger.info('Expert gamma = %f %.4f (n_episode = %d)', gamma,
                        np.mean(expert_returns), len(expert_returns))
        yaml.dump(expert_result,
                  open(expert_result_path, 'w'),
                  default_flow_style=False)

    # loader policy
    loader = nn.ModuleDict({'policy': policy})
    root_dir = 'logs/gail_l2'
    for save_dir in sorted(os.listdir(root_dir)):
        if FLAGS.env.id not in save_dir:
            continue
        policy_load = os.path.join(root_dir, save_dir, 'stage-3000000.npy')
        loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
        logger.warning('Load {} from {}'.format(loader.keys(), policy_load))

        dict_result = dict()
        for gamma in [0.9, 0.99, 0.999, 1.0]:
            eval_returns, eval_lengths = evaluate(policy,
                                                  env_eval,
                                                  gamma=gamma)
            dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
            logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

        save_path = os.path.join(root_dir, save_dir, 'evaluate.yml')
        yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)