def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    expert_actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': expert_actor})
    actor_load = f'dataset/sac/{FLAGS.env.id}/policy.npy'
    loader.load_state_dict(np.load(actor_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % actor_load)

    loader = nn.ModuleDict({'policy': policy})
    # policy_load = 'benchmarks/discounted-policies/bc/bc-Hopper-v2-100-2020-05-16-18-39-51/final.npy'
    policy_load = 'benchmarks/discounted-policies/gail_nn/gail-Hopper-v2-100-2020-05-17-00-50-42/final.npy'
    loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
    logger.warning('Load policy from %s' % policy_load)

    for i in range(10):
        state = env.reset()
        return_ = 0.
        for t in range(env.max_episode_steps):
            env.render()
            action = expert_actor.get_actions(state[None],
                                              fetch='actions_mean')[0]

            next_state, reward, done, info = env.step(action)
            return_ += reward
            if done:
                break
            state = next_state
        print(return_)
    time.sleep(2)
    for i in range(10):
        state = env.reset()
        return_ = 0.
        for t in range(env.max_episode_steps):
            env.render()
            action = policy.get_actions(state[None], fetch='actions_mean')[0]

            next_state, reward, done, info = env.step(action)
            return_ += reward
            if done:
                break
            state = next_state
        print(return_)
Ejemplo n.º 2
0
Archivo: flags.py Proyecto: liziniu/RLX
    def set_seed(cls):
        if cls.seed == 0:  # auto seed
            cls.seed = int.from_bytes(
                os.urandom(3),
                'little') + 1  # never use seed 0 for RNG, 0 is for `urandom`
        logger.warning("Setting random seed to %s", cls.seed)

        import numpy as np
        import tensorflow as tf
        import random
        np.random.seed(cls.seed)
        tf.set_random_seed(cls.seed + 1000)
        random.seed(cls.seed + 2000)
Ejemplo n.º 3
0
    def set_seed(cls):
        if cls.seed == 0:  # auto seed
            cls.seed = int.from_bytes(
                os.urandom(3),
                'little') + 1  # never use seed 0 for RNG, 0 is for `urandom`
        logger.warning("Setting random seed to %s", cls.seed)

        import numpy as np
        import tensorflow as tf
        import torch
        import random
        np.random.seed(cls.seed)
        tf.set_random_seed(np.random.randint(2**30))
        torch.manual_seed(np.random.randint(2**30))
        random.seed(np.random.randint(2**30))
        torch.cuda.manual_seed_all(np.random.randint(2**30))
        torch.backends.cudnn.deterministic = True
    def verify(self, n=10000, eps=1e-4):
        dataset = Dataset(gen_dtype(self, 'state action next_state reward done'), n)
        state = self.reset()
        for _ in range(n):
            action = self.action_space.sample()
            next_state, reward, done, _ = self.step(action)
            dataset.append((state, action, next_state, reward, done))

            state = next_state
            if done:
                state = self.reset()

        rewards_, dones_ = self.mb_step(dataset.state, dataset.action, dataset.next_state)
        diff = dataset.reward - rewards_
        l_inf = np.abs(diff).max()
        logger.warning('rewarder difference: %.6f', l_inf)

        np.testing.assert_allclose(dones_, dataset.done)
        assert not np.isclose(np.std(dataset.state, axis=0), 0.).any(), \
            'state.std:{}'.format(np.std(dataset.state, axis=0))
        assert l_inf < eps
Ejemplo n.º 5
0
def find_monkey_patch_keys(avoid_set=None):
    if avoid_set is None:
        avoid_set = {"shape"}  # tf.shape conflicts with Tensor.shape
    patched = []
    for key, value in tf.__dict__.items():
        if not callable(value) or key in avoid_set:
            continue
        doc = value.__doc__
        if doc is None:
            continue
        loc = doc.find('Args:\n')
        if loc == -1:
            continue

        # Am I doing NLP?
        # It seems that PyTorch has better doc. They always write `x (Tensor): ...` which is much easier to parse.
        first_arg_doc = doc[loc + 6:].split('\n')[0].split(': ')[1]
        if first_arg_doc.startswith('A `Tensor`') or first_arg_doc.startswith(
                '`Tensor`') or key.startswith('reduce_'):
            patched.append(key)
    logger.warning(f'Monkey patched TensorFlow: {patched}')
    return patched
Ejemplo n.º 6
0
def monkey_patch(avoid_set=None):
    logger.warning('Monkey patching TensorFlow...')

    patched = [
        'abs', 'acos', 'acosh', 'add', 'angle', 'argmax', 'argmin', 'asin',
        'asinh', 'atan', 'atan2', 'atanh', 'betainc', 'cast', 'ceil',
        'check_numerics', 'clip_by_average_norm', 'clip_by_norm',
        'clip_by_value', 'complex', 'conj', 'cos', 'cosh', 'cross', 'cumprod',
        'cumsum', 'dequantize', 'diag', 'digamma', 'div', 'equal', 'erf',
        'erfc', 'exp', 'expand_dims', 'expm1', 'fill', 'floor', 'floor_div',
        'floordiv', 'floormod', 'gather', 'gather_nd', 'greater',
        'greater_equal', 'hessians', 'identity', 'igamma', 'igammac', 'imag',
        'is_finite', 'is_inf', 'is_nan', 'less', 'less_equal', 'lgamma', 'log',
        'log1p', 'logical_and', 'logical_not', 'logical_or', 'matmul',
        'maximum', 'meshgrid', 'minimum', 'mod', 'multiply', 'negative',
        'norm', 'not_equal', 'one_hot', 'ones_like', 'pad', 'polygamma', 'pow',
        'quantize', 'real', 'realdiv', 'reciprocal', 'reduce_all',
        'reduce_any', 'reduce_logsumexp', 'reduce_max', 'reduce_mean',
        'reduce_min', 'reduce_prod', 'reduce_sum', 'reshape', 'reverse',
        'rint', 'round', 'rsqrt', 'scatter_nd', 'sign', 'sin', 'sinh', 'size',
        'slice', 'sqrt', 'square', 'squeeze', 'stop_gradient', 'subtract',
        'tan', 'tensordot', 'tile', 'to_bfloat16', 'to_complex128',
        'to_complex64', 'to_double', 'to_float', 'to_int32', 'to_int64',
        'transpose', 'truediv', 'truncatediv', 'truncatemod', 'unique',
        'where', 'zeros_like', 'zeta'
    ]
    alias = {
        'mul': 'multiply',
        'sub': 'subtract',
    }

    # use the code below for more ops
    # patched = find_monkey_patch_keys(avoid_set)

    for key, method in list(zip(patched, patched)) + list(alias.items()):
        value = tf.__dict__[method]
        setattr(tf.Tensor, key, value)
        setattr(tf.Variable, key, value)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    collect_mb = FLAGS.env.env_type == 'mb'
    if collect_mb:
        env_id = 'MB' + FLAGS.env.id
        logger.warning('Collect dataset for imitating environments')
    else:
        env_id = FLAGS.env.id
        logger.warning('Collect dataset for imitating policies')
    env = create_env(env_id,
                     FLAGS.seed,
                     FLAGS.log_dir,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)

    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.info('Load policy from %s' % FLAGS.ckpt.policy_load)

    state_traj, action_traj, next_state_traj, reward_traj, len_traj = [], [], [], [], []
    returns = []
    while len(state_traj) < 50:
        states = np.zeros([env.max_episode_steps, dim_state], dtype=np.float32)
        actions = np.zeros([env.max_episode_steps, dim_action],
                           dtype=np.float32)
        next_states = np.zeros([env.max_episode_steps, dim_state],
                               dtype=np.float32)
        rewards = np.zeros([env.max_episode_steps], dtype=np.float32)
        state = env.reset()
        done = False
        t = 0
        while not done:
            action = actor.get_actions(state[None], fetch='actions_mean')
            next_state, reward, done, info = env.step(action)

            states[t] = state
            actions[t] = action
            rewards[t] = reward
            next_states[t] = next_state
            t += 1
            if done:
                break
            state = next_state
        if t < 700 or np.sum(rewards) < 0:
            continue
        state_traj.append(states)
        action_traj.append(actions)
        next_state_traj.append(next_states)
        reward_traj.append(rewards)
        len_traj.append(t)

        returns.append(np.sum(rewards))
        logger.info('# %d: collect a trajectory return = %.4f length = %d',
                    len(state_traj), np.sum(rewards), t)

    state_traj = np.array(state_traj)
    action_traj = np.array(action_traj)
    next_state_traj = np.array(next_state_traj)
    reward_traj = np.array(reward_traj)
    len_traj = np.array(len_traj)
    assert len(state_traj.shape) == len(action_traj.shape) == 3
    assert len(reward_traj.shape) == 2 and len(len_traj.shape) == 1

    dataset = {
        'a_B_T_Da': action_traj,
        'len_B': len_traj,
        'obs_B_T_Do': state_traj,
        'r_B_T': reward_traj
    }
    if collect_mb:
        dataset['next_obs_B_T_Do'] = next_state_traj
    logger.info('Expert avg return = %.4f avg length = %d', np.mean(returns),
                np.mean(len_traj))

    if collect_mb:
        root_dir = 'dataset/mb2'
    else:
        root_dir = 'dataset/sac'

    save_dir = f'{root_dir}/{FLAGS.env.id}'
    os.makedirs(save_dir, exist_ok=True)
    shutil.copy(FLAGS.ckpt.policy_load, os.path.join(save_dir, 'policy.npy'))

    save_path = f'{root_dir}/{FLAGS.env.id}.h5'
    f = h5py.File(save_path, 'w')
    f.update(dataset)
    f.close()
    logger.info('save dataset into %s' % save_path)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=float(FLAGS.BC.lr),
                                    train_std=FLAGS.BC.train_std)

    expert_actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': expert_actor})
    if FLAGS.BC.dagger:
        loader.load_state_dict(
            np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
        logger.warning('Load expert policy from %s' % FLAGS.ckpt.policy_load)
    runner = Runner(env, max_steps=env.max_episode_steps, rescale_action=False)

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    expert_batch = expert_dataset.sample(10)
    expert_state = np.stack([t.obs for t in expert_batch])
    expert_action = np.stack([t.action for t in expert_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(expert_state),
                np.mean(expert_action))
    del expert_batch, expert_state, expert_action
    set_random_seed(FLAGS.seed)

    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    batch_size = FLAGS.BC.batch_size
    eval_gamma = 0.999
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            eval_returns_discount, eval_lengths_discount = evaluate(
                policy, env_eval, gamma=eval_gamma)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths))),
                             discounted_episode=dict(
                                 returns=np.mean(eval_returns_discount),
                                 lengths=int(np.mean(eval_lengths_discount)))))

        expert_batch = expert_dataset.sample(batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              fetch='train loss grad_norm')

        if FLAGS.BC.dagger and t % FLAGS.BC.collect_freq == 0 and t > 0:
            if t // FLAGS.BC.collect_freq == 1:
                collect_policy = expert_actor
                stochastic = False
                logger.info('Collect samples with expert actor...')
            else:
                collect_policy = policy
                stochastic = True
                logger.info('Collect samples with learned policy...')
            runner.reset()
            data, ep_infos = runner.run(collect_policy,
                                        FLAGS.BC.n_collect_samples, stochastic)
            data.action = expert_actor.get_actions(data.state,
                                                   fetch='actions_mean')
            returns = [info['return'] for info in ep_infos]
            lengths = [info['length'] for info in ep_infos]
            for i in range(len(data)):
                expert_dataset.push_back(data[i].state, data[i].action,
                                         data[i].next_state, data[i].reward,
                                         data[i].mask, data[i].timeout)
            logger.info('Collect %d samples avg return = %.4f avg length = %d',
                        len(data), np.mean(returns), np.mean(lengths))
        if t % 100 == 0:
            mse_loss = policy.get_mse_loss(expert_state, expert_action)
            log_kvs(prefix='BC',
                    kvs=dict(iter=t,
                             loss=loss,
                             grad_norm=grad_norm,
                             mse_loss=mse_loss))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate(policy, env_eval, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
Ejemplo n.º 9
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    expert_done = np.stack([t.done for t in expert_dataset.buffer()])
    np.testing.assert_allclose(
        expert_next_state[:-1] * (1 - expert_done[:-1][:, None]),
        expert_state[1:] * (1 - expert_done[:-1][:, None]))
    del expert_state, expert_next_state, expert_done
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    eval_batch = expert_dataset.sample(1024)
    eval_state = np.stack([t.obs for t in eval_batch])
    eval_action = np.stack([t.action for t in eval_batch])
    eval_next_state = np.stack([t.next_obs for t in eval_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(eval_state),
                np.mean(eval_action))
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    set_random_seed(FLAGS.seed)

    # expert actor
    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)
    # generator
    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff,
                               normalizers=normalizers)
    vfn = MLPVFunction(dim_state, dim_action, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    if FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        loc, scale = np.mean(expert_state, axis=0,
                             keepdims=True), np.std(expert_state,
                                                    axis=0,
                                                    keepdims=True)
        del expert_batch, expert_state
        logger.info('loc = {}\nscale={}'.format(loc, scale))
        discriminator = Discriminator(dim_state,
                                      dim_action,
                                      normalizers=normalizers,
                                      subsampling_rate=subsampling_rate,
                                      loc=loc,
                                      scale=scale,
                                      **FLAGS.GAIL.discriminator.as_dict())
    else:
        raise NotImplementedError
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=FLAGS.BC.lr,
                                    train_std=FLAGS.BC.train_std)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.info('Load policy from %s' % FLAGS.ckpt.policy_load)
    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers,
        'discriminator': discriminator
    })
    print(saver)

    # updater normalizer
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_action = np.stack([t.action for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)
    del expert_state, expert_action, expert_next_state

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # pretrain
    for n_updates in range(FLAGS.GAIL.pretrain_iters):
        expert_batch = expert_dataset.sample(FLAGS.BC.batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        expert_next_state = np.stack([t.next_obs for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              expert_next_state,
                                              fetch='train loss grad_norm')
        if n_updates % 100 == 0:
            mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                           expert_next_state)
            logger.info(
                '[Pretrain] iter = %d grad_norm = %.4f loss = %.4f mse_loss = %.4f',
                n_updates, grad_norm, loss, mse_loss)

    # virtual env
    virtual_env = VirtualEnv(policy,
                             env,
                             n_envs=FLAGS.env.num_env,
                             stochastic_model=True)
    virtual_runner = VirtualRunner(virtual_env,
                                   max_steps=env.max_episode_steps,
                                   gamma=FLAGS.TRPO.gamma,
                                   lambda_=FLAGS.TRPO.lambda_,
                                   rescale_action=False)
    env_eval_stochastic = VirtualEnv(policy,
                                     env,
                                     n_envs=4,
                                     stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy,
                                        env,
                                        n_envs=4,
                                        stochastic_model=False)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    true_return = np.mean(eval_returns)
    for t in range(0, FLAGS.GAIL.total_timesteps,
                   FLAGS.TRPO.rollout_samples * FLAGS.GAIL.g_iters):
        time_st = time.time()
        if t % FLAGS.GAIL.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(
                prefix='Evaluate',
                kvs=dict(
                    iter=t,
                    stochastic_episode=dict(
                        returns=np.mean(eval_returns_stochastic),
                        lengths=int(np.mean(eval_lengths_stochastic))),
                    episode=dict(returns=np.mean(eval_returns_deterministic),
                                 lengths=int(
                                     np.mean(eval_lengths_deterministic))),
                    evaluation_error=dict(
                        stochastic_error=true_return -
                        np.mean(eval_returns_stochastic),
                        stochastic_abs=np.abs(
                            true_return - np.mean(eval_returns_stochastic)),
                        stochastic_rel=np.abs(true_return -
                                              np.mean(eval_returns_stochastic))
                        / true_return,
                        deterministic_error=true_return -
                        np.mean(eval_returns_deterministic),
                        deterministic_abs=np.abs(
                            true_return - np.mean(eval_returns_deterministic)),
                        deterministic_rel=np.abs(true_return - np.mean(
                            eval_returns_deterministic)) / true_return)))
        # Generator
        generator_dataset = None
        for n_update in range(FLAGS.GAIL.g_iters):
            data, ep_infos = virtual_runner.run(actor,
                                                FLAGS.TRPO.rollout_samples,
                                                stochastic=False)
            # if FLAGS.TRPO.normalization:
            #     normalizers.state.update(data.state)
            #     normalizers.action.update(data.action)
            #     normalizers.diff.update(data.next_state - data.state)
            if t == 0:
                np.testing.assert_allclose(data.reward,
                                           env.mb_step(data.state, data.action,
                                                       data.next_state)[0],
                                           atol=1e-4,
                                           rtol=1e-4)
            if t == 0 and n_update == 0 and not FLAGS.GAIL.learn_absorbing:
                data_ = data.copy()
                data_ = data_.reshape(
                    [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
                for e in range(env.n_envs):
                    samples = data_[:, e]
                    masks = 1 - (samples.done | samples.timeout)[...,
                                                                 np.newaxis]
                    masks = masks[:-1]
                    assert np.allclose(samples.state[1:] * masks,
                                       samples.next_state[:-1] * masks)
            t += FLAGS.TRPO.rollout_samples
            data.reward = discriminator.get_reward(data.state, data.action,
                                                   data.next_state)
            advantages, values = virtual_runner.compute_advantage(vfn, data)
            train_info = algo.train(max_ent_coef, data, advantages, values)
            fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
            train_info['reward'] = np.mean(data.reward)
            train_info['fps'] = fps

            expert_batch = expert_dataset.sample(256)
            expert_state = np.stack([t.obs for t in expert_batch])
            expert_action = np.stack([t.action for t in expert_batch])
            expert_next_state = np.stack([t.next_obs for t in expert_batch])
            train_mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                                 expert_next_state)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action,
                                                eval_next_state)
            train_info['mse_loss'] = dict(train=train_mse_loss,
                                          eval=eval_mse_loss)
            log_kvs(prefix='TRPO', kvs=dict(iter=t, **train_info))

            generator_dataset = data

        # Discriminator
        for n_update in range(FLAGS.GAIL.d_iters):
            batch_size = FLAGS.GAIL.d_batch_size
            d_train_infos = dict()
            for generator_subset in generator_dataset.iterator(batch_size):
                expert_batch = expert_dataset.sample(batch_size)
                expert_state = np.stack([t.obs for t in expert_batch])
                expert_action = np.stack([t.action for t in expert_batch])
                expert_next_state = np.stack(
                    [t.next_obs for t in expert_batch])
                expert_mask = None
                train_info = discriminator.train(
                    expert_state,
                    expert_action,
                    expert_next_state,
                    generator_subset.state,
                    generator_subset.action,
                    generator_subset.next_state,
                    expert_mask,
                )
                for k, v in train_info.items():
                    if k not in d_train_infos:
                        d_train_infos[k] = []
                    d_train_infos[k].append(v)
            d_train_infos = {k: np.mean(v) for k, v in d_train_infos.items()}
            if n_update == FLAGS.GAIL.d_iters - 1:
                log_kvs(prefix='Discriminator',
                        kvs=dict(iter=t, **d_train_infos))

        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(
            actor, env_eval_stochastic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff,
                               normalizers=normalizers)
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=float(FLAGS.BC.lr),
                                    train_std=FLAGS.BC.train_std)

    actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    expert_done = np.stack([t.done for t in expert_dataset.buffer()])
    np.testing.assert_allclose(
        expert_next_state[:-1] * (1 - expert_done[:-1][:, None]),
        expert_state[1:] * (1 - expert_done[:-1][:, None]))
    del expert_state, expert_next_state, expert_done
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    eval_batch = expert_dataset.sample(1024)
    eval_state = np.stack([t.obs for t in eval_batch])
    eval_action = np.stack([t.action for t in eval_batch])
    eval_next_state = np.stack([t.next_obs for t in eval_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(eval_state),
                np.mean(eval_action))
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    set_random_seed(FLAGS.seed)

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % FLAGS.ckpt.policy_load)
    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    # updater normalizer
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_action = np.stack([t.action for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)
    del expert_state, expert_action, expert_next_state

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # virtual env
    env_eval_stochastic = VirtualEnv(policy,
                                     env,
                                     n_envs=4,
                                     stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy,
                                        env,
                                        n_envs=4,
                                        stochastic_model=False)

    batch_size = FLAGS.BC.batch_size
    true_return = np.mean(eval_returns)
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(
                prefix='Evaluate',
                kvs=dict(
                    iter=t,
                    stochastic_episode=dict(
                        returns=np.mean(eval_returns_stochastic),
                        lengths=int(np.mean(eval_lengths_stochastic))),
                    episode=dict(returns=np.mean(eval_returns_deterministic),
                                 lengths=int(
                                     np.mean(eval_lengths_deterministic))),
                    evaluation_error=dict(
                        stochastic_error=true_return -
                        np.mean(eval_returns_stochastic),
                        stochastic_abs=np.abs(
                            true_return - np.mean(eval_returns_stochastic)),
                        stochastic_rel=np.abs(true_return -
                                              np.mean(eval_returns_stochastic))
                        / true_return,
                        deterministic_error=true_return -
                        np.mean(eval_returns_deterministic),
                        deterministic_abs=np.abs(
                            true_return - np.mean(eval_returns_deterministic)),
                        deterministic_rel=np.abs(true_return - np.mean(
                            eval_returns_deterministic)) / true_return)))

        expert_batch = expert_dataset.sample(batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        expert_next_state = np.stack([t.next_obs for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              expert_next_state,
                                              fetch='train loss grad_norm')

        if t % 100 == 0:
            train_mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                                 expert_next_state)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action,
                                                eval_next_state)
            log_kvs(prefix='BC',
                    kvs=dict(iter=t,
                             grad_norm=grad_norm,
                             loss=loss,
                             mse_loss=dict(train=train_mse_loss,
                                           eval=eval_mse_loss)))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(
            actor, env_eval_stochastic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
Ejemplo n.º 11
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id)
    dim_state = int(np.prod(env.observation_space.shape))
    dim_action = int(np.prod(env.action_space.shape))

    env.verify()

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)

    dtype = gen_dtype(env, 'state action next_state reward done timeout')
    train_set = Dataset(dtype, FLAGS.rollout.max_buf_size)
    dev_set = Dataset(dtype, FLAGS.rollout.max_buf_size)

    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               normalizer=normalizers.state,
                               **FLAGS.policy.as_dict())
    # batched noises
    noise = OUNoise(env.action_space,
                    theta=FLAGS.OUNoise.theta,
                    sigma=FLAGS.OUNoise.sigma,
                    shape=(1, dim_action))
    vfn = MLPVFunction(dim_state, [64, 64], normalizers.state)
    model = DynamicsModel(dim_state, dim_action, normalizers,
                          FLAGS.model.hidden_sizes)

    virt_env = VirtualEnv(model,
                          make_env(FLAGS.env.id),
                          FLAGS.plan.n_envs,
                          opt_model=FLAGS.slbo.opt_model)
    virt_runner = Runner(
        virt_env, **{
            **FLAGS.runner.as_dict(), 'max_steps': FLAGS.plan.max_steps
        })

    criterion_map = {
        'L1': nn.L1Loss(),
        'L2': nn.L2Loss(),
        'MSE': nn.MSELoss(),
    }
    criterion = criterion_map[FLAGS.model.loss]
    loss_mod = MultiStepLoss(model, normalizers, dim_state, dim_action,
                             criterion, FLAGS.model.multi_step)
    loss_mod.build_backward(FLAGS.model.lr, FLAGS.model.weight_decay)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())

    runners = {
        'test':
        make_real_runner(4),
        'collect':
        make_real_runner(1),
        'dev':
        make_real_runner(1),
        'train':
        make_real_runner(FLAGS.plan.n_envs)
        if FLAGS.algorithm == 'MF' else virt_runner,
    }
    settings = [(runners['test'], policy, 'Real Env'),
                (runners['train'], policy, 'Virt Env')]

    saver = nn.ModuleDict({'policy': policy, 'model': model, 'vfn': vfn})
    print(saver)

    if FLAGS.ckpt.model_load:
        saver.load_state_dict(np.load(FLAGS.ckpt.model_load)[()])
        logger.warning('Load model from %s', FLAGS.ckpt.model_load)

    if FLAGS.ckpt.buf_load:
        n_samples = 0
        for i in range(FLAGS.ckpt.buf_load_index):
            data = pickle.load(
                open(f'{FLAGS.ckpt.buf_load}/stage-{i}.inc-buf.pkl', 'rb'))
            add_multi_step(data, train_set)
            n_samples += len(data)
        logger.warning('Loading %d samples from %s', n_samples,
                       FLAGS.ckpt.buf_load)

    max_ent_coef = FLAGS.TRPO.ent_coef

    for T in range(FLAGS.slbo.n_stages):
        logger.info('------ Starting Stage %d --------', T)
        evaluate(settings, 'episode')

        if not FLAGS.use_prev:
            train_set.clear()
            dev_set.clear()

        # collect data
        recent_train_set, ep_infos = runners['collect'].run(
            noise.make(policy), FLAGS.rollout.n_train_samples)
        add_multi_step(recent_train_set, train_set)
        add_multi_step(
            runners['dev'].run(noise.make(policy),
                               FLAGS.rollout.n_dev_samples)[0],
            dev_set,
        )

        returns = np.array([ep_info['return'] for ep_info in ep_infos])
        if len(returns) > 0:
            logger.info("episode: %s", np.mean(returns))

        if T == 0:  # check
            samples = train_set.sample_multi_step(100, 1,
                                                  FLAGS.model.multi_step)
            for i in range(FLAGS.model.multi_step - 1):
                masks = 1 - (samples.done[i] | samples.timeout[i])[...,
                                                                   np.newaxis]
                assert np.allclose(samples.state[i + 1] * masks,
                                   samples.next_state[i] * masks)

        # recent_states = obsvs
        # ref_actions = policy.eval('actions_mean actions_std', states=recent_states)
        if FLAGS.rollout.normalizer == 'policy' or FLAGS.rollout.normalizer == 'uniform' and T == 0:
            normalizers.state.update(recent_train_set.state)
            normalizers.action.update(recent_train_set.action)
            normalizers.diff.update(recent_train_set.next_state -
                                    recent_train_set.state)

        if T == 50:
            max_ent_coef = 0.

        for i in range(FLAGS.slbo.n_iters):
            if i % FLAGS.slbo.n_evaluate_iters == 0 and i != 0:
                # cur_actions = policy.eval('actions_mean actions_std', states=recent_states)
                # kl_old_new = gaussian_kl(*ref_actions, *cur_actions).sum(axis=1).mean()
                # logger.info('KL(old || cur) = %.6f', kl_old_new)
                evaluate(settings, 'iteration')

            losses = deque(maxlen=FLAGS.slbo.n_model_iters)
            grad_norm_meter = AverageMeter()
            n_model_iters = FLAGS.slbo.n_model_iters
            for _ in range(n_model_iters):
                samples = train_set.sample_multi_step(
                    FLAGS.model.train_batch_size, 1, FLAGS.model.multi_step)
                _, train_loss, grad_norm = loss_mod.get_loss(
                    samples.state,
                    samples.next_state,
                    samples.action,
                    ~samples.done & ~samples.timeout,
                    fetch='train loss grad_norm')
                losses.append(train_loss.mean())
                grad_norm_meter.update(grad_norm)
                # ideally, we should define an Optimizer class, which takes parameters as inputs.
                # The `update` method of `Optimizer` will invalidate all parameters during updates.
                for param in model.parameters():
                    param.invalidate()

            if i % FLAGS.model.validation_freq == 0:
                samples = train_set.sample_multi_step(
                    FLAGS.model.train_batch_size, 1, FLAGS.model.multi_step)
                loss = loss_mod.get_loss(samples.state, samples.next_state,
                                         samples.action,
                                         ~samples.done & ~samples.timeout)
                loss = loss.mean()
                if np.isnan(loss) or np.isnan(np.mean(losses)):
                    logger.info('nan! %s %s', np.isnan(loss),
                                np.isnan(np.mean(losses)))
                logger.info(
                    '# Iter %3d: Loss = [train = %.3f, dev = %.3f], after %d steps, grad_norm = %.6f',
                    i, np.mean(losses), loss, n_model_iters,
                    grad_norm_meter.get())

            for n_updates in range(FLAGS.slbo.n_policy_iters):
                if FLAGS.algorithm != 'MF' and FLAGS.slbo.start == 'buffer':
                    runners['train'].set_state(
                        train_set.sample(FLAGS.plan.n_envs).state)
                else:
                    runners['train'].reset()

                data, ep_infos = runners['train'].run(
                    policy, FLAGS.plan.n_trpo_samples)
                advantages, values = runners['train'].compute_advantage(
                    vfn, data)
                dist_mean, dist_std, vf_loss = algo.train(
                    max_ent_coef, data, advantages, values)
                returns = [info['return'] for info in ep_infos]
                logger.info(
                    '[TRPO] # %d: n_episodes = %d, returns: {mean = %.0f, std = %.0f}, '
                    'dist std = %.10f, dist mean = %.10f, vf_loss = %.3f',
                    n_updates, len(returns), np.mean(returns),
                    np.std(returns) / np.sqrt(len(returns)), dist_std,
                    dist_mean, vf_loss)

        if T % FLAGS.ckpt.n_save_stages == 0:
            np.save(f'{FLAGS.log_dir}/stage-{T}', saver.state_dict())
            np.save(f'{FLAGS.log_dir}/final', saver.state_dict())
        if FLAGS.ckpt.n_save_stages == 1:
            pickle.dump(recent_train_set,
                        open(f'{FLAGS.log_dir}/stage-{T}.inc-buf.pkl', 'wb'))
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id, seed=FLAGS.seed, rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    # expert actor
    actor = Actor(dim_state, dim_action, init_std=0.)
    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    expert_state, expert_action, expert_next_state, expert_reward = collect_samples_from_true_env(
        env=env, actor=actor, nb_episode=FLAGS.GAIL.traj_limit, subsampling_rate=subsampling_rate)
    logger.info('Collect % d samples avg return = %.4f', len(expert_state), np.mean(expert_reward))
    eval_state, eval_action, eval_next_state, eval_reward = collect_samples_from_true_env(
        env=env, actor=actor, nb_episode=3, seed=FLAGS.seed)
    loc, scale = np.mean(expert_state, axis=0, keepdims=True), np.std(expert_state, axis=0, keepdims=True)
    logger.info('loc = {}\nscale={}'.format(loc, scale))

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state, dim_action, FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff, normalizers=normalizers)
    bc_loss = BehavioralCloningLoss(dim_state, dim_action, policy, lr=float(FLAGS.BC.lr), train_std=FLAGS.BC.train_std)

    tf.get_default_session().run(tf.global_variables_initializer())
    set_random_seed(FLAGS.seed)

    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    # updater normalizer
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor, env, gamma=eval_gamma)
    logger.warning('Test policy true value = %.4f true length = %d (gamma = %f)',
                   np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # virtual env
    env_eval_stochastic = VirtualEnv(policy, env, n_envs=4, stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy, env, n_envs=4, stochastic_model=False)

    batch_size = FLAGS.BC.batch_size
    true_return = np.mean(eval_returns)
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(prefix='Evaluate', kvs=dict(
                iter=t, stochastic_episode=dict(
                    returns=np.mean(eval_returns_stochastic), lengths=int(np.mean(eval_lengths_stochastic))
                ), episode=dict(
                    returns=np.mean(eval_returns_deterministic), lengths=int(np.mean(eval_lengths_deterministic))
                ),  evaluation_error=dict(
                    stochastic_error=true_return-np.mean(eval_returns_stochastic),
                    stochastic_abs=np.abs(true_return-np.mean(eval_returns_stochastic)),
                    stochastic_rel=np.abs(true_return-np.mean(eval_returns_stochastic))/true_return,
                    deterministic_error=true_return-np.mean(eval_returns_deterministic),
                    deterministic_abs=np.abs(true_return - np.mean(eval_returns_deterministic)),
                    deterministic_rel=np.abs(true_return-np.mean(eval_returns_deterministic))/true_return
                )
            ))

        indices = np.random.randint(low=0, high=len(expert_state), size=batch_size)
        expert_state_ = expert_state[indices]
        expert_action_ = expert_action[indices]
        expert_next_state_ = expert_next_state[indices]
        _, loss, grad_norm = bc_loss.get_loss(expert_state_, expert_action_, expert_next_state_,
                                              fetch='train loss grad_norm')

        if t % 100 == 0:
            train_mse_loss = policy.get_mse_loss(expert_state_, expert_action_, expert_next_state_)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action, eval_next_state)
            log_kvs(prefix='BC', kvs=dict(
                iter=t, grad_norm=grad_norm, loss=loss, mse_loss=dict(train=train_mse_loss, eval=eval_mse_loss)
            ))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(actor, env_eval_deterministic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    bc_normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    bc_policy = GaussianMLPPolicy(dim_state,
                                  dim_action,
                                  FLAGS.TRPO.policy_hidden_sizes,
                                  output_diff=FLAGS.TRPO.output_diff,
                                  normalizers=bc_normalizers)

    gail_normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    gail_policy = GaussianMLPPolicy(dim_state,
                                    dim_action,
                                    FLAGS.TRPO.policy_hidden_sizes,
                                    output_diff=FLAGS.TRPO.output_diff,
                                    normalizers=gail_normalizers)

    actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    policy_load = f'dataset/mb2/{FLAGS.env.id}/policy.npy'
    loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % policy_load)

    bc_policy_load = "benchmarks/mbrl_benchmark/mbrl2_bc_30_1000/mbrl2_bc-Walker2d-v2-100-2020-05-22-16-02-12/final.npy"
    loader = nn.ModuleDict({
        'policy': bc_policy,
        'normalizers': bc_normalizers
    })
    loader.load_state_dict(np.load(bc_policy_load, allow_pickle=True)[()])
    logger.warning('Load bc policy from %s' % bc_policy_load)

    gail_policy_load = "benchmarks/mbrl_benchmark/mbrl2_gail_grad_penalty/mbrl2_gail-Walker2d-v2-100-2020-05-22-12-10-07/final.npy"
    loader = nn.ModuleDict({
        'policy': gail_policy,
        'normalizers': gail_normalizers
    })
    loader.load_state_dict(np.load(gail_policy_load, allow_pickle=True)[()])
    logger.warning('Load gail policy from %s' % gail_policy_load)

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    real_runner = Runner(env,
                         max_steps=env.max_episode_steps,
                         rescale_action=False)
    # virtual env
    env_bc_stochastic = VirtualEnv(bc_policy,
                                   env,
                                   n_envs=1,
                                   stochastic_model=True)
    env_bc_deterministic = VirtualEnv(bc_policy,
                                      env,
                                      n_envs=1,
                                      stochastic_model=False)
    runner_bc_stochastic = VirtualRunner(env_bc_stochastic,
                                         max_steps=env.max_episode_steps,
                                         rescale_action=False)
    runner_bc_deterministic = VirtualRunner(env_bc_deterministic,
                                            max_steps=env.max_episode_steps,
                                            rescale_action=False)

    env_gail_stochastic = VirtualEnv(gail_policy,
                                     env,
                                     n_envs=1,
                                     stochastic_model=True)
    env_gail_deterministic = VirtualEnv(gail_policy,
                                        env,
                                        n_envs=1,
                                        stochastic_model=False)
    runner_gail_stochastic = VirtualRunner(env_gail_stochastic,
                                           max_steps=env.max_episode_steps)
    runner_gail_deterministic = VirtualRunner(env_gail_deterministic,
                                              max_steps=env.max_episode_steps)

    data_actor, ep_infos = real_runner.run(actor,
                                           n_samples=int(2e3),
                                           stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for actor avg return = %.4f avg length = %d',
        len(data_actor), np.mean(returns), np.mean(lengths))

    data_bc_stochastic, ep_infos = runner_bc_stochastic.run(actor,
                                                            n_samples=int(2e3),
                                                            stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for bc stochastic policy avg return = %.4f avg length = %d',
        len(data_bc_stochastic), np.mean(returns), np.mean(lengths))

    reward_ref, _ = env.mb_step(data_bc_stochastic.state,
                                data_bc_stochastic.action,
                                data_bc_stochastic.next_state)
    np.testing.assert_allclose(reward_ref,
                               data_bc_stochastic.reward,
                               rtol=1e-4,
                               atol=1e-4)

    data_bc_deterministic, ep_infos = runner_bc_deterministic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for bc deterministic policy avg return = %.4f avg length = %d',
        len(data_bc_deterministic), np.mean(returns), np.mean(lengths))

    reward_ref, _ = env.mb_step(data_bc_deterministic.state,
                                data_bc_deterministic.action,
                                data_bc_deterministic.next_state)
    np.testing.assert_allclose(reward_ref,
                               data_bc_deterministic.reward,
                               rtol=1e-4,
                               atol=1e-4)

    data_gail_stochastic, ep_infos = runner_gail_stochastic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for gail stochastic policy avg return = %.4f avg length = %d',
        len(data_gail_stochastic), np.mean(returns), np.mean(lengths))
    data_gail_deterministic, ep_infos = runner_gail_deterministic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for gail deterministic policy avg return = %.4f avg length = %d',
        len(data_bc_deterministic), np.mean(returns), np.mean(lengths))

    t_sne = manifold.TSNE(init='pca', random_state=2020)
    data = np.concatenate([
        data.state for data in [
            data_actor, data_bc_stochastic, data_bc_deterministic,
            data_gail_stochastic, data_gail_deterministic
        ]
    ],
                          axis=0)
    step = np.concatenate([
        data.step for data in [
            data_actor, data_bc_stochastic, data_bc_deterministic,
            data_gail_stochastic, data_gail_deterministic
        ]
    ],
                          axis=0)
    loc, scale = bc_normalizers.state.eval('mean std')
    data = (data - loc) / (1e-6 + scale)
    embedding = t_sne.fit_transform(data)

    fig, axarrs = plt.subplots(nrows=1,
                               ncols=5,
                               figsize=[6 * 5, 4],
                               squeeze=False,
                               sharex=True,
                               sharey=True,
                               dpi=300)
    start = 0
    indices = 0
    g2c = {}
    for title in [
            'expert', 'bc_stochastic', 'bc_deterministic', 'gail_stochastic',
            'gail_deterministic'
    ]:
        g2c[title] = axarrs[0][indices].scatter(embedding[start:start + 2000,
                                                          0],
                                                embedding[start:start + 2000,
                                                          1],
                                                c=step[start:start + 2000])
        axarrs[0][indices].set_title(title)
        indices += 1
        start += 2000
    plt.colorbar(list(g2c.values())[0], ax=axarrs.flatten())
    plt.tight_layout()
    plt.savefig(f'{FLAGS.log_dir}/visualize.png', bbox_inches='tight')

    data = {
        'expert': data_actor.state,
        'bc_stochastic': data_bc_stochastic.state,
        'bc_deterministic': data_bc_deterministic.state,
        'gail_stochastic': data_gail_stochastic.state,
        'gail_deterministic': data_gail_deterministic.state
    }
    np.savez(f'{FLAGS.log_dir}/data.npz', **data)
 def set_state(self, state):
     logger.warning('`set_state` is not implemented')
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)

    tf.get_default_session().run(tf.global_variables_initializer())

    expert_result_path = os.path.join('logs', 'expert-%s.yml' % FLAGS.env.id)
    if not os.path.exists(expert_result_path):
        expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
        expert_reward = expert_dataset.get_average_reward()
        logger.info('Expert Reward %f', expert_reward)
        if FLAGS.GAIL.learn_absorbing:
            expert_dataset.add_absorbing_states(env)

        expert_result = dict()
        for gamma in [0.9, 0.99, 0.999, 1.0]:
            expert_returns = []
            discount = 1.
            expert_return = 0.
            for timestep in expert_dataset.buffer():
                expert_return += discount * timestep.reward[0]
                discount *= gamma
                if timestep.done:
                    expert_returns.append(float(expert_return))
                    discount = 1.
                    expert_return = 0.
            expert_result[gamma] = [
                float(np.mean(expert_returns)), expert_returns
            ]
            logger.info('Expert gamma = %f %.4f (n_episode = %d)', gamma,
                        np.mean(expert_returns), len(expert_returns))
        yaml.dump(expert_result,
                  open(expert_result_path, 'w'),
                  default_flow_style=False)

    # loader policy
    loader = nn.ModuleDict({'policy': policy})
    root_dir = 'logs/gail_l2'
    for save_dir in sorted(os.listdir(root_dir)):
        if FLAGS.env.id not in save_dir:
            continue
        policy_load = os.path.join(root_dir, save_dir, 'stage-3000000.npy')
        loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
        logger.warning('Load {} from {}'.format(loader.keys(), policy_load))

        dict_result = dict()
        for gamma in [0.9, 0.99, 0.999, 1.0]:
            eval_returns, eval_lengths = evaluate(policy,
                                                  env_eval,
                                                  gamma=gamma)
            dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
            logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

        save_path = os.path.join(root_dir, save_dir, 'evaluate.yml')
        yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)