示例#1
0
    def __init__(self, env, steps, save_file):
        self._env = env
        self._steps = steps
        self._save_file = save_file

        self._sampler = Sampler(
                            policy=DummyPolicy(),
                            env=self._env,
                            n_envs=1,
                            replay_pool_size=steps,
                            max_path_length=self._env.horizon,
                            sampling_method='uniform',
                            save_rollouts=True,
                            save_rollouts_observations=True,
                            save_env_infos=True)
示例#2
0
 def _create_sampler(self):
     return Sampler(policy=self._policy,
                    env=self._env,
                    n_envs=1,
                    replay_pool_size=int(1e5),
                    max_path_length=self._env.horizon,
                    sampling_method='uniform',
                    save_rollouts=True,
                    save_rollouts_observations=True,
                    save_env_infos=True)
示例#3
0
class GatherRandomData(object):
    def __init__(self, env, steps, save_file):
        self._env = env
        self._steps = steps
        self._save_file = save_file

        self._sampler = Sampler(
                            policy=DummyPolicy(),
                            env=self._env,
                            n_envs=1,
                            replay_pool_size=steps,
                            max_path_length=self._env.horizon,
                            sampling_method='uniform',
                            save_rollouts=True,
                            save_rollouts_observations=True,
                            save_env_infos=True)

    def _itr_save_file(self, itr):
        path, ext = os.path.splitext(self._save_file)
        return '{0}_{1:02d}{2}'.format(path, itr, ext)

    def run(self):
        rollouts = []
        itr = 0

        self._sampler.reset()
        step = 0
        while step < self._steps:
            self._sampler.step(step, take_random_actions=True)
            step += 1

            rollouts += self._sampler.get_recent_paths()
            if step > 0 and step % 5000 == 0 and len(rollouts) > 0:
                mypickle.dump({'rollouts': rollouts}, self._itr_save_file(itr))
                rollouts = []
                itr += 1

        if len(rollouts) > 0:
            mypickle.dump({'rollouts': rollouts}, self._itr_save_file(itr))
示例#4
0
    def __init__(self, eval_itr, eval_save_dir, **kwargs):
        self._eval_itr = eval_itr
        self._eval_save_dir = eval_save_dir

        self._save_dir = kwargs['save_dir']
        self._load_dir = kwargs.get('load_dir', self._save_dir)
        self._env_eval = kwargs['env_eval']
        self._policy = kwargs['policy']

        self._sampler = Sampler(
            policy=kwargs['policy'],
            env=kwargs['env_eval'],
            n_envs=1,
            replay_pool_size=int(np.ceil(1.5 * kwargs['max_path_length']) + 1),
            max_path_length=kwargs['max_path_length'],
            sampling_method=kwargs['replay_pool_sampling'],
            save_rollouts=True,
            save_rollouts_observations=True,
            save_env_infos=True,
            env_dict=kwargs['env_dict'],
            replay_pool_params=kwargs['replay_pool_params']
        )
示例#5
0
文件: gcg.py 项目: XieKaixuan/gcg
    def __init__(self, **kwargs):

        self._save_dir = kwargs['save_dir']
        self._load_dir = kwargs.get('load_dir', self._save_dir)
        self._env = kwargs['env']
        self._policy = kwargs['policy']

        self._batch_size = kwargs['batch_size']
        self._save_rollouts = kwargs['save_rollouts']
        self._save_rollouts_observations = kwargs['save_rollouts_observations']

        self._sampler = Sampler(
            policy=kwargs['policy'],
            env=kwargs['env'],
            n_envs=kwargs['n_envs'],
            replay_pool_size=kwargs['replay_pool_size'],
            max_path_length=kwargs['max_path_length'],
            sampling_method=kwargs['replay_pool_sampling'],
            save_rollouts=kwargs['save_rollouts'],
            save_rollouts_observations=kwargs['save_rollouts_observations'],
            save_env_infos=kwargs['save_env_infos'],
            env_dict=kwargs['env_dict'],
            replay_pool_params=kwargs['replay_pool_params'])

        if kwargs['env_eval'] is not None:
            self._eval_sampler = Sampler(
                policy=kwargs['policy'],
                env=kwargs['env_eval'],
                n_envs=1,
                replay_pool_size=int(
                    np.ceil(1.5 * kwargs['max_path_length']) + 1),
                max_path_length=kwargs['max_path_length'],
                sampling_method=kwargs['replay_pool_sampling'],
                save_rollouts=True,
                save_rollouts_observations=kwargs.get(
                    'save_eval_rollouts_observations', False),
                save_env_infos=kwargs['save_env_infos'],
                replay_pool_params=kwargs['replay_pool_params'])
        else:
            self._eval_sampler = None

        if kwargs.get('offpolicy', None) is not None:
            self._add_offpolicy(kwargs['offpolicy'],
                                max_to_add=kwargs['num_offpolicy'])

        alg_args = kwargs
        self._total_steps = int(alg_args['total_steps'])
        self._sample_after_n_steps = int(alg_args['sample_after_n_steps'])
        self._onpolicy_after_n_steps = int(alg_args['onpolicy_after_n_steps'])
        self._learn_after_n_steps = int(alg_args['learn_after_n_steps'])
        self._train_every_n_steps = alg_args['train_every_n_steps']
        self._eval_every_n_steps = int(alg_args['eval_every_n_steps'])
        self._rollouts_per_eval = int(alg_args.get('rollouts_per_eval', 1))
        self._save_every_n_steps = int(alg_args['save_every_n_steps'])
        self._update_target_after_n_steps = int(
            alg_args['update_target_after_n_steps'])
        self._update_target_every_n_steps = int(
            alg_args['update_target_every_n_steps'])
        self._log_every_n_steps = int(alg_args['log_every_n_steps'])
        assert (self._learn_after_n_steps % self._sampler.n_envs == 0)
        if self._train_every_n_steps >= 1:
            assert (int(self._train_every_n_steps) % self._sampler.n_envs == 0)
        assert (self._save_every_n_steps % self._sampler.n_envs == 0)
        assert (self._update_target_every_n_steps % self._sampler.n_envs == 0)
示例#6
0
文件: gcg.py 项目: XieKaixuan/gcg
class GCG(object):
    def __init__(self, **kwargs):

        self._save_dir = kwargs['save_dir']
        self._load_dir = kwargs.get('load_dir', self._save_dir)
        self._env = kwargs['env']
        self._policy = kwargs['policy']

        self._batch_size = kwargs['batch_size']
        self._save_rollouts = kwargs['save_rollouts']
        self._save_rollouts_observations = kwargs['save_rollouts_observations']

        self._sampler = Sampler(
            policy=kwargs['policy'],
            env=kwargs['env'],
            n_envs=kwargs['n_envs'],
            replay_pool_size=kwargs['replay_pool_size'],
            max_path_length=kwargs['max_path_length'],
            sampling_method=kwargs['replay_pool_sampling'],
            save_rollouts=kwargs['save_rollouts'],
            save_rollouts_observations=kwargs['save_rollouts_observations'],
            save_env_infos=kwargs['save_env_infos'],
            env_dict=kwargs['env_dict'],
            replay_pool_params=kwargs['replay_pool_params'])

        if kwargs['env_eval'] is not None:
            self._eval_sampler = Sampler(
                policy=kwargs['policy'],
                env=kwargs['env_eval'],
                n_envs=1,
                replay_pool_size=int(
                    np.ceil(1.5 * kwargs['max_path_length']) + 1),
                max_path_length=kwargs['max_path_length'],
                sampling_method=kwargs['replay_pool_sampling'],
                save_rollouts=True,
                save_rollouts_observations=kwargs.get(
                    'save_eval_rollouts_observations', False),
                save_env_infos=kwargs['save_env_infos'],
                replay_pool_params=kwargs['replay_pool_params'])
        else:
            self._eval_sampler = None

        if kwargs.get('offpolicy', None) is not None:
            self._add_offpolicy(kwargs['offpolicy'],
                                max_to_add=kwargs['num_offpolicy'])

        alg_args = kwargs
        self._total_steps = int(alg_args['total_steps'])
        self._sample_after_n_steps = int(alg_args['sample_after_n_steps'])
        self._onpolicy_after_n_steps = int(alg_args['onpolicy_after_n_steps'])
        self._learn_after_n_steps = int(alg_args['learn_after_n_steps'])
        self._train_every_n_steps = alg_args['train_every_n_steps']
        self._eval_every_n_steps = int(alg_args['eval_every_n_steps'])
        self._rollouts_per_eval = int(alg_args.get('rollouts_per_eval', 1))
        self._save_every_n_steps = int(alg_args['save_every_n_steps'])
        self._update_target_after_n_steps = int(
            alg_args['update_target_after_n_steps'])
        self._update_target_every_n_steps = int(
            alg_args['update_target_every_n_steps'])
        self._log_every_n_steps = int(alg_args['log_every_n_steps'])
        assert (self._learn_after_n_steps % self._sampler.n_envs == 0)
        if self._train_every_n_steps >= 1:
            assert (int(self._train_every_n_steps) % self._sampler.n_envs == 0)
        assert (self._save_every_n_steps % self._sampler.n_envs == 0)
        assert (self._update_target_every_n_steps % self._sampler.n_envs == 0)

    #############
    ### Files ###
    #############

    def _train_rollouts_file_name(self, itr):
        return os.path.join(self._save_dir,
                            'itr_{0:04d}_train_rollouts.pkl'.format(itr))

    def _eval_rollouts_file_name(self, itr):
        return os.path.join(self._save_dir,
                            'itr_{0:04d}_eval_rollouts.pkl'.format(itr))

    def _train_policy_file_name(self, itr):
        return os.path.join(self._save_dir,
                            'itr_{0:04d}_train_policy.ckpt'.format(itr))

    def _inference_policy_file_name(self, itr):
        return os.path.join(self._save_dir,
                            'itr_{0:04d}_inference_policy.ckpt'.format(itr))

    def _load_train_policy_file_name(self, itr):
        return os.path.join(self._load_dir,
                            'itr_{0:04d}_train_policy.ckpt'.format(itr))

    def _load_inference_policy_file_name(self, itr):
        return os.path.join(self._load_dir,
                            'itr_{0:04d}_inference_policy.ckpt'.format(itr))

    ############
    ### Save ###
    ############

    def _save_train_rollouts(self, itr, rollouts):
        fname = self._train_rollouts_file_name(itr)
        mypickle.dump({'rollouts': rollouts}, fname)

    def _save_eval_rollouts(self, itr, rollouts):
        fname = self._eval_rollouts_file_name(itr)
        mypickle.dump({'rollouts': rollouts}, fname)

    def _save_train_policy(self, itr):
        self._policy.save(self._train_policy_file_name(itr), train=True)

    def _save_inference_policy(self, itr):
        self._policy.save(self._inference_policy_file_name(itr), train=False)

    def _save_train(self, itr):
        self._save_train_policy(itr)
        self._save_inference_policy(itr)

    def _save_inference(self, itr, train_rollouts, eval_rollouts):
        self._save_train_rollouts(itr, train_rollouts)
        self._save_eval_rollouts(itr, eval_rollouts)

    def _save(self, itr, train_rollouts, eval_rollouts):
        self._save_train(itr)
        self._save_inference(itr, train_rollouts, eval_rollouts)

    ###############
    ### Restore ###
    ###############

    def _add_offpolicy(self, folders, max_to_add):
        for folder in folders:
            assert (os.path.exists(folder))
            logger.info('Loading offpolicy data from {0}'.format(folder))
            rollout_filenames = [
                os.path.join(folder, fname) for fname in os.listdir(folder)
                if 'train_rollouts.pkl' in fname
            ]
            self._sampler.add_rollouts(rollout_filenames,
                                       max_to_add=max_to_add)
        logger.info('Added {0} samples'.format(len(self._sampler)))

    def _get_train_itr(self):
        train_itr = 0
        while len(
                glob.glob(
                    os.path.splitext(
                        self._inference_policy_file_name(train_itr))[0] +
                    '*')) > 0:
            train_itr += 1

        return train_itr

    def _get_inference_itr(self):
        inference_itr = 0
        while len(
                glob.glob(self._train_rollouts_file_name(inference_itr) +
                          '*')) > 0:
            inference_itr += 1

        return inference_itr

    def _restore_train_rollouts(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        rollout_filenames = []
        while True:
            fname = self._train_rollouts_file_name(itr)
            if not os.path.exists(fname):
                break

            rollout_filenames.append(fname)
            itr += 1

        logger.info(
            'Restoring {0} iterations of train rollouts....'.format(itr))
        self._sampler.add_rollouts(rollout_filenames)
        logger.info('Done restoring rollouts!')

    def _restore_train_policy(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        while len(
                glob.glob(
                    os.path.splitext(self._load_train_policy_file_name(itr))[0]
                    + '*')) > 0:
            itr += 1

        if itr > 0:
            logger.info(
                'Loading train policy from {0} iteration {1}...'.format(
                    self._load_dir, itr - 1))
            self._policy.restore(self._load_train_policy_file_name(itr - 1),
                                 train=True)
            logger.info('Loaded train policy!')

    def _restore_inference_policy(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        while len(
                glob.glob(
                    os.path.splitext(self._load_inference_policy_file_name(
                        itr))[0] + '*')) > 0:
            itr += 1

        if itr > 0:
            logger.info(
                'Loading inference policy from iteration {0}...'.format(itr -
                                                                        1))
            self._policy.restore(self._load_inference_policy_file_name(itr -
                                                                       1),
                                 train=False)
            logger.info('Loaded inference policy!')

    def _restore(self):
        self._restore_train_rollouts()
        self._restore_train_policy()

        train_itr = self._get_train_itr()
        inference_itr = self._get_inference_itr()
        assert (train_itr == inference_itr,
                'Train itr is {0} but inference itr is {1}'.format(
                    train_itr, inference_itr))
        return train_itr

    ########################
    ### Training methods ###
    ########################

    def train(self):
        ### restore where we left off
        save_itr = self._restore()

        target_updated = False
        eval_rollouts = []

        self._sampler.reset()
        if self._eval_sampler is not None:
            self._eval_sampler.reset()

        timeit.reset()
        timeit.start('total')
        for step in range(0, self._total_steps, self._sampler.n_envs):
            ### sample and add to buffer
            if step > self._sample_after_n_steps:
                timeit.start('sample')
                self._sampler.step(
                    step,
                    take_random_actions=(step <= self._onpolicy_after_n_steps),
                    explore=True)
                timeit.stop('sample')

            ### sample and DON'T add to buffer (for validation)
            if self._eval_sampler is not None and step > 0 and step % self._eval_every_n_steps == 0:
                timeit.start('eval')
                for _ in range(self._rollouts_per_eval):
                    eval_rollouts_step = []
                    eval_step = step
                    while len(eval_rollouts_step) == 0:
                        self._eval_sampler.step(eval_step, explore=False)
                        eval_rollouts_step = self._eval_sampler.get_recent_paths(
                        )
                        eval_step += 1
                    eval_rollouts += eval_rollouts_step
                timeit.stop('eval')

            if step >= self._learn_after_n_steps:
                ### training step
                if self._train_every_n_steps >= 1:
                    if step % int(self._train_every_n_steps) == 0:
                        timeit.start('batch')
                        steps, observations, goals, actions, rewards, dones, _ = \
                            self._sampler.sample(self._batch_size)
                        timeit.stop('batch')
                        timeit.start('train')
                        self._policy.train_step(step,
                                                steps=steps,
                                                observations=observations,
                                                goals=goals,
                                                actions=actions,
                                                rewards=rewards,
                                                dones=dones,
                                                use_target=target_updated)
                        timeit.stop('train')
                else:
                    for _ in range(int(1. / self._train_every_n_steps)):
                        timeit.start('batch')
                        steps, observations, goals, actions, rewards, dones, _ = \
                            self._sampler.sample(self._batch_size)
                        timeit.stop('batch')
                        timeit.start('train')
                        self._policy.train_step(step,
                                                steps=steps,
                                                observations=observations,
                                                goals=goals,
                                                actions=actions,
                                                rewards=rewards,
                                                dones=dones,
                                                use_target=target_updated)
                        timeit.stop('train')

                ### update target network
                if step > self._update_target_after_n_steps and step % self._update_target_every_n_steps == 0:
                    self._policy.update_target()
                    target_updated = True

                ### log
                if step % self._log_every_n_steps == 0:
                    logger.record_tabular('Step', step)
                    self._sampler.log()
                    self._eval_sampler.log(prefix='Eval')
                    self._policy.log()
                    logger.dump_tabular(print_func=logger.info)
                    timeit.stop('total')
                    for line in str(timeit).split('\n'):
                        logger.debug(line)
                    timeit.reset()
                    timeit.start('total')

            ### save model
            if step > 0 and step % self._save_every_n_steps == 0:
                logger.info('Saving files for itr {0}'.format(save_itr))
                self._save(save_itr, self._sampler.get_recent_paths(),
                           eval_rollouts)
                save_itr += 1
                eval_rollouts = []

        self._save(save_itr, self._sampler.get_recent_paths(), eval_rollouts)
示例#7
0
class EvalGCG(GCG):

    def __init__(self, eval_itr, eval_save_dir, **kwargs):
        self._eval_itr = eval_itr
        self._eval_save_dir = eval_save_dir

        self._save_dir = kwargs['save_dir']
        self._load_dir = kwargs.get('load_dir', self._save_dir)
        self._env_eval = kwargs['env_eval']
        self._policy = kwargs['policy']

        self._sampler = Sampler(
            policy=kwargs['policy'],
            env=kwargs['env_eval'],
            n_envs=1,
            replay_pool_size=int(np.ceil(1.5 * kwargs['max_path_length']) + 1),
            max_path_length=kwargs['max_path_length'],
            sampling_method=kwargs['replay_pool_sampling'],
            save_rollouts=True,
            save_rollouts_observations=True,
            save_env_infos=True,
            env_dict=kwargs['env_dict'],
            replay_pool_params=kwargs['replay_pool_params']
        )

    #############
    ### Files ###
    #############

    def _eval_rollouts_file_name(self, itr):
        return os.path.join(self._eval_save_dir, 'itr_{0:04d}_eval_rollouts.pkl'.format(itr))

    ############
    ### Save ###
    ############

    def _save_eval_rollouts(self, rollouts):
        fname = self._eval_rollouts_file_name(self._eval_itr)
        mypickle.dump({'rollouts': rollouts}, fname)

    ###############
    ### Restore ###
    ###############

    def _load_eval_rollouts(self, itr):
        fname = self._eval_rollouts_file_name(itr)
        if os.path.exists(fname):
            rollouts = mypickle.load(fname)['rollouts']
        else:
            rollouts = []
        return rollouts

    ############
    ### Eval ###
    ############

    def _eval_reset(self, **kwargs):
        self._sampler.reset(**kwargs)

    def _eval_step(self):
        self._sampler.step(step=0,
                           take_random_actions=False,
                           explore=False)

    def _eval_save(self, rollouts, new_rollouts):
        assert (len(new_rollouts) > 0)

        logger.info('Saving rollouts')
        rollouts += new_rollouts
        self._save_eval_rollouts(rollouts)

        return rollouts
    
    def eval(self):
        ### Load policy
        policy_fname = self._load_inference_policy_file_name(self._eval_itr)
        if len(glob.glob(os.path.splitext(policy_fname)[0] + '*')) == 0:
            logger.error('Policy for {0} does not exist'.format(policy_fname))
            sys.exit(0)
        logger.info('Restoring policy for itr {0}'.format(self._eval_itr))
        self._policy.restore(policy_fname, train=False)

        ### Load previous eval rollouts
        logger.info('Loading previous eval rollouts')
        rollouts = self._load_eval_rollouts(self._eval_itr)
        logger.info('Loaded {0} rollouts'.format(len(rollouts)))

        self._eval_reset()
        
        logger.info('')
        logger.info('Rollout {0}'.format(len(rollouts)))
        while True:
            self._eval_step()
            
            new_rollouts = self._sampler.get_recent_paths()
            if len(new_rollouts) > 0:
                rollouts = self._eval_save(rollouts, new_rollouts)

                logger.info('')
                logger.info('Rollout {0}'.format(len(rollouts)))