Exemplo n.º 1
0
def make_env(env_id: str, env_type: str, num_env: int, seed: int, log_dir: str,
             **kwargs):
    if env_type == 'atari':
        make_thunk = make_atari_env
    elif env_type == 'mujoco':
        if kwargs.get('rescale_action', None):
            logger.info('MuJoCo Rescale action...')
        make_thunk = functools.partial(make_mujoco_env,
                                       rescale_action=kwargs.get(
                                           'rescale_action', False))
    else:
        make_thunk = make_gym_env
    if num_env == 1:
        env = DummyVecEnv([
            functools.partial(make_thunk,
                              env_id=env_id,
                              seed=seed,
                              index=index,
                              log_dir=log_dir) for index in range(num_env)
        ])
    else:
        env = SubprocVecEnv([
            functools.partial(make_thunk,
                              env_id=env_id,
                              seed=seed,
                              index=index,
                              log_dir=log_dir) for index in range(num_env)
        ])
    global COUNT
    COUNT += num_env
    return env
def parse(cls):
    global _initialized

    if _initialized:
        return
    parser = argparse.ArgumentParser(
        description='Stochastic Lower Bound Optimization')
    parser.add_argument('-c',
                        '--config',
                        type=str,
                        help='configuration file (YAML)',
                        nargs='+',
                        action='append')
    parser.add_argument('-s',
                        '--set',
                        type=str,
                        help='additional options',
                        nargs='*',
                        action='append')

    args, unknown = parser.parse_known_args()
    for a in unknown:
        logger.info('unknown arguments: %s', a)
    # logger.info('parsed arguments = %s, unknown arguments: %s', args, unknown)
    if args.config:
        for config in sum(args.config, []):
            cls.merge(yaml.load(open(expand(config))))
    else:
        logger.info('no config file specified.')
    if args.set:
        for instruction in sum(args.set, []):
            path, *value = instruction.split('=')
            cls.set_value(path.split('.'), yaml.load('='.join(value)))

    _initialized = True
Exemplo n.º 3
0
def conj_grad(mat_mul_vec: Callable[[np.ndarray], np.ndarray],
              b,
              n_iters=10,
              residual_tol=1e-10,
              verbose=False):
    p = b.copy()
    r = b.copy()
    x = np.zeros_like(b)
    r_dot_r = r.dot(r)

    for i in range(n_iters):
        if verbose:
            logger.info('[CG] iters = %d, |Res| = %.6f, |x| = %.6f', i,
                        r_dot_r, np.linalg.norm(x))
        z = mat_mul_vec(p)
        v = r_dot_r / p.dot(z)
        x += v * p
        r -= v * z
        new_r_dot_r = r.dot(r)
        if new_r_dot_r < residual_tol:
            break
        mu = new_r_dot_r / r_dot_r
        p = r + mu * p
        r_dot_r = new_r_dot_r
    return x, r_dot_r
Exemplo n.º 4
0
 def validate_advantage(self, advantages_pre, adv_mean, adv_std, feed_dict,
                        name):
     advantages_post = (advantages_pre - adv_mean) / np.maximum(
         adv_std, 1e-8)
     advantages_params_post, advantages_params_pre, advp_mean, advp_std, returns_params, r_mean, r_std = tf.get_default_session(
     ).run([
         self.advantages_params, self.advantages_params_pre, self.adv_mean,
         self.adv_std, self.returns_params, self.r_mean, self.r_std
     ],
           feed_dict=feed_dict)
     print(
         '=====================================================%s====================================================='
         % name)
     logger.info('Task goal_vel: %s', self.task.goal_velocity)
     print('advp:', advp_mean, advp_std)
     print('adv:', adv_mean, adv_std)
     #print ("returns_params:", returns_params.tolist())
     diff_pre = np.linalg.norm(advantages_pre - advantages_params_pre)
     diff_post = np.linalg.norm(advantages_post - advantages_params_post)
     print("diff_pre: %f, diff_post: %f" % (diff_pre, diff_post))
     print("advantage_pre[0:5]:", advantages_pre[0:5])
     print("advantage_params_pre[0:5]:", advantages_params_pre[0:5])
     print("advantage_post[0:5]:", advantages_post[0:5])
     print("advantage_params_post[0:5]:", advantages_params_post[0:5])
     print(
         '===================================================================================================================='
     )
def collect_samples_from_true_env(env, actor, nb_episode=50, subsampling_rate=1, seed=2020):
    set_random_seed(seed)
    state_traj, action_traj, next_state_traj, reward_traj = [], [], [], []
    episode = 0
    while episode < nb_episode:
        state = env.reset()
        done = False
        t = 0
        return_ = 0
        while not done:
            action = actor.get_actions(state[None], fetch='actions_mean')[0]
            next_state, reward, done, info = env.step(action)
            return_ += reward
            state_traj.append(state)
            action_traj.append(action)
            next_state_traj.append(next_state)
            reward_traj.append(reward)
            t += 1
            if done:
                break
            state = next_state
        episode += 1
        logger.info('Collect a trajectory return = %.4f length = %d', return_, t)
    state_traj = np.array(state_traj)[::subsampling_rate]
    action_traj = np.array(action_traj)[::subsampling_rate]
    next_state_traj = np.array(next_state_traj)[::subsampling_rate]
    reward_traj = np.array(reward_traj)[::subsampling_rate]
    return state_traj, action_traj, next_state_traj, reward_traj
Exemplo n.º 6
0
def evaluate(settings, tag):
    for runner, policy, name in settings:
        runner.reset()
        debug_info, ep_infos = runner.run(policy, FLAGS.rollout.n_test_samples)
        returns = np.array([ep_info['return'] for ep_info in ep_infos])
        logger.info(
            'Tag = %s, Reward on %s (%d episodes): mean = %.6f, std = %.6f',
            tag, name, len(returns), np.mean(returns), np.std(returns))
Exemplo n.º 7
0
 def enable_flat(self):
     params = self.params
     logger.info('Enabling flattening... %s', [p.name for p in params])
     n_params = n_parameters(params)
     feed_flat = tf.placeholder(tf.float32, [n_params])
     get_flat = parameters_to_vector(params)
     set_flat = tf.group(*[tf.assign(param, value) for param, value in
                         zip(params, vector_to_parameters(feed_flat, params))])
     return feed_flat, set_flat, get_flat
Exemplo n.º 8
0
def load_expert_dataset(load_dir):
    logger.info('Load dataset from %s' % load_dir)
    expert_replay_buffer_var = tf.Variable('', name='expert_replay_buffer')
    saver = tf.train.Saver([expert_replay_buffer_var])
    last_checkpoint = os.path.join(load_dir, 'expert_replay_buffer')
    sess = tf.get_default_session()
    saver.restore(sess, last_checkpoint)
    expert_replay_buffer = pickle.loads(sess.run(expert_replay_buffer_var))
    return expert_replay_buffer
Exemplo n.º 9
0
Arquivo: flags.py Projeto: liziniu/RLX
    def finalize(cls):
        log_dir = cls.log_dir
        if log_dir is None:
            run_id = cls.run_id
            if run_id is None:
                run_id = '{}-{}-{}-{}'.format(
                    cls.algorithm, cls.env.id, cls.seed,
                    time.strftime('%Y-%m-%d-%H-%M-%S'))

            log_dir = os.path.join("logs", run_id)
            cls.log_dir = log_dir

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        assert cls.TRPO.rollout_samples % cls.env.num_env == 0

        if os.path.exists('.git'):
            for t in range(10):
                try:
                    if sys.platform == 'linux':
                        cls.commit = check_output(['git', 'rev-parse', 'HEAD'
                                                   ]).decode('utf-8').strip()
                        check_output(['git', 'add', '.'])
                        check_output([
                            'git', 'checkout-index', '-a', '-f',
                            '--prefix={}/src/'.format(cls.log_dir)
                        ])
                        open(os.path.join(log_dir, 'diff.patch'), 'w').write(
                            check_output(['git', '--no-pager', 'diff',
                                          'HEAD']).decode('utf-8'))
                    else:
                        check_output([
                            'git', 'checkout-index', '-a',
                            '--prefix={}/src/'.format(cls.log_dir)
                        ])
                    break
                except Exception as e:
                    print(e)
                    print('Try again...')
                time.sleep(1)
            else:
                raise RuntimeError('Failed after 10 trials.')

        yaml.dump(cls.as_dict(),
                  open(os.path.join(log_dir, 'config.yml'), 'w'),
                  default_flow_style=False)
        # logger.add_sink(FileSink(os.path.join(log_dir, 'log.json')))
        logger.add_sink(FileSink(os.path.join(log_dir, 'log.txt')))
        logger.add_csvwriter(CSVWriter(os.path.join(log_dir, 'progress.csv')))
        logger.info("log_dir = %s", log_dir)

        cls.set_frozen()
Exemplo n.º 10
0
def conj_grad(mat_mul_vec: Callable[[np.ndarray], np.ndarray],
              b,
              n_iters=10,
              residual_tol=1e-6,
              verbose=False):
    print("b.norm in CG:", norm(b))
    p = b.copy()
    r = b.copy()
    x = np.zeros_like(b)  #+ 1e-4
    r_dot_r = r.dot(r)

    x_list = [x]
    x_norm_list = [np.linalg.norm(x)]
    res_list = [r_dot_r]
    res_plot = []
    res_plot.append(r_dot_r.copy())
    for i in range(n_iters):
        xAx = p.dot(mat_mul_vec(p))  #x^T A^TA x > 0
        if verbose:
            logger.info(
                '[CG] iters = %d, |Res| = %.6f, |x| = %.6f, x^TAx = %.6f', i,
                r_dot_r, np.linalg.norm(x), xAx)
        z = mat_mul_vec(p)
        old_p = p
        old_r_dot_r = r_dot_r
        v = r_dot_r / p.dot(z)
        x += v * p  #generate new guess
        r -= v * z
        new_r_dot_r = r.dot(r)
        if new_r_dot_r < residual_tol:
            break
        mu = new_r_dot_r / r_dot_r
        p = r + mu * p  #generate new conjugate direction
        r_dot_r = new_r_dot_r

        x_list.append(x.copy())
        x_norm_list.append(np.linalg.norm(x))
        res_list.append(r_dot_r)
        res_plot.append(r_dot_r.copy())
    if verbose:
        logger.info('[CG] iters = %d, |Res| = %.6f, |x| = %.6f', n_iters,
                    r_dot_r, np.linalg.norm(x))
    idx = np.argmin(res_list)
    x = x_list[idx]
    res = res_list[idx]
    print(f"res = {res}, norm = {x_norm_list[idx]}")
    print("x = ", x)
    print('----------------------------------------------------')
    return x, res, res_plot
Exemplo n.º 11
0
    def update(self, samples: np.ndarray):
        old_mean, old_std, old_n = self.op_mean.numpy(), self.op_std.numpy(), self.op_n.numpy()
        samples = samples - old_mean

        m = samples.shape[0]
        delta = samples.mean(axis=0)
        new_n = old_n + m
        new_mean = old_mean + delta * m / new_n
        new_std = np.sqrt((old_std**2 * old_n + samples.var(axis=0) * m + delta**2 * old_n * m / new_n) / new_n)

        kl_old_new = gaussian_kl(new_mean, new_std, old_mean, old_std).sum()
        self.load_state_dict({'op_mean': new_mean, 'op_std': new_std, 'op_n': new_n})

        if self._verbose:
            logger.info("updating Normalizer<%s>, KL divergence = %.6f", self.name, kl_old_new)
Exemplo n.º 12
0
    def verify(self, n=2000, eps=1e-4):
        dataset = Dataset(gen_dtype(self, 'state action next_state reward done'), n)
        state = self.reset()
        for _ in range(n):
            action = self.action_space.sample()
            next_state, reward, done, _ = self.step(action)
            dataset.append((state, action, next_state, reward, done))

            state = next_state
            if done:
                state = self.reset()

        rewards_, dones_ = self.mb_step(dataset.state, dataset.action, dataset.next_state)
        diff = dataset.reward - rewards_
        l_inf = np.abs(diff).max()
        logger.info('rewarder difference: %.6f', l_inf)

        assert np.allclose(dones_, dataset.done)
        assert l_inf < eps
Exemplo n.º 13
0
    def compute_returns_advantages(self, next_value_preds, use_gae=False):
        """Compute returns for trajectory."""

        logger.info('Computing returns and advantages...')

        # TODO(agrawalk): Add more tests and asserts.
        batch = TimeStepAdv(*zip(*self._buffer))
        reward = np.stack(batch.reward).squeeze()
        value_preds = np.stack(batch.value_preds).squeeze()
        returns = np.stack(batch.returns).squeeze()
        mask = np.stack(batch.mask).squeeze()
        # effective_traj_len = traj_len - 2
        # This takes into account:
        #   - the extra observation in buffer.
        #   - 0-indexing for the transitions.
        effective_traj_len = len(reward) - 2

        if use_gae:
            value_preds[-1] = next_value_preds
            gae = 0
            for step in range(effective_traj_len, -1, -1):
                delta = (reward[step] +
                         self.gamma * value_preds[step + 1] * mask[step] -
                         value_preds[step])
                gae = delta + self.gamma * self.tau * mask[step] * gae
                returns[step] = gae + value_preds[step]
        else:
            returns[-1] = next_value_preds
            for step in range(effective_traj_len, -1, -1):
                returns[step] = (reward[step] +
                                 self.gamma * returns[step + 1] * mask[step])

        advantages = returns - value_preds
        keys = ['value_preds', 'returns', 'advantages']
        values = [
            list(entry) for entry in zip(  # pylint: disable=g-complex-comprehension
                value_preds.reshape(-1, 1), returns.reshape(-1, 1),
                advantages.reshape(-1, 1))
        ]
        self.update_buffer(keys, values)

        self._buffer = self._buffer[:-1]
Exemplo n.º 14
0
 def eval(self, fetch: str, **feed: Dict[str, np.ndarray]):
     cache_key = f'[{" ".join(feed.keys())}] => [{fetch}]'
     if cache_key not in self._callables:
         logger.info('[%s] is making TensorFlow callables, key = %s',
                     self.__class__.__name__, cache_key)
         feed_ops = []
         for key in feed.keys():
             feed_ops.append(self.__dict__['op_' + key])
         if isinstance(fetch, str):
             fetch_ops = [
                 self.__dict__['op_' + key] for key in fetch.split(' ')
             ]
             if len(fetch_ops) == 1:
                 fetch_ops = fetch_ops[0]
         else:
             fetch_ops = fetch
         self.register_callable(
             cache_key,
             tf.get_default_session().make_callable(fetch_ops, feed_ops))
     return self._callables[cache_key](*feed.values())
    def build(self, expert_obs, expert_acs):
        self.expert_obs = expert_obs
        self.expert_acs = expert_acs
        inputs = np.concatenate([self.expert_obs, self.expert_acs], axis=1)
        self.normalizer.update(inputs)

        self.normalizer_mean, self.normalizer_std = self.normalizer.eval(fetch='mean std')
        self.normalizer_updated = False
        logger.info('mean: {}'.format(self.normalizer_mean))
        logger.info('std:{}'.format(self.normalizer_std))

        self.expert_featexp = self._compute_featexp(self.expert_obs, self.expert_acs)
        feat_dim = self.expert_featexp.shape[0]
        if self.simplex:
            self.widx = np.random.randint(feat_dim)
        else:
            self.w = np.random.randn(feat_dim)
            self.w /= np.linalg.norm(self.w) + 1e-8

        self.reward_bound = 0.
        self.gap = 0.
Exemplo n.º 16
0
    def finalize(cls):
        log_dir = cls.log_dir
        if log_dir is None:
            run_id = cls.run_id
            if run_id is None:
                run_id = time.strftime('%Y-%m-%d_%H-%M-%S')

            log_dir = os.path.join(cls.ckpt.base, run_id)
            cls.log_dir = log_dir

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        for t in range(60):
            try:
                cls.commit = check_output(['git', 'rev-parse',
                                           'HEAD']).decode('utf-8').strip()
                check_output(['git', 'add', '.'])
                check_output([
                    'git', 'checkout-index', '-a', '-f',
                    f'--prefix={log_dir}/src/'
                ])
                break
            except CalledProcessError:
                pass
            time.sleep(1)
        else:
            raise RuntimeError('Failed after 60 trials.')

        yaml.dump(cls.as_dict(),
                  open(os.path.join(log_dir, 'config.yml'), 'w'),
                  default_flow_style=False)
        open(os.path.join(log_dir, 'diff.patch'), 'w').write(
            check_output(['git', '--no-pager', 'diff',
                          'HEAD']).decode('utf-8'))

        logger.add_sink(FileSink(os.path.join(log_dir, 'log.json')))
        logger.info("log_dir = %s", log_dir)

        cls.set_frozen()
Exemplo n.º 17
0
def conj_grad(mat_mul_vec: Callable[[np.ndarray], np.ndarray],
              b,
              n_iters=10,
              residual_tol=1e-10,
              verbose=False):
    #print ("b.norm:", norm(b))
    p = b.copy()
    r = b.copy()
    x = np.zeros_like(b)
    r_dot_r = r.dot(r)

    #print ("b.dtype", b.dtype)

    for i in range(n_iters):
        if verbose:
            last_r_dot_r = r_dot_r
            logger.info('[CG] iters = %d, |Res| = %.6f, |x| = %.6f', i,
                        r_dot_r, np.linalg.norm(x))
        z = mat_mul_vec(p)
        #print ("z.dtype", z.dtype)
        old_p = p
        old_r_dot_r = r_dot_r
        v = r_dot_r / p.dot(z)
        x += v * p
        r -= v * z
        new_r_dot_r = r.dot(r)
        if new_r_dot_r < residual_tol:
            break
        mu = new_r_dot_r / r_dot_r
        p = r + mu * p
        r_dot_r = new_r_dot_r
        if verbose:
            print(
                f'z={norm(z)}, p={norm(old_p)}, r^2={old_r_dot_r}, pz={old_p.dot(z)}, v={v}, x={norm(x)}, r={norm(r)}, mu={mu}, newp={norm(p)}, newr^2={new_r_dot_r}'
            )
    #print ("x.dtype", x.dtype)
    return x, old_r_dot_r
Exemplo n.º 18
0
    def train(self, train_runner, collect_runner, warmup_collect_virt,
              warmup_collect_real, optimal_collect_real, returns_pre_warmup,
              val_losses_warmup, val_losses_slbo, train_losses_warmup,
              train_losses_slbo, surprisal, trpo_warmup, trpo_slbo,
              infofilename, extra_runners):
        #tf.get_default_session().run(tf.variables_initializer(self.warmup_policy.parameters()))
        self.sync()
        self.print_params_norm()
        self.invalidate()

        logger.info(
            '--------------------------------------------------- Update Task Parameter ------------------------------------------------'
        )
        self.task_num += 1

        sanitycheck = {}

        # collect virtual data
        logger.info('Rollout on virtual env, policy hat')
        train_runner.reset()
        data, ep_infos = train_runner.run(self.warmup_policy, self.nsample)
        advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = train_runner.compute_advantage(
            self.warmup_vfn, data, self.task)

        # collect real data
        #logger.info('Rollout on real env, policy hat')
        #collect_runner.reset()
        #data, ep_infos = collect_runner.run(self.warmup_policy, self.nsample)
        #advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = collect_runner.compute_advantage(self.warmup_vfn, data, self.task)
        logp = self.warmup_policy(data.state).log_prob(
            data.action).reduce_sum(axis=1).reduce_mean()
        logp = tf.get_default_session().run(logp)
        print("state_mean:", np.mean(data.state))
        print("action_mean:", np.mean(data.action))
        print("warmup_logpac_mean in ADVTASK:", logp)
        #ep_infos, data, advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = warmup_collect_virt
        policy_over_task, cg_residual = self.get_grad_hatenv_hatpolicy(
            data, advantages, values, td, coef_mat, coef_mat_returns,
            reward_ctrl, reward_state, begin_mark)
        virtual_returns_post_warmup = np.mean(
            [info['return'] for info in ep_infos])
        print("policy_over_task.shape:", policy_over_task.shape)
        #exit(0)

        logger.info('Rollout on real env, policy hat')
        collect_runner.reset()
        data, ep_infos = collect_runner.run(self.warmup_policy, self.nsample)
        advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = collect_runner.compute_advantage(
            self.warmup_vfn, data, self.task)
        #ep_infos, data, advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = warmup_collect_real
        returns_post_warmup = np.mean([info['return'] for info in ep_infos])
        flat_grad_policy, flat_grad_task = self.get_grad_optenv_hatpolicy(
            data, advantages, values, td, coef_mat, coef_mat_returns,
            reward_ctrl, reward_state, begin_mark)
        print("flat_grad_policy.shape:", flat_grad_policy.shape)
        print("flat_grad_task.shape:", flat_grad_task.shape)

        logger.info('Rollout on real env, policy star')
        # collect data with optimal policy
        collect_runner.reset()
        data, ep_infos = collect_runner.run(self.policy, self.nsample)
        advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = collect_runner.compute_advantage(
            self.vfn, data, self.task)
        #ep_infos, data, advantages, advantages_params, values, td, coef_mat, coef_mat_returns, reward_ctrl, reward_state, begin_mark = optimal_collect_real
        returns_post_slbo = np.mean([info['return'] for info in ep_infos])
        flat_grad_task_opt = self.get_grad_optenv_optpolicy(
            data, advantages, values, td, coef_mat, coef_mat_returns,
            reward_ctrl, reward_state, begin_mark)
        print("flat_grad_task_opt.shape:", flat_grad_task_opt.shape)

        ### Compute
        matmul = np.reshape(np.matmul(policy_over_task, flat_grad_policy),
                            (self.task.n_dim, ))
        final_grad = (flat_grad_task_opt - flat_grad_task) - matmul

        #max L over psi -> get gradient ascent
        goal_velocity = self.task.goal_velocity
        task_params_before = tf.get_default_session().run(
            self.task.parameters())
        task_params_after = tf.get_default_session().run(
            self.task.parameters())
        # Perform line search
        #self.task.set_parameters(task_params_after)
        goal_velocity_after = tf.get_default_session().run(
            self.task.goal_velocity_params)
        print("final_grad:", final_grad.shape)
        print("before(delete):", task_params_before.shape)
        print("after(delete):", task_params_after.shape)

        _, ep_infos = train_runner.run(self.policy, self.nsample)
        virtual_returns_post_slbo = np.mean(
            [info['return'] for info in ep_infos])
        diff = returns_post_slbo - returns_post_warmup
        virtual_diff = virtual_returns_post_slbo - virtual_returns_post_warmup

        warmup_params = tf.get_default_session().run(
            nn.utils.parameters_to_vector(self.warmup_policy.parameters()))
        policy_params = tf.get_default_session().run(
            nn.utils.parameters_to_vector(self.policy.parameters()))
        warmup_policy_norm = np.linalg.norm(warmup_params)
        policy_norm = np.linalg.norm(policy_params)
        print("alpha and beta:", self.alpha, self.beta)
        print('warmup_policy_norm:', warmup_policy_norm)
        print('policy_norm:', policy_norm)
        print("logpac_norm:", self.logpac_norm)
        print("pg_norm:", self.pg_norm)
        print("Ax_norm:", self.Ax_norm)
        print("b_norm:", self.b_norm)
        print("cg_residual:", cg_residual)
        x_norm = self.policy_over_task_norm
        print("x_norm (policy_over_task_norm):", self.policy_over_task_norm)
        flat_grad_policy_norm = np.linalg.norm(flat_grad_policy)
        print("flat_grad_policy_norm:", np.linalg.norm(flat_grad_policy))
        print("flat_grad_task_opt:", flat_grad_task_opt)
        print("flat_grad_task:", flat_grad_task)
        print("minus:", flat_grad_task_opt - flat_grad_task)
        print("matmul:", matmul)
        print("final_grad:", final_grad)
        print(
            f"loss and adv: {self.policy_loss_value}, {self.quad_loss_value}, {self.return_loss_value}, {self.return_mean_value}, {self.adv_mean_value}"
        )
        print(f'task_params: {task_params_before} -> {task_params_after}')
        print(f'goal_vel: {goal_velocity} -> {goal_velocity_after}')
        print(
            f'returns_pre_warmup={returns_pre_warmup}, returns_post_warmup={returns_post_warmup}, returns_post_slbo={returns_post_slbo}, real_returns_diff={diff}'
        )
        print(
            f'virtual_returns_post_warmup={virtual_returns_post_warmup}, virtual_returns_post_slbo={virtual_returns_post_slbo}, virtual_diff={virtual_diff}'
        )
        print(
            f'val_losses_warmup={np.mean(val_losses_warmup)}, val_losses_slbo={np.mean(val_losses_slbo)}'
        )
        print(
            f'train_losses_warmup={np.mean(train_losses_warmup)}, train_losses_slbo={np.mean(train_losses_slbo)}'
        )
        print(
            f'#val_losses_warmup={len(val_losses_warmup)}, #val_losses_slbo={len(val_losses_slbo)}'
        )
        print(f'surprisal={surprisal}')
        trpo_slbo = np.array(trpo_slbo)
        print('=======================')
        print(trpo_slbo.shape)
        trpo_plot = np.mean(trpo_slbo, 0).tolist()
        print("trpo_plot:", trpo_plot)

        self.info['goal_velocity'].append(goal_velocity)
        self.info['matmul'].append(matmul)
        self.info['flat_grad_task'].append(flat_grad_task)
        self.info['flat_grad_task_opt'].append(flat_grad_task_opt)
        self.info['final_grad'].append(final_grad)
        self.info['task_params_before'].append(task_params_before)
        self.info['task_params_after'].append(task_params_after)
        self.info['returns_pre_warmup'].append(returns_pre_warmup)
        self.info['returns_post_warmup'].append(returns_post_warmup)
        self.info['returns_post_slbo'].append(returns_post_slbo)
        self.info['diff'].append(diff)
        self.info['virtual_returns_post_warmup'].append(
            virtual_returns_post_warmup)
        self.info['virtual_returns_post_slbo'].append(
            virtual_returns_post_slbo)
        self.info['virtual_diff'].append(virtual_diff)
        self.info['cg_residual'].append(cg_residual)
        self.info['val_losses_warmup'].append(np.mean(val_losses_warmup))
        self.info['val_losses_slbo'].append(np.mean(val_losses_slbo))
        self.info['train_losses_warmup'].append(np.mean(train_losses_warmup))
        self.info['train_losses_slbo'].append(np.mean(train_losses_slbo))
        self.info['surprisal'].append(surprisal)
        self.info['warmup_policy_norm'].append(warmup_policy_norm)
        self.info['policy_norm'].append(policy_norm)
        self.info['logpac_norm'].append(self.logpac_norm)
        self.info['pg_norm'].append(self.pg_norm)
        self.info['Ax_norm'].append(self.Ax_norm)
        self.info['b_norm'].append(self.b_norm)
        self.info['x_norm'].append(x_norm)
        self.info['flat_grad_policy_norm'].append(flat_grad_policy_norm)
        self.info['policy_loss_value'].append(self.policy_loss_value)
        self.info['quad_loss_value'].append(self.quad_loss_value)
        self.info['return_loss_value'].append(self.return_loss_value)
        self.info['return_mean_value'].append(self.return_mean_value)
        self.info['adv_mean_value'].append(self.adv_mean_value)
        self.info['sanitycheck'].append(sanitycheck)
        self.invalidate()
        norm_g = np.linalg.norm(final_grad)
        print(f"norm_g = {norm_g}")
        print(f"task_num = {self.task_num}")
        return task_params_before, final_grad, self.info
Exemplo n.º 19
0
    def __init__(self,
                 dim_state,
                 dim_action,
                 policy,
                 vfn,
                 warmup_policy,
                 warmup_vfn,
                 task,
                 cg_damping=0.1,
                 n_cg_iters=200,
                 alpha=1.0,
                 beta=1.0,
                 nsample=8000,
                 atype='adv'):
        super().__init__()
        self.keys = [
            'restartnext', 'goal_velocity', 'matmul', 'flat_grad_task',
            'flat_grad_task_opt', 'final_grad', 'task_params_before',
            'task_params_after', 'returns_pre_warmup', 'returns_post_warmup',
            'returns_post_slbo', 'diff', 'virtual_returns_post_warmup',
            'virtual_returns_post_slbo', 'virtual_diff', 'cg_residual',
            'val_losses_warmup', 'val_losses_slbo', 'train_losses_warmup',
            'train_losses_slbo', 'surprisal', 'warmup_policy_norm',
            'policy_norm', 'logpac_norm', 'pg_norm', 'Ax_norm', 'b_norm',
            'x_norm', 'flat_grad_policy_norm', 'policy_loss_value',
            'quad_loss_value', 'return_loss_value', 'return_mean_value',
            'adv_mean_value', 'trpo_loss', 'trpo_kl', 'trpo_g', 'trpo_x',
            'trpo_Ax', 'trpo_res', 'alpha', 'sanitycheck'
        ]
        self.info = dict()
        for k in self.keys:
            self.info[k] = []
        self.info['cg_plot'] = {}
        self.info['cg_plot']['mine'] = []

        self.cg_damping = cg_damping
        self.AAx = True
        self.meanAAx = True
        self.task_num = 0
        self.advnormalize = False
        self.retnormalize = False
        self.meanret = True
        self.atype = atype  # gae or ret or 1step or adv
        self.alpha = alpha
        self.beta = beta
        self.nsample = nsample
        self.task = task
        self.n_cg_iters = n_cg_iters

        self.policy = policy
        self.vfn = vfn
        self.old_policy: nn.Module = policy.clone()
        self.warmup_policy = warmup_policy
        self.old_warmup_policy: nn.Module = warmup_policy.clone()
        self.warmup_vfn = warmup_vfn
        self.warmup_policy1: nn.Module = warmup_policy.clone()
        self.warmup_policy2: nn.Module = warmup_policy.clone()
        self.warmup_policy3: nn.Module = warmup_policy.clone()

        with self.scope:
            # placeholder
            self.op_advantages = tf.placeholder(dtype=tf.float32,
                                                shape=[None],
                                                name='advantages')
            self.op_advantages_mean = tf.placeholder(dtype=tf.float32,
                                                     name='advantages_mean')
            self.op_advantages_std = tf.placeholder(dtype=tf.float32,
                                                    name='advantages_std')
            self.op_states = tf.placeholder(dtype=tf.float32,
                                            shape=[None, dim_state],
                                            name='states')
            self.op_actions = tf.placeholder(dtype=tf.float32,
                                             shape=[None, dim_action],
                                             name='actions')
            self.op_tangents = tf.placeholder(
                dtype=tf.float32,
                shape=[nn.utils.n_parameters(self.warmup_policy.parameters())])
            self.op_feed_params = tf.placeholder(dtype=tf.float32,
                                                 shape=[None],
                                                 name='feed_params')

            self.op_reward_ctrl = tf.placeholder(
                dtype=tf.float32, shape=[None, None],
                name='reward_ctrl')  #nstep,nenv
            self.op_reward_state = tf.placeholder(
                dtype=tf.float32,
                shape=[None, None, self.task.n_params],
                name='reward_state')  #nstep,nenv,n
            self.op_coef_mat = tf.placeholder(
                dtype=tf.float32, shape=[None, None, None],
                name='coef_mat')  #nenv,nstep,nstep
            self.op_coef_mat_returns = tf.placeholder(
                dtype=tf.float32,
                shape=[None, None, None],
                name='coef_mat_returns')  #nenv,nstep,nstep
            self.op_td = tf.placeholder(dtype=tf.float32,
                                        shape=[None, None],
                                        name='td')  #nstep,nenv
            self.op_values = tf.placeholder(dtype=tf.float32,
                                            shape=[None, None],
                                            name='values')  #nstep,nenv
            self.op_begin_mark = tf.placeholder(dtype=tf.float32,
                                                shape=[None, None],
                                                name='begin_mark')  #nstep,nenv

            # Simulate GAE
            EPS = 1e-6
            #goal_velocity_params = tf.reshape(self.task.goal_velocity_params, (1, 1, self.task.n_dim))
            goal_velocity_params = self.task.goal_velocity_params
            coef = self.task.coef
            func = self.task.func
            logger.info(f'reward coef in ADVTASK is {coef}')
            logger.info(f'reward func in ADVTASK is {func}')
            if func == 'abs':
                logger.info('we are in abs func!')
                reward_params_gae = self.op_reward_ctrl - tf.reduce_sum(
                    tf.abs(self.op_reward_state - goal_velocity_params) * coef,
                    axis=2) + self.op_td  #nstep,nenv
                reward_params = self.op_reward_ctrl - tf.reduce_sum(
                    tf.abs(self.op_reward_state - goal_velocity_params) * coef,
                    axis=2)
            elif func == 'linear':
                logger.info('we are in linear func!')
                reward_params_gae = self.op_reward_ctrl + tf.reduce_sum(
                    tf.multiply(self.op_reward_state + EPS,
                                goal_velocity_params * coef),
                    axis=2) + self.op_td  #nstep,nenv
                reward_params = self.op_reward_ctrl + tf.reduce_sum(
                    tf.multiply(self.op_reward_state + EPS,
                                goal_velocity_params * coef),
                    axis=2)
            else:
                raise Exception(
                    f'{FLAGS.task.reward} reward function is not available!')

            if self.atype == '1step':
                advantages_params = reward_params_gae  #r_t + \gamma V(s') - V(s)
            elif self.atype == 'gae':
                reward_params_gae = tf.transpose(
                    reward_params_gae)  #nenv,nstep
                advantages_params = tf.squeeze(
                    tf.matmul(self.op_coef_mat,
                              tf.expand_dims(reward_params_gae, 2)),
                    [2])  #nenv,nstep,nstep * nenv,nstep,1 -> nenv,nstep
                advantages_params = tf.transpose(
                    advantages_params)  #nstep,nenv
            elif self.atype == 'ret':
                reward_params = tf.transpose(reward_params)  #nenv,nstep
                advantages_params = tf.squeeze(
                    tf.matmul(self.op_coef_mat_returns,
                              tf.expand_dims(reward_params, 2)),
                    [2])  #nenv,nstep,nstep * nenv,nstep,1 -> nenv,nstep
                advantages_params = tf.transpose(
                    advantages_params)  #nstep,nenv
            elif self.atype == 'adv':
                reward_params = tf.transpose(reward_params)  #nenv,nstep  r_t
                advantages_params = tf.squeeze(
                    tf.matmul(self.op_coef_mat_returns,
                              tf.expand_dims(reward_params, 2)),
                    [2])  #nenv,nstep,nstep * nenv,nstep,1 -> nenv,nstep
                advantages_params = tf.transpose(
                    advantages_params)  #nstep,nenv
                advantages_params = advantages_params - self.op_values  #\sum r_t - V(s)

            self.advantages_params = tf.reshape(advantages_params,
                                                (-1, ))  #nstep*nenv
            self.advantages_params_pre = self.advantages_params
            self.adv_mean, self.adv_var = tf.nn.moments(tf.stop_gradient(
                self.advantages_params),
                                                        axes=0)
            self.adv_std = tf.sqrt(self.adv_var)
            if self.advnormalize:
                self.advantages_params = (self.advantages_params -
                                          self.adv_mean) / tf.maximum(
                                              self.adv_std, 1e-8)
            self.op_advantages_params = self.advantages_params.reduce_mean()

            # reward function and returns
            if func == 'abs':
                reward_params = self.op_reward_ctrl - tf.reduce_sum(
                    tf.abs(self.op_reward_state - goal_velocity_params) * coef,
                    axis=2)
            elif func == 'linear':
                reward_params = self.op_reward_ctrl + tf.reduce_sum(
                    tf.multiply(self.op_reward_state + EPS,
                                goal_velocity_params * coef),
                    axis=2)
            else:
                raise Exception(
                    f'{FLAGS.task.reward} reward function is not available!')
            reward_params = tf.transpose(reward_params)  #nenv,nstep
            returns_params = tf.squeeze(
                tf.matmul(self.op_coef_mat_returns,
                          tf.expand_dims(reward_params, 2)),
                [2])  #nenv,nstep,nstep * nenv,nstep,1 -> nenv,nstep
            returns_params = tf.transpose(returns_params)  #nstep,nenv
            #NOTE that NOT: self.returns_params = tf.reshape(returns_params, (-1,)) #nstep*nenv

            if self.meanret:
                self.returns_params = tf.reduce_sum(
                    returns_params * self.op_begin_mark) / tf.reduce_sum(
                        self.op_begin_mark * 0. + 1.)
            else:
                self.returns_params = tf.reduce_sum(
                    returns_params * self.op_begin_mark) / tf.reduce_sum(
                        self.op_begin_mark)
            returns_params_nozero = self.returns_params
            self.r_mean = self.adv_mean
            self.r_var = self.adv_var
            self.r_std = self.adv_std
            self.op_returns_params = self.returns_params.reduce_mean()

        # build loss from input placeholder
        self.op_policy_loss_warmup, self.op_policy_loss_warmup_quad, self.op_return_loss_warmup, self.op_policy_loss, self.op_return_loss= \
            self.build_loss(self.op_states, self.op_actions, self.advantages_params, self.returns_params)

        # compute jacobian and hessian part
        self.op_task_b, self.op_task_hvp = self.compute_task_jacobian(
            self.op_policy_loss_warmup, self.op_policy_loss_warmup_quad,
            self.op_tangents)

        # compute sync op and other gradients
        self.op_sync_old_warmup, self.op_flat_grad_policy_warmup, self.op_flat_grad_task_warmup = self.compute_grad(
            self.warmup_policy, self.old_warmup_policy,
            self.op_policy_loss_warmup, self.op_return_loss_warmup)
        self.op_sync_old, _, self.op_flat_grad_task_opt = self.compute_grad(
            self.policy, self.old_policy, self.op_policy_loss,
            self.op_return_loss)
        self.op_logpac_plot = self.warmup_policy(self.op_states).log_prob(
            self.op_actions).reduce_sum(axis=1).reduce_mean()

        # build sync all op
        assign_ops = []
        for old_v, new_v in zip(self.warmup_policy1.parameters(),
                                self.warmup_policy.parameters()):
            assign_ops.append(tf.assign(old_v, new_v))
        for old_v, new_v in zip(self.warmup_policy2.parameters(),
                                self.warmup_policy.parameters()):
            assign_ops.append(tf.assign(old_v, new_v))
        for old_v, new_v in zip(self.warmup_policy3.parameters(),
                                self.warmup_policy.parameters()):
            assign_ops.append(tf.assign(old_v, new_v))
        self.op_sync_all = tf.group(*assign_ops)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    bc_normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    bc_policy = GaussianMLPPolicy(dim_state,
                                  dim_action,
                                  FLAGS.TRPO.policy_hidden_sizes,
                                  output_diff=FLAGS.TRPO.output_diff,
                                  normalizers=bc_normalizers)

    gail_normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    gail_policy = GaussianMLPPolicy(dim_state,
                                    dim_action,
                                    FLAGS.TRPO.policy_hidden_sizes,
                                    output_diff=FLAGS.TRPO.output_diff,
                                    normalizers=gail_normalizers)

    actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    policy_load = f'dataset/mb2/{FLAGS.env.id}/policy.npy'
    loader.load_state_dict(np.load(policy_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % policy_load)

    bc_policy_load = "benchmarks/mbrl_benchmark/mbrl2_bc_30_1000/mbrl2_bc-Walker2d-v2-100-2020-05-22-16-02-12/final.npy"
    loader = nn.ModuleDict({
        'policy': bc_policy,
        'normalizers': bc_normalizers
    })
    loader.load_state_dict(np.load(bc_policy_load, allow_pickle=True)[()])
    logger.warning('Load bc policy from %s' % bc_policy_load)

    gail_policy_load = "benchmarks/mbrl_benchmark/mbrl2_gail_grad_penalty/mbrl2_gail-Walker2d-v2-100-2020-05-22-12-10-07/final.npy"
    loader = nn.ModuleDict({
        'policy': gail_policy,
        'normalizers': gail_normalizers
    })
    loader.load_state_dict(np.load(gail_policy_load, allow_pickle=True)[()])
    logger.warning('Load gail policy from %s' % gail_policy_load)

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    real_runner = Runner(env,
                         max_steps=env.max_episode_steps,
                         rescale_action=False)
    # virtual env
    env_bc_stochastic = VirtualEnv(bc_policy,
                                   env,
                                   n_envs=1,
                                   stochastic_model=True)
    env_bc_deterministic = VirtualEnv(bc_policy,
                                      env,
                                      n_envs=1,
                                      stochastic_model=False)
    runner_bc_stochastic = VirtualRunner(env_bc_stochastic,
                                         max_steps=env.max_episode_steps,
                                         rescale_action=False)
    runner_bc_deterministic = VirtualRunner(env_bc_deterministic,
                                            max_steps=env.max_episode_steps,
                                            rescale_action=False)

    env_gail_stochastic = VirtualEnv(gail_policy,
                                     env,
                                     n_envs=1,
                                     stochastic_model=True)
    env_gail_deterministic = VirtualEnv(gail_policy,
                                        env,
                                        n_envs=1,
                                        stochastic_model=False)
    runner_gail_stochastic = VirtualRunner(env_gail_stochastic,
                                           max_steps=env.max_episode_steps)
    runner_gail_deterministic = VirtualRunner(env_gail_deterministic,
                                              max_steps=env.max_episode_steps)

    data_actor, ep_infos = real_runner.run(actor,
                                           n_samples=int(2e3),
                                           stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for actor avg return = %.4f avg length = %d',
        len(data_actor), np.mean(returns), np.mean(lengths))

    data_bc_stochastic, ep_infos = runner_bc_stochastic.run(actor,
                                                            n_samples=int(2e3),
                                                            stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for bc stochastic policy avg return = %.4f avg length = %d',
        len(data_bc_stochastic), np.mean(returns), np.mean(lengths))

    reward_ref, _ = env.mb_step(data_bc_stochastic.state,
                                data_bc_stochastic.action,
                                data_bc_stochastic.next_state)
    np.testing.assert_allclose(reward_ref,
                               data_bc_stochastic.reward,
                               rtol=1e-4,
                               atol=1e-4)

    data_bc_deterministic, ep_infos = runner_bc_deterministic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for bc deterministic policy avg return = %.4f avg length = %d',
        len(data_bc_deterministic), np.mean(returns), np.mean(lengths))

    reward_ref, _ = env.mb_step(data_bc_deterministic.state,
                                data_bc_deterministic.action,
                                data_bc_deterministic.next_state)
    np.testing.assert_allclose(reward_ref,
                               data_bc_deterministic.reward,
                               rtol=1e-4,
                               atol=1e-4)

    data_gail_stochastic, ep_infos = runner_gail_stochastic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for gail stochastic policy avg return = %.4f avg length = %d',
        len(data_gail_stochastic), np.mean(returns), np.mean(lengths))
    data_gail_deterministic, ep_infos = runner_gail_deterministic.run(
        actor, n_samples=int(2e3), stochastic=False)
    returns = [info['return'] for info in ep_infos]
    lengths = [info['length'] for info in ep_infos]
    logger.info(
        'Collect %d samples for gail deterministic policy avg return = %.4f avg length = %d',
        len(data_bc_deterministic), np.mean(returns), np.mean(lengths))

    t_sne = manifold.TSNE(init='pca', random_state=2020)
    data = np.concatenate([
        data.state for data in [
            data_actor, data_bc_stochastic, data_bc_deterministic,
            data_gail_stochastic, data_gail_deterministic
        ]
    ],
                          axis=0)
    step = np.concatenate([
        data.step for data in [
            data_actor, data_bc_stochastic, data_bc_deterministic,
            data_gail_stochastic, data_gail_deterministic
        ]
    ],
                          axis=0)
    loc, scale = bc_normalizers.state.eval('mean std')
    data = (data - loc) / (1e-6 + scale)
    embedding = t_sne.fit_transform(data)

    fig, axarrs = plt.subplots(nrows=1,
                               ncols=5,
                               figsize=[6 * 5, 4],
                               squeeze=False,
                               sharex=True,
                               sharey=True,
                               dpi=300)
    start = 0
    indices = 0
    g2c = {}
    for title in [
            'expert', 'bc_stochastic', 'bc_deterministic', 'gail_stochastic',
            'gail_deterministic'
    ]:
        g2c[title] = axarrs[0][indices].scatter(embedding[start:start + 2000,
                                                          0],
                                                embedding[start:start + 2000,
                                                          1],
                                                c=step[start:start + 2000])
        axarrs[0][indices].set_title(title)
        indices += 1
        start += 2000
    plt.colorbar(list(g2c.values())[0], ax=axarrs.flatten())
    plt.tight_layout()
    plt.savefig(f'{FLAGS.log_dir}/visualize.png', bbox_inches='tight')

    data = {
        'expert': data_actor.state,
        'bc_stochastic': data_bc_stochastic.state,
        'bc_deterministic': data_bc_deterministic.state,
        'gail_stochastic': data_gail_stochastic.state,
        'gail_deterministic': data_gail_deterministic.state
    }
    np.savez(f'{FLAGS.log_dir}/data.npz', **data)
Exemplo n.º 21
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    # load expert dataset
    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    expert_batch = expert_dataset.sample(10)
    expert_state = np.stack([t.obs for t in expert_batch])
    expert_action = np.stack([t.action for t in expert_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(expert_state),
                np.mean(expert_action))
    del expert_batch, expert_state, expert_action
    set_random_seed(FLAGS.seed)

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    vfn = MLPVFunction(dim_state, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    if FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        loc, scale = np.mean(expert_state, axis=0,
                             keepdims=True), np.std(expert_state,
                                                    axis=0,
                                                    keepdims=True)
        del expert_batch, expert_state
        discriminator = Discriminator(dim_state,
                                      dim_action,
                                      normalizers=normalizers,
                                      subsampling_rate=subsampling_rate,
                                      loc=loc,
                                      scale=scale,
                                      **FLAGS.GAIL.discriminator.as_dict())
    elif FLAGS.GAIL.reward_type in {'simplex', 'l2'}:
        discriminator = LinearReward(
            dim_state, dim_action, simplex=FLAGS.GAIL.reward_type == 'simplex')
    else:
        raise NotImplementedError
    tf.get_default_session().run(tf.global_variables_initializer())

    if not FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        discriminator.build(expert_state, expert_action)
        del expert_batch, expert_state, expert_action

    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers,
        'discriminator': discriminator
    })
    runner = Runner(env,
                    max_steps=env.max_episode_steps,
                    gamma=FLAGS.TRPO.gamma,
                    lambda_=FLAGS.TRPO.lambda_,
                    add_absorbing_state=FLAGS.GAIL.learn_absorbing)
    print(saver)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    eval_gamma = 0.999
    for t in range(0, FLAGS.GAIL.total_timesteps,
                   FLAGS.TRPO.rollout_samples * FLAGS.GAIL.g_iters):
        time_st = time.time()
        if t % FLAGS.GAIL.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            eval_returns_discount, eval_lengths_discount = evaluate(
                policy, env_eval, gamma=eval_gamma)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths))),
                             discounted_episode=dict(
                                 returns=np.mean(eval_returns_discount),
                                 lengths=int(np.mean(eval_lengths_discount)))))

        # Generator
        generator_dataset = None
        for n_update in range(FLAGS.GAIL.g_iters):
            data, ep_infos = runner.run(policy, FLAGS.TRPO.rollout_samples)
            if FLAGS.TRPO.normalization:
                normalizers.state.update(data.state)
                normalizers.action.update(data.action)
                normalizers.diff.update(data.next_state - data.state)
            if t == 0 and n_update == 0 and not FLAGS.GAIL.learn_absorbing:
                data_ = data.copy()
                data_ = data_.reshape(
                    [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
                for e in range(env.n_envs):
                    samples = data_[:, e]
                    masks = 1 - (samples.done | samples.timeout)[...,
                                                                 np.newaxis]
                    masks = masks[:-1]
                    assert np.allclose(samples.state[1:] * masks,
                                       samples.next_state[:-1] * masks)
            t += FLAGS.TRPO.rollout_samples
            data.reward = discriminator.get_reward(data.state, data.action)
            advantages, values = runner.compute_advantage(vfn, data)
            train_info = algo.train(max_ent_coef, data, advantages, values)
            fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
            train_info['reward'] = np.mean(data.reward)
            train_info['fps'] = fps

            expert_batch = expert_dataset.sample(256)
            expert_state = np.stack([t.obs for t in expert_batch])
            expert_action = np.stack([t.action for t in expert_batch])
            train_info['mse_loss'] = policy.get_mse_loss(
                expert_state, expert_action)
            log_kvs(prefix='TRPO', kvs=dict(iter=t, **train_info))

            generator_dataset = data

        # Discriminator
        if FLAGS.GAIL.reward_type in {'nn', 'vb'}:
            for n_update in range(FLAGS.GAIL.d_iters):
                batch_size = FLAGS.GAIL.d_batch_size
                d_train_infos = dict()
                for generator_subset in generator_dataset.iterator(batch_size):
                    expert_batch = expert_dataset.sample(batch_size)
                    expert_state = np.stack([t.obs for t in expert_batch])
                    expert_action = np.stack([t.action for t in expert_batch])
                    expert_mask = np.stack([
                        t.mask for t in expert_batch
                    ]).flatten() if FLAGS.GAIL.learn_absorbing else None
                    train_info = discriminator.train(
                        expert_state,
                        expert_action,
                        generator_subset.state,
                        generator_subset.action,
                        expert_mask,
                    )
                    for k, v in train_info.items():
                        if k not in d_train_infos:
                            d_train_infos[k] = []
                        d_train_infos[k].append(v)
                d_train_infos = {
                    k: np.mean(v)
                    for k, v in d_train_infos.items()
                }
                if n_update == FLAGS.GAIL.d_iters - 1:
                    log_kvs(prefix='Discriminator',
                            kvs=dict(iter=t, **d_train_infos))
        else:
            train_info = discriminator.train(generator_dataset.state,
                                             generator_dataset.action)
            log_kvs(prefix='Discriminator', kvs=dict(iter=t, **train_info))

        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate(policy, env_eval, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    collect_mb = FLAGS.env.env_type == 'mb'
    if collect_mb:
        env_id = 'MB' + FLAGS.env.id
        logger.warning('Collect dataset for imitating environments')
    else:
        env_id = FLAGS.env.id
        logger.warning('Collect dataset for imitating policies')
    env = create_env(env_id,
                     FLAGS.seed,
                     FLAGS.log_dir,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)

    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.info('Load policy from %s' % FLAGS.ckpt.policy_load)

    state_traj, action_traj, next_state_traj, reward_traj, len_traj = [], [], [], [], []
    returns = []
    while len(state_traj) < 50:
        states = np.zeros([env.max_episode_steps, dim_state], dtype=np.float32)
        actions = np.zeros([env.max_episode_steps, dim_action],
                           dtype=np.float32)
        next_states = np.zeros([env.max_episode_steps, dim_state],
                               dtype=np.float32)
        rewards = np.zeros([env.max_episode_steps], dtype=np.float32)
        state = env.reset()
        done = False
        t = 0
        while not done:
            action = actor.get_actions(state[None], fetch='actions_mean')
            next_state, reward, done, info = env.step(action)

            states[t] = state
            actions[t] = action
            rewards[t] = reward
            next_states[t] = next_state
            t += 1
            if done:
                break
            state = next_state
        if t < 700 or np.sum(rewards) < 0:
            continue
        state_traj.append(states)
        action_traj.append(actions)
        next_state_traj.append(next_states)
        reward_traj.append(rewards)
        len_traj.append(t)

        returns.append(np.sum(rewards))
        logger.info('# %d: collect a trajectory return = %.4f length = %d',
                    len(state_traj), np.sum(rewards), t)

    state_traj = np.array(state_traj)
    action_traj = np.array(action_traj)
    next_state_traj = np.array(next_state_traj)
    reward_traj = np.array(reward_traj)
    len_traj = np.array(len_traj)
    assert len(state_traj.shape) == len(action_traj.shape) == 3
    assert len(reward_traj.shape) == 2 and len(len_traj.shape) == 1

    dataset = {
        'a_B_T_Da': action_traj,
        'len_B': len_traj,
        'obs_B_T_Do': state_traj,
        'r_B_T': reward_traj
    }
    if collect_mb:
        dataset['next_obs_B_T_Do'] = next_state_traj
    logger.info('Expert avg return = %.4f avg length = %d', np.mean(returns),
                np.mean(len_traj))

    if collect_mb:
        root_dir = 'dataset/mb2'
    else:
        root_dir = 'dataset/sac'

    save_dir = f'{root_dir}/{FLAGS.env.id}'
    os.makedirs(save_dir, exist_ok=True)
    shutil.copy(FLAGS.ckpt.policy_load, os.path.join(save_dir, 'policy.npy'))

    save_path = f'{root_dir}/{FLAGS.env.id}.h5'
    f = h5py.File(save_path, 'w')
    f.update(dataset)
    f.close()
    logger.info('save dataset into %s' % save_path)
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id, seed=FLAGS.seed, rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    # expert actor
    actor = Actor(dim_state, dim_action, init_std=0.)
    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    expert_state, expert_action, expert_next_state, expert_reward = collect_samples_from_true_env(
        env=env, actor=actor, nb_episode=FLAGS.GAIL.traj_limit, subsampling_rate=subsampling_rate)
    logger.info('Collect % d samples avg return = %.4f', len(expert_state), np.mean(expert_reward))
    eval_state, eval_action, eval_next_state, eval_reward = collect_samples_from_true_env(
        env=env, actor=actor, nb_episode=3, seed=FLAGS.seed)
    loc, scale = np.mean(expert_state, axis=0, keepdims=True), np.std(expert_state, axis=0, keepdims=True)
    logger.info('loc = {}\nscale={}'.format(loc, scale))

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state, dim_action, FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff, normalizers=normalizers)
    bc_loss = BehavioralCloningLoss(dim_state, dim_action, policy, lr=float(FLAGS.BC.lr), train_std=FLAGS.BC.train_std)

    tf.get_default_session().run(tf.global_variables_initializer())
    set_random_seed(FLAGS.seed)

    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    # updater normalizer
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor, env, gamma=eval_gamma)
    logger.warning('Test policy true value = %.4f true length = %d (gamma = %f)',
                   np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # virtual env
    env_eval_stochastic = VirtualEnv(policy, env, n_envs=4, stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy, env, n_envs=4, stochastic_model=False)

    batch_size = FLAGS.BC.batch_size
    true_return = np.mean(eval_returns)
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(prefix='Evaluate', kvs=dict(
                iter=t, stochastic_episode=dict(
                    returns=np.mean(eval_returns_stochastic), lengths=int(np.mean(eval_lengths_stochastic))
                ), episode=dict(
                    returns=np.mean(eval_returns_deterministic), lengths=int(np.mean(eval_lengths_deterministic))
                ),  evaluation_error=dict(
                    stochastic_error=true_return-np.mean(eval_returns_stochastic),
                    stochastic_abs=np.abs(true_return-np.mean(eval_returns_stochastic)),
                    stochastic_rel=np.abs(true_return-np.mean(eval_returns_stochastic))/true_return,
                    deterministic_error=true_return-np.mean(eval_returns_deterministic),
                    deterministic_abs=np.abs(true_return - np.mean(eval_returns_deterministic)),
                    deterministic_rel=np.abs(true_return-np.mean(eval_returns_deterministic))/true_return
                )
            ))

        indices = np.random.randint(low=0, high=len(expert_state), size=batch_size)
        expert_state_ = expert_state[indices]
        expert_action_ = expert_action[indices]
        expert_next_state_ = expert_next_state[indices]
        _, loss, grad_norm = bc_loss.get_loss(expert_state_, expert_action_, expert_next_state_,
                                              fetch='train loss grad_norm')

        if t % 100 == 0:
            train_mse_loss = policy.get_mse_loss(expert_state_, expert_action_, expert_next_state_)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action, eval_next_state)
            log_kvs(prefix='BC', kvs=dict(
                iter=t, grad_norm=grad_norm, loss=loss, mse_loss=dict(train=train_mse_loss, eval=eval_mse_loss)
            ))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(actor, env_eval_deterministic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
Exemplo n.º 24
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id)
    dim_state = int(np.prod(env.observation_space.shape))
    dim_action = int(np.prod(env.action_space.shape))

    env.verify()

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)

    dtype = gen_dtype(env, 'state action next_state reward done timeout')
    train_set = Dataset(dtype, FLAGS.rollout.max_buf_size)
    dev_set = Dataset(dtype, FLAGS.rollout.max_buf_size)

    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               normalizer=normalizers.state,
                               **FLAGS.policy.as_dict())
    # batched noises
    noise = OUNoise(env.action_space,
                    theta=FLAGS.OUNoise.theta,
                    sigma=FLAGS.OUNoise.sigma,
                    shape=(1, dim_action))
    vfn = MLPVFunction(dim_state, [64, 64], normalizers.state)
    model = DynamicsModel(dim_state, dim_action, normalizers,
                          FLAGS.model.hidden_sizes)

    virt_env = VirtualEnv(model,
                          make_env(FLAGS.env.id),
                          FLAGS.plan.n_envs,
                          opt_model=FLAGS.slbo.opt_model)
    virt_runner = Runner(
        virt_env, **{
            **FLAGS.runner.as_dict(), 'max_steps': FLAGS.plan.max_steps
        })

    criterion_map = {
        'L1': nn.L1Loss(),
        'L2': nn.L2Loss(),
        'MSE': nn.MSELoss(),
    }
    criterion = criterion_map[FLAGS.model.loss]
    loss_mod = MultiStepLoss(model, normalizers, dim_state, dim_action,
                             criterion, FLAGS.model.multi_step)
    loss_mod.build_backward(FLAGS.model.lr, FLAGS.model.weight_decay)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())

    runners = {
        'test':
        make_real_runner(4),
        'collect':
        make_real_runner(1),
        'dev':
        make_real_runner(1),
        'train':
        make_real_runner(FLAGS.plan.n_envs)
        if FLAGS.algorithm == 'MF' else virt_runner,
    }
    settings = [(runners['test'], policy, 'Real Env'),
                (runners['train'], policy, 'Virt Env')]

    saver = nn.ModuleDict({'policy': policy, 'model': model, 'vfn': vfn})
    print(saver)

    if FLAGS.ckpt.model_load:
        saver.load_state_dict(np.load(FLAGS.ckpt.model_load)[()])
        logger.warning('Load model from %s', FLAGS.ckpt.model_load)

    if FLAGS.ckpt.buf_load:
        n_samples = 0
        for i in range(FLAGS.ckpt.buf_load_index):
            data = pickle.load(
                open(f'{FLAGS.ckpt.buf_load}/stage-{i}.inc-buf.pkl', 'rb'))
            add_multi_step(data, train_set)
            n_samples += len(data)
        logger.warning('Loading %d samples from %s', n_samples,
                       FLAGS.ckpt.buf_load)

    max_ent_coef = FLAGS.TRPO.ent_coef

    for T in range(FLAGS.slbo.n_stages):
        logger.info('------ Starting Stage %d --------', T)
        evaluate(settings, 'episode')

        if not FLAGS.use_prev:
            train_set.clear()
            dev_set.clear()

        # collect data
        recent_train_set, ep_infos = runners['collect'].run(
            noise.make(policy), FLAGS.rollout.n_train_samples)
        add_multi_step(recent_train_set, train_set)
        add_multi_step(
            runners['dev'].run(noise.make(policy),
                               FLAGS.rollout.n_dev_samples)[0],
            dev_set,
        )

        returns = np.array([ep_info['return'] for ep_info in ep_infos])
        if len(returns) > 0:
            logger.info("episode: %s", np.mean(returns))

        if T == 0:  # check
            samples = train_set.sample_multi_step(100, 1,
                                                  FLAGS.model.multi_step)
            for i in range(FLAGS.model.multi_step - 1):
                masks = 1 - (samples.done[i] | samples.timeout[i])[...,
                                                                   np.newaxis]
                assert np.allclose(samples.state[i + 1] * masks,
                                   samples.next_state[i] * masks)

        # recent_states = obsvs
        # ref_actions = policy.eval('actions_mean actions_std', states=recent_states)
        if FLAGS.rollout.normalizer == 'policy' or FLAGS.rollout.normalizer == 'uniform' and T == 0:
            normalizers.state.update(recent_train_set.state)
            normalizers.action.update(recent_train_set.action)
            normalizers.diff.update(recent_train_set.next_state -
                                    recent_train_set.state)

        if T == 50:
            max_ent_coef = 0.

        for i in range(FLAGS.slbo.n_iters):
            if i % FLAGS.slbo.n_evaluate_iters == 0 and i != 0:
                # cur_actions = policy.eval('actions_mean actions_std', states=recent_states)
                # kl_old_new = gaussian_kl(*ref_actions, *cur_actions).sum(axis=1).mean()
                # logger.info('KL(old || cur) = %.6f', kl_old_new)
                evaluate(settings, 'iteration')

            losses = deque(maxlen=FLAGS.slbo.n_model_iters)
            grad_norm_meter = AverageMeter()
            n_model_iters = FLAGS.slbo.n_model_iters
            for _ in range(n_model_iters):
                samples = train_set.sample_multi_step(
                    FLAGS.model.train_batch_size, 1, FLAGS.model.multi_step)
                _, train_loss, grad_norm = loss_mod.get_loss(
                    samples.state,
                    samples.next_state,
                    samples.action,
                    ~samples.done & ~samples.timeout,
                    fetch='train loss grad_norm')
                losses.append(train_loss.mean())
                grad_norm_meter.update(grad_norm)
                # ideally, we should define an Optimizer class, which takes parameters as inputs.
                # The `update` method of `Optimizer` will invalidate all parameters during updates.
                for param in model.parameters():
                    param.invalidate()

            if i % FLAGS.model.validation_freq == 0:
                samples = train_set.sample_multi_step(
                    FLAGS.model.train_batch_size, 1, FLAGS.model.multi_step)
                loss = loss_mod.get_loss(samples.state, samples.next_state,
                                         samples.action,
                                         ~samples.done & ~samples.timeout)
                loss = loss.mean()
                if np.isnan(loss) or np.isnan(np.mean(losses)):
                    logger.info('nan! %s %s', np.isnan(loss),
                                np.isnan(np.mean(losses)))
                logger.info(
                    '# Iter %3d: Loss = [train = %.3f, dev = %.3f], after %d steps, grad_norm = %.6f',
                    i, np.mean(losses), loss, n_model_iters,
                    grad_norm_meter.get())

            for n_updates in range(FLAGS.slbo.n_policy_iters):
                if FLAGS.algorithm != 'MF' and FLAGS.slbo.start == 'buffer':
                    runners['train'].set_state(
                        train_set.sample(FLAGS.plan.n_envs).state)
                else:
                    runners['train'].reset()

                data, ep_infos = runners['train'].run(
                    policy, FLAGS.plan.n_trpo_samples)
                advantages, values = runners['train'].compute_advantage(
                    vfn, data)
                dist_mean, dist_std, vf_loss = algo.train(
                    max_ent_coef, data, advantages, values)
                returns = [info['return'] for info in ep_infos]
                logger.info(
                    '[TRPO] # %d: n_episodes = %d, returns: {mean = %.0f, std = %.0f}, '
                    'dist std = %.10f, dist mean = %.10f, vf_loss = %.3f',
                    n_updates, len(returns), np.mean(returns),
                    np.std(returns) / np.sqrt(len(returns)), dist_std,
                    dist_mean, vf_loss)

        if T % FLAGS.ckpt.n_save_stages == 0:
            np.save(f'{FLAGS.log_dir}/stage-{T}', saver.state_dict())
            np.save(f'{FLAGS.log_dir}/final', saver.state_dict())
        if FLAGS.ckpt.n_save_stages == 1:
            pickle.dump(recent_train_set,
                        open(f'{FLAGS.log_dir}/stage-{T}.inc-buf.pkl', 'wb'))
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     log_dir=FLAGS.log_dir,
                     absorbing_state=FLAGS.GAIL.learn_absorbing,
                     rescale_action=FLAGS.env.rescale_action)
    env_eval = create_env(FLAGS.env.id,
                          seed=FLAGS.seed + 1000,
                          log_dir=FLAGS.log_dir,
                          absorbing_state=FLAGS.GAIL.learn_absorbing,
                          rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               normalizer=normalizers.state)
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=float(FLAGS.BC.lr),
                                    train_std=FLAGS.BC.train_std)

    expert_actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': expert_actor})
    if FLAGS.BC.dagger:
        loader.load_state_dict(
            np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
        logger.warning('Load expert policy from %s' % FLAGS.ckpt.policy_load)
    runner = Runner(env, max_steps=env.max_episode_steps, rescale_action=False)

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    expert_batch = expert_dataset.sample(10)
    expert_state = np.stack([t.obs for t in expert_batch])
    expert_action = np.stack([t.action for t in expert_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(expert_state),
                np.mean(expert_action))
    del expert_batch, expert_state, expert_action
    set_random_seed(FLAGS.seed)

    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    batch_size = FLAGS.BC.batch_size
    eval_gamma = 0.999
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(policy, env_eval)
            eval_returns_discount, eval_lengths_discount = evaluate(
                policy, env_eval, gamma=eval_gamma)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths))),
                             discounted_episode=dict(
                                 returns=np.mean(eval_returns_discount),
                                 lengths=int(np.mean(eval_lengths_discount)))))

        expert_batch = expert_dataset.sample(batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              fetch='train loss grad_norm')

        if FLAGS.BC.dagger and t % FLAGS.BC.collect_freq == 0 and t > 0:
            if t // FLAGS.BC.collect_freq == 1:
                collect_policy = expert_actor
                stochastic = False
                logger.info('Collect samples with expert actor...')
            else:
                collect_policy = policy
                stochastic = True
                logger.info('Collect samples with learned policy...')
            runner.reset()
            data, ep_infos = runner.run(collect_policy,
                                        FLAGS.BC.n_collect_samples, stochastic)
            data.action = expert_actor.get_actions(data.state,
                                                   fetch='actions_mean')
            returns = [info['return'] for info in ep_infos]
            lengths = [info['length'] for info in ep_infos]
            for i in range(len(data)):
                expert_dataset.push_back(data[i].state, data[i].action,
                                         data[i].next_state, data[i].reward,
                                         data[i].mask, data[i].timeout)
            logger.info('Collect %d samples avg return = %.4f avg length = %d',
                        len(data), np.mean(returns), np.mean(lengths))
        if t % 100 == 0:
            mse_loss = policy.get_mse_loss(expert_state, expert_action)
            log_kvs(prefix='BC',
                    kvs=dict(iter=t,
                             loss=loss,
                             grad_norm=grad_norm,
                             mse_loss=mse_loss))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate(policy, env_eval, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
    def __init__(self,
                 dim_state: int,
                 dim_action: int,
                 hidden_sizes: List[int],
                 normalizers: Normalizers,
                 lr: float,
                 ent_coef: float,
                 loc=None,
                 scale=None,
                 neural_distance=False,
                 gradient_penalty_coef=0.,
                 l2_regularization_coef=0.,
                 max_grad_norm=None,
                 subsampling_rate=20.):
        super().__init__()
        self.ent_coef = ent_coef
        self.neural_distance = neural_distance
        self.gradient_penalty_coef = gradient_penalty_coef
        self.l2_regularization_coef = l2_regularization_coef
        self.subsampling_rate = subsampling_rate

        with self.scope:
            self.op_true_states = tf.placeholder(tf.float32, [None, dim_state],
                                                 "true_state")
            self.op_true_actions = tf.placeholder(tf.float32,
                                                  [None, dim_action],
                                                  "true_action")
            self.op_fake_states = tf.placeholder(tf.float32, [None, dim_state],
                                                 "fake_state")
            self.op_fake_actions = tf.placeholder(tf.float32,
                                                  [None, dim_action],
                                                  "fake_actions")
            self.op_true_masks = tf.placeholder(tf.float32, [None], "mask")

            if self.neural_distance or self.gradient_penalty_coef > 0.:
                logger.info('Use predefined normalization.')
                if loc is None:
                    loc = np.zeros([1, dim_state], dtype=np.float32)
                if scale is None:
                    scale = np.ones_like([1, dim_action], dtype=np.float32)
                logger.info('Normalizer loc:{} \n scale:{}'.format(loc, scale))
                state_process_fn = lambda states_: (states_ - loc) / (1e-3 +
                                                                      scale)
            else:
                logger.info('Use given normalizer.')
                state_process_fn = lambda states_: normalizers.state(states_)
            action_process_fn = lambda action_: action_
            activ_fn = 'none'
            if self.neural_distance:
                activ_fn = 'none'

            self.classifier = BinaryClassifier(
                dim_state,
                dim_action,
                hidden_sizes,
                state_process_fn=state_process_fn,
                action_process_fn=action_process_fn,
                activ_fn=activ_fn)

            self.op_loss, self.op_classifier_loss, self.op_entropy_loss, self.op_grad_penalty, self.op_regularization, \
                self.op_true_logits, self.op_fake_logits, self.op_true_weight = self(
                    self.op_true_states, self.op_true_actions,
                    self.op_fake_states, self.op_fake_actions,
                    self.op_true_masks)
            self.op_true_prob = tf.nn.sigmoid(self.op_true_logits)
            self.op_fake_prob = tf.nn.sigmoid(self.op_fake_logits)

            optimizer = tf.train.AdamOptimizer(lr)
            params = self.classifier.parameters()
            grads_and_vars = optimizer.compute_gradients(self.op_loss,
                                                         var_list=params)
            if max_grad_norm is not None:
                clip_grads, op_grad_norm = tf.clip_by_global_norm(
                    [grad for grad, _ in grads_and_vars], max_grad_norm)
                clip_grads_and_vars = [
                    (grad, var)
                    for grad, (_, var) in zip(clip_grads, grads_and_vars)
                ]
            else:
                op_grad_norm = tf.global_norm(
                    [grad for grad, _ in grads_and_vars])
                clip_grads_and_vars = grads_and_vars
            self.op_train = optimizer.apply_gradients(clip_grads_and_vars)
            if self.neural_distance:
                logger.info('Discriminator uses Wasserstein distance.')
            logger.info('{}'.format(self.classifier.parameters()))
            logger.info('Use gradient penalty regularization (coef = %f)',
                        gradient_penalty_coef)
            self.op_grad_norm = op_grad_norm
            # neural reward function
            reference = tf.reduce_mean(self.op_fake_logits)
            self.op_unscaled_neural_reward = self.op_fake_logits
            unscaled_reward = self.op_fake_logits - reference
            reward_scale = tf.reduce_max(unscaled_reward) - tf.reduce_min(
                unscaled_reward)
            self.op_scaled_neural_reward = unscaled_reward / (1e-6 +
                                                              reward_scale)
            # gail reward function
            self.op_gail_reward = -tf.log(1 - self.op_fake_prob + 1e-6)
Exemplo n.º 27
0
    def train(self, ent_coef, samples, advantages, values):
        returns = advantages + values
        advantages = (advantages - advantages.mean()) / np.maximum(
            advantages.std(), 1e-8)
        assert np.isfinite(advantages).all()
        self.sync_old()
        old_loss, grad, dist_std, mean_kl, dist_mean = self.get_loss(
            samples.state,
            samples.action,
            samples.next_state,
            advantages,
            ent_coef,
            fetch='loss flat_grad dist_std mean_kl dist_mean')

        if np.allclose(grad, 0):
            logger.info('Zero gradient, not updating...')
            return

        def fisher_vec_prod(x):
            return self.get_hessian_vec_prod(
                samples.state, samples.action, x,
                samples.next_state) + self.cg_damping * x

        assert np.isfinite(grad).all()
        nat_grad, cg_residual = conj_grad(fisher_vec_prod,
                                          grad,
                                          n_iters=self.n_cg_iters,
                                          verbose=False)
        grad_norm = np.linalg.norm(grad)
        nat_grad_norm = np.linalg.norm(nat_grad)

        assert np.isfinite(nat_grad).all()

        old_params = self.flatten.get_flat()
        step_size = np.sqrt(2 * self.max_kl /
                            nat_grad.dot(fisher_vec_prod(nat_grad)))

        for _ in range(10):
            new_params = old_params + nat_grad * step_size
            self.flatten.set_flat(new_params)
            loss, mean_kl = self.get_loss(samples.state,
                                          samples.action,
                                          samples.next_state,
                                          advantages,
                                          ent_coef,
                                          fetch='loss mean_kl')
            improve = loss - old_loss
            if not np.isfinite([loss, mean_kl]).all():
                logger.info('Got non-finite loss.')
            elif mean_kl > self.max_kl * 1.5:
                logger.info(
                    'Violated kl constraints, shrinking step... mean_kl = %.6f, max_kl = %.6f',
                    mean_kl, self.max_kl)
            elif improve < 0:
                logger.info(
                    "Surrogate didn't improve, shrinking step... %.6f => %.6f",
                    old_loss, loss)
            else:
                break
            step_size *= 0.5
        else:
            logger.info("Couldn't find a good step.")
            self.flatten.set_flat(old_params)
        for param in self.policy.parameters():
            param.invalidate()

        # optimize value function
        vf_dataset = Dataset.fromarrays(
            [samples.state, samples.action, returns],
            dtype=[('state', ('f8', self.dim_state)),
                   ('action', ('f8', self.dim_action)), ('return_', 'f8')])
        vf_loss = self.train_vf(vf_dataset)

        info = dict(dist_mean=dist_mean,
                    dist_std=dist_std,
                    vf_loss=np.mean(vf_loss),
                    grad_norm=grad_norm,
                    nat_grad_norm=nat_grad_norm,
                    cg_residual=cg_residual,
                    step_size=step_size)
        return info
Exemplo n.º 28
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir)
    state_spec = env.observation_space
    action_spec = env.action_space

    logger.info('[{}]: state_spec:{}, action_spec:{}'.format(
        FLAGS.env.id, state_spec.shape, action_spec.n))

    dtype = gen_dtype(env,
                      'state action next_state mu reward done timeout info')
    buffer = ReplayBuffer(env.n_envs,
                          FLAGS.ACER.n_steps,
                          stacked_frame=FLAGS.env.env_type == 'atari',
                          dtype=dtype,
                          size=FLAGS.ACER.buffer_size)

    if len(state_spec.shape) == 3:
        policy = CNNPolicy(state_spec, action_spec)
    else:
        policy = MLPPolicy(state_spec, action_spec)

    algo = ACER(state_spec,
                action_spec,
                policy,
                lr=FLAGS.ACER.lr,
                lrschedule=FLAGS.ACER.lrschedule,
                total_timesteps=FLAGS.ACER.total_timesteps,
                ent_coef=FLAGS.ACER.ent_coef,
                q_coef=FLAGS.ACER.q_coef,
                trust_region=FLAGS.ACER.trust_region)
    runner = Runner(env,
                    max_steps=env.max_episode_steps,
                    gamma=FLAGS.ACER.gamma)
    saver = nn.ModuleDict({'policy': policy})
    print(saver)

    tf.get_default_session().run(tf.global_variables_initializer())
    algo.update_old_policy(0.)

    n_steps = FLAGS.ACER.n_steps
    n_batches = n_steps * env.n_envs
    n_stages = FLAGS.ACER.total_timesteps // n_batches

    returns = collections.deque(maxlen=40)
    lengths = collections.deque(maxlen=40)
    replay_reward = collections.deque(maxlen=40)
    time_st = time.time()
    for t in range(n_stages):
        data, ep_infos = runner.run(policy, n_steps)
        returns.extend([info['return'] for info in ep_infos])
        lengths.extend([info['length'] for info in ep_infos])

        if t == 0:  # check runner
            indices = np.arange(0, n_batches, env.n_envs)
            for _ in range(env.n_envs):
                samples = data[indices]
                masks = 1 - (samples.done | samples.timeout)
                masks = masks[:-1]
                masks = np.reshape(masks,
                                   [-1] + [1] * len(samples.state.shape[1:]))
                np.testing.assert_allclose(samples.state[1:] * masks,
                                           samples.next_state[:-1] * masks)
                indices += 1

        buffer.store_episode(data)
        if t == 1:  # check buffer
            data_ = buffer.sample(idx=[1 for _ in range(env.n_envs)])
            check_data_equal(data_, data, ('state', 'action', 'next_state',
                                           'mu', 'reward', 'done', 'timeout'))

        # on-policy training
        qret = runner.compute_qret(policy, data)
        train_info = algo.train(data, qret, t * n_batches)
        replay_reward.append(np.mean(data.reward))
        # off-policy training
        if t * n_batches > FLAGS.ACER.replay_start:
            n = np.random.poisson(FLAGS.ACER.replay_ratio)
            for _ in range(n):
                data = buffer.sample()
                qret = runner.compute_qret(policy, data)
                algo.train(data, qret, t * n_batches)
                replay_reward.append(np.mean(data.reward))

        if t * n_batches % FLAGS.ACER.log_interval == 0:
            fps = int(t * n_batches / (time.time() - time_st))
            kvs = dict(iter=t * n_batches,
                       episode=dict(
                           returns=np.mean(returns) if len(returns) > 0 else 0,
                           lengths=np.mean(lengths).astype(np.int32)
                           if len(lengths) > 0 else 0),
                       **train_info,
                       replay_reward=np.mean(replay_reward)
                       if len(replay_reward) > 0 else 0.,
                       fps=fps)
            log_kvs(prefix='ACER', kvs=kvs)

        if t * n_batches % FLAGS.ACER.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     seed=FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff,
                               normalizers=normalizers)
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=float(FLAGS.BC.lr),
                                    train_std=FLAGS.BC.train_std)

    actor = Actor(dim_state, dim_action, FLAGS.SAC.actor_hidden_sizes)
    tf.get_default_session().run(tf.global_variables_initializer())

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    expert_done = np.stack([t.done for t in expert_dataset.buffer()])
    np.testing.assert_allclose(
        expert_next_state[:-1] * (1 - expert_done[:-1][:, None]),
        expert_state[1:] * (1 - expert_done[:-1][:, None]))
    del expert_state, expert_next_state, expert_done
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    eval_batch = expert_dataset.sample(1024)
    eval_state = np.stack([t.obs for t in eval_batch])
    eval_action = np.stack([t.action for t in eval_batch])
    eval_next_state = np.stack([t.next_obs for t in eval_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(eval_state),
                np.mean(eval_action))
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    set_random_seed(FLAGS.seed)

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.warning('Load expert policy from %s' % FLAGS.ckpt.policy_load)
    saver = nn.ModuleDict({'policy': policy, 'normalizers': normalizers})
    print(saver)

    # updater normalizer
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_action = np.stack([t.action for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)
    del expert_state, expert_action, expert_next_state

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # virtual env
    env_eval_stochastic = VirtualEnv(policy,
                                     env,
                                     n_envs=4,
                                     stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy,
                                        env,
                                        n_envs=4,
                                        stochastic_model=False)

    batch_size = FLAGS.BC.batch_size
    true_return = np.mean(eval_returns)
    for t in range(FLAGS.BC.max_iters):
        if t % FLAGS.BC.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(
                prefix='Evaluate',
                kvs=dict(
                    iter=t,
                    stochastic_episode=dict(
                        returns=np.mean(eval_returns_stochastic),
                        lengths=int(np.mean(eval_lengths_stochastic))),
                    episode=dict(returns=np.mean(eval_returns_deterministic),
                                 lengths=int(
                                     np.mean(eval_lengths_deterministic))),
                    evaluation_error=dict(
                        stochastic_error=true_return -
                        np.mean(eval_returns_stochastic),
                        stochastic_abs=np.abs(
                            true_return - np.mean(eval_returns_stochastic)),
                        stochastic_rel=np.abs(true_return -
                                              np.mean(eval_returns_stochastic))
                        / true_return,
                        deterministic_error=true_return -
                        np.mean(eval_returns_deterministic),
                        deterministic_abs=np.abs(
                            true_return - np.mean(eval_returns_deterministic)),
                        deterministic_rel=np.abs(true_return - np.mean(
                            eval_returns_deterministic)) / true_return)))

        expert_batch = expert_dataset.sample(batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        expert_next_state = np.stack([t.next_obs for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              expert_next_state,
                                              fetch='train loss grad_norm')

        if t % 100 == 0:
            train_mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                                 expert_next_state)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action,
                                                eval_next_state)
            log_kvs(prefix='BC',
                    kvs=dict(iter=t,
                             grad_norm=grad_norm,
                             loss=loss,
                             mse_loss=dict(train=train_mse_loss,
                                           eval=eval_mse_loss)))

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(
            actor, env_eval_stochastic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)
Exemplo n.º 30
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = create_env(FLAGS.env.id,
                     FLAGS.seed,
                     rescale_action=FLAGS.env.rescale_action)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    # load expert dataset
    set_random_seed(2020)
    expert_dataset = load_expert_dataset(FLAGS.GAIL.buf_load)
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    expert_done = np.stack([t.done for t in expert_dataset.buffer()])
    np.testing.assert_allclose(
        expert_next_state[:-1] * (1 - expert_done[:-1][:, None]),
        expert_state[1:] * (1 - expert_done[:-1][:, None]))
    del expert_state, expert_next_state, expert_done
    expert_reward = expert_dataset.get_average_reward()
    logger.info('Expert Reward %f', expert_reward)
    if FLAGS.GAIL.learn_absorbing:
        expert_dataset.add_absorbing_states(env)
    eval_batch = expert_dataset.sample(1024)
    eval_state = np.stack([t.obs for t in eval_batch])
    eval_action = np.stack([t.action for t in eval_batch])
    eval_next_state = np.stack([t.next_obs for t in eval_batch])
    logger.info('Sampled obs: %.4f, acs: %.4f', np.mean(eval_state),
                np.mean(eval_action))
    expert_dataset.subsample_trajectories(FLAGS.GAIL.traj_limit)
    logger.info('Original dataset size {}'.format(len(expert_dataset)))
    expert_dataset.subsample_transitions(subsampling_rate)
    logger.info('Subsampled dataset size {}'.format(len(expert_dataset)))
    logger.info('np random: %d random : %d', np.random.randint(1000),
                random.randint(0, 1000))
    set_random_seed(FLAGS.seed)

    # expert actor
    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)
    # generator
    normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state)
    policy = GaussianMLPPolicy(dim_state,
                               dim_action,
                               FLAGS.TRPO.policy_hidden_sizes,
                               output_diff=FLAGS.TRPO.output_diff,
                               normalizers=normalizers)
    vfn = MLPVFunction(dim_state, dim_action, FLAGS.TRPO.vf_hidden_sizes,
                       normalizers.state)
    algo = TRPO(vfn=vfn,
                policy=policy,
                dim_state=dim_state,
                dim_action=dim_action,
                **FLAGS.TRPO.algo.as_dict())

    subsampling_rate = env.max_episode_steps // FLAGS.GAIL.trajectory_size
    if FLAGS.GAIL.reward_type == 'nn':
        expert_batch = expert_dataset.buffer()
        expert_state = np.stack([t.obs for t in expert_batch])
        loc, scale = np.mean(expert_state, axis=0,
                             keepdims=True), np.std(expert_state,
                                                    axis=0,
                                                    keepdims=True)
        del expert_batch, expert_state
        logger.info('loc = {}\nscale={}'.format(loc, scale))
        discriminator = Discriminator(dim_state,
                                      dim_action,
                                      normalizers=normalizers,
                                      subsampling_rate=subsampling_rate,
                                      loc=loc,
                                      scale=scale,
                                      **FLAGS.GAIL.discriminator.as_dict())
    else:
        raise NotImplementedError
    bc_loss = BehavioralCloningLoss(dim_state,
                                    dim_action,
                                    policy,
                                    lr=FLAGS.BC.lr,
                                    train_std=FLAGS.BC.train_std)
    tf.get_default_session().run(tf.global_variables_initializer())

    loader = nn.ModuleDict({'actor': actor})
    loader.load_state_dict(
        np.load(FLAGS.ckpt.policy_load, allow_pickle=True)[()])
    logger.info('Load policy from %s' % FLAGS.ckpt.policy_load)
    saver = nn.ModuleDict({
        'policy': policy,
        'vfn': vfn,
        'normalizers': normalizers,
        'discriminator': discriminator
    })
    print(saver)

    # updater normalizer
    expert_state = np.stack([t.obs for t in expert_dataset.buffer()])
    expert_action = np.stack([t.action for t in expert_dataset.buffer()])
    expert_next_state = np.stack([t.next_obs for t in expert_dataset.buffer()])
    normalizers.state.update(expert_state)
    normalizers.action.update(expert_action)
    normalizers.diff.update(expert_next_state - expert_state)
    del expert_state, expert_action, expert_next_state

    eval_gamma = 0.999
    eval_returns, eval_lengths = evaluate_on_true_env(actor,
                                                      env,
                                                      gamma=eval_gamma)
    logger.warning(
        'Test policy true value = %.4f true length = %d (gamma = %f)',
        np.mean(eval_returns), np.mean(eval_lengths), eval_gamma)

    # pretrain
    for n_updates in range(FLAGS.GAIL.pretrain_iters):
        expert_batch = expert_dataset.sample(FLAGS.BC.batch_size)
        expert_state = np.stack([t.obs for t in expert_batch])
        expert_action = np.stack([t.action for t in expert_batch])
        expert_next_state = np.stack([t.next_obs for t in expert_batch])
        _, loss, grad_norm = bc_loss.get_loss(expert_state,
                                              expert_action,
                                              expert_next_state,
                                              fetch='train loss grad_norm')
        if n_updates % 100 == 0:
            mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                           expert_next_state)
            logger.info(
                '[Pretrain] iter = %d grad_norm = %.4f loss = %.4f mse_loss = %.4f',
                n_updates, grad_norm, loss, mse_loss)

    # virtual env
    virtual_env = VirtualEnv(policy,
                             env,
                             n_envs=FLAGS.env.num_env,
                             stochastic_model=True)
    virtual_runner = VirtualRunner(virtual_env,
                                   max_steps=env.max_episode_steps,
                                   gamma=FLAGS.TRPO.gamma,
                                   lambda_=FLAGS.TRPO.lambda_,
                                   rescale_action=False)
    env_eval_stochastic = VirtualEnv(policy,
                                     env,
                                     n_envs=4,
                                     stochastic_model=True)
    env_eval_deterministic = VirtualEnv(policy,
                                        env,
                                        n_envs=4,
                                        stochastic_model=False)

    max_ent_coef = FLAGS.TRPO.algo.ent_coef
    true_return = np.mean(eval_returns)
    for t in range(0, FLAGS.GAIL.total_timesteps,
                   FLAGS.TRPO.rollout_samples * FLAGS.GAIL.g_iters):
        time_st = time.time()
        if t % FLAGS.GAIL.eval_freq == 0:
            eval_returns_stochastic, eval_lengths_stochastic = evaluate_on_virtual_env(
                actor, env_eval_stochastic, gamma=eval_gamma)
            eval_returns_deterministic, eval_lengths_deterministic = evaluate_on_virtual_env(
                actor, env_eval_deterministic, gamma=eval_gamma)
            log_kvs(
                prefix='Evaluate',
                kvs=dict(
                    iter=t,
                    stochastic_episode=dict(
                        returns=np.mean(eval_returns_stochastic),
                        lengths=int(np.mean(eval_lengths_stochastic))),
                    episode=dict(returns=np.mean(eval_returns_deterministic),
                                 lengths=int(
                                     np.mean(eval_lengths_deterministic))),
                    evaluation_error=dict(
                        stochastic_error=true_return -
                        np.mean(eval_returns_stochastic),
                        stochastic_abs=np.abs(
                            true_return - np.mean(eval_returns_stochastic)),
                        stochastic_rel=np.abs(true_return -
                                              np.mean(eval_returns_stochastic))
                        / true_return,
                        deterministic_error=true_return -
                        np.mean(eval_returns_deterministic),
                        deterministic_abs=np.abs(
                            true_return - np.mean(eval_returns_deterministic)),
                        deterministic_rel=np.abs(true_return - np.mean(
                            eval_returns_deterministic)) / true_return)))
        # Generator
        generator_dataset = None
        for n_update in range(FLAGS.GAIL.g_iters):
            data, ep_infos = virtual_runner.run(actor,
                                                FLAGS.TRPO.rollout_samples,
                                                stochastic=False)
            # if FLAGS.TRPO.normalization:
            #     normalizers.state.update(data.state)
            #     normalizers.action.update(data.action)
            #     normalizers.diff.update(data.next_state - data.state)
            if t == 0:
                np.testing.assert_allclose(data.reward,
                                           env.mb_step(data.state, data.action,
                                                       data.next_state)[0],
                                           atol=1e-4,
                                           rtol=1e-4)
            if t == 0 and n_update == 0 and not FLAGS.GAIL.learn_absorbing:
                data_ = data.copy()
                data_ = data_.reshape(
                    [FLAGS.TRPO.rollout_samples // env.n_envs, env.n_envs])
                for e in range(env.n_envs):
                    samples = data_[:, e]
                    masks = 1 - (samples.done | samples.timeout)[...,
                                                                 np.newaxis]
                    masks = masks[:-1]
                    assert np.allclose(samples.state[1:] * masks,
                                       samples.next_state[:-1] * masks)
            t += FLAGS.TRPO.rollout_samples
            data.reward = discriminator.get_reward(data.state, data.action,
                                                   data.next_state)
            advantages, values = virtual_runner.compute_advantage(vfn, data)
            train_info = algo.train(max_ent_coef, data, advantages, values)
            fps = int(FLAGS.TRPO.rollout_samples / (time.time() - time_st))
            train_info['reward'] = np.mean(data.reward)
            train_info['fps'] = fps

            expert_batch = expert_dataset.sample(256)
            expert_state = np.stack([t.obs for t in expert_batch])
            expert_action = np.stack([t.action for t in expert_batch])
            expert_next_state = np.stack([t.next_obs for t in expert_batch])
            train_mse_loss = policy.get_mse_loss(expert_state, expert_action,
                                                 expert_next_state)
            eval_mse_loss = policy.get_mse_loss(eval_state, eval_action,
                                                eval_next_state)
            train_info['mse_loss'] = dict(train=train_mse_loss,
                                          eval=eval_mse_loss)
            log_kvs(prefix='TRPO', kvs=dict(iter=t, **train_info))

            generator_dataset = data

        # Discriminator
        for n_update in range(FLAGS.GAIL.d_iters):
            batch_size = FLAGS.GAIL.d_batch_size
            d_train_infos = dict()
            for generator_subset in generator_dataset.iterator(batch_size):
                expert_batch = expert_dataset.sample(batch_size)
                expert_state = np.stack([t.obs for t in expert_batch])
                expert_action = np.stack([t.action for t in expert_batch])
                expert_next_state = np.stack(
                    [t.next_obs for t in expert_batch])
                expert_mask = None
                train_info = discriminator.train(
                    expert_state,
                    expert_action,
                    expert_next_state,
                    generator_subset.state,
                    generator_subset.action,
                    generator_subset.next_state,
                    expert_mask,
                )
                for k, v in train_info.items():
                    if k not in d_train_infos:
                        d_train_infos[k] = []
                    d_train_infos[k].append(v)
            d_train_infos = {k: np.mean(v) for k, v in d_train_infos.items()}
            if n_update == FLAGS.GAIL.d_iters - 1:
                log_kvs(prefix='Discriminator',
                        kvs=dict(iter=t, **d_train_infos))

        if t % FLAGS.TRPO.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    dict_result = dict()
    for gamma in [0.9, 0.99, 0.999, 1.0]:
        eval_returns, eval_lengths = evaluate_on_virtual_env(
            actor, env_eval_stochastic, gamma=gamma)
        dict_result[gamma] = [float(np.mean(eval_returns)), eval_returns]
        logger.info('[%s]: %.4f', gamma, np.mean(eval_returns))

    save_path = os.path.join(FLAGS.log_dir, 'evaluate.yml')
    yaml.dump(dict_result, open(save_path, 'w'), default_flow_style=False)