Exemplo n.º 1
0
class WorkerData(Worker):
    def __init__(self, policy_ps, data_buffers, time_sleep, name, exp_dir,
                 n_itr, stop_cond):
        super().__init__(name, exp_dir, n_itr, stop_cond)
        self.policy_ps = policy_ps
        self.data_buffers = data_buffers
        self.time_sleep = time_sleep
        self.env = None
        self.env_sampler = None
        self.dynamics_sample_processor = None

    def prepare_start(self, policy_pickle, env_pickle, baseline_pickle,
                      feed_dict, config, initial_random_samples):
        import tensorflow as tf
        self.sess = sess = tf.Session(config=config)
        with sess.as_default():
            """ --------------------- Construct instances -------------------"""

            from asynch_mb.samplers.sampler import Sampler
            from asynch_mb.samplers.mb_sample_processor import ModelSampleProcessor

            env = pickle.loads(env_pickle)
            policy = pickle.loads(policy_pickle)
            baseline = pickle.loads(baseline_pickle)
            sess.run(tf.initializers.global_variables())

            self.env = env
            self.env_sampler = Sampler(env=env,
                                       policy=policy,
                                       **feed_dict['env_sampler'])
            self.dynamics_sample_processor = ModelSampleProcessor(
                baseline=baseline, **feed_dict['dynamics_sample_processor'])
            """ ------------------- Step and Push ------------------"""

            samples_data = self.step(random=initial_random_samples)
            self.push(samples_data)

        return 1

    def step_wrapper(self):
        self.pull()
        samples_data = self.step()
        self.push(samples_data)
        return 1, 1

    def step(self, random=False):
        time_step = time.time()
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing environment samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        time_step = time.time() - time_step

        time_sleep = max(self.time_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)

        return samples_data

    def pull(self):
        time_synch = time.time()
        policy_params = ray.get(self.policy_ps.pull.remote())
        assert isinstance(policy_params, dict)
        self.env_sampler.policy.set_shared_params(policy_params)
        logger.logkv('Data-TimePull', time.time() - time_synch)

    def push(self, samples_data):
        time_push = time.time()
        # broadcast samples to all data buffers
        samples_data_id = ray.put(samples_data)
        for data_buffer in self.data_buffers:
            # ray.get(data_buffer.push.remote(samples_data))
            data_buffer.push.remote(samples_data_id)
        logger.logkv('Data-TimePush', time.time() - time_push)

    def set_stop_cond(self):
        if self.step_counter >= self.n_itr:
            ray.get(self.stop_cond.set.remote())
Exemplo n.º 2
0
class WorkerData(Worker):
    def __init__(self, simulation_sleep, video=False):
        if video:
            super().__init__(snapshot_mode='gap', snapshot_gap=int(30/1250/simulation_sleep))  # FIXME
        else:
            super().__init__()
        self.simulation_sleep = simulation_sleep
        self.env = None
        self.env_sampler = None
        self.dynamics_sample_processor = None
        self.samples_data_arr = []

    def construct_from_feed_dict(
            self,
            policy_pickle,
            env_pickle,
            baseline_pickle,
            dynamics_model_pickle,
            feed_dict
    ):

        from asynch_mb.samplers.sampler import Sampler
        from asynch_mb.samplers.mb_sample_processor import ModelSampleProcessor

        env = pickle.loads(env_pickle)
        policy = pickle.loads(policy_pickle)
        baseline = pickle.loads(baseline_pickle)

        self.env = env
        self.env_sampler = Sampler(env=env, policy=policy, **feed_dict['env_sampler'])
        self.dynamics_sample_processor = ModelSampleProcessor(
            baseline=baseline,
            **feed_dict['dynamics_sample_processor']
        )

    def prepare_start(self):
        initial_random_samples = self.queue.get()
        self.step(initial_random_samples)
        self.push()

    def step(self, random=False):
        time_step = time.time()

        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )

        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing environment samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)

        # save snapshot
        params = self.get_itr_snapshot()
        logger.save_itr_params(self.itr_counter, params)

    def _synch(self, policy_state_pickle):
        time_synch = time.time()
        policy_state = pickle.loads(policy_state_pickle)
        assert isinstance(policy_state, dict)
        self.env_sampler.policy.set_shared_params(policy_state)
        time_synch = time.time() - time_synch

        logger.logkv('Data-TimeSynch', time_synch)

    def push(self):
        time_push = time.time()
        self.queue_next.put(pickle.dumps(self.samples_data_arr))
        self.samples_data_arr = []
        time_push = time.time() - time_push

        logger.logkv('Data-TimePush', time_push)

    def set_stop_cond(self):
        if self.itr_counter >= self.n_itr:
            self.stop_cond.set()

    def get_itr_snapshot(self):
        """
        Gets the current policy and env for storage
        """
        return dict(itr=self.itr_counter, policy=self.env_sampler.policy, env=self.env)
Exemplo n.º 3
0
class WorkerData(Worker):
    def __init__(self, simulation_sleep):
        super().__init__()
        self.simulation_sleep = simulation_sleep
        self.env = None
        self.env_sampler = None
        self.dynamics_sample_processor = None
        self.samples_data_arr = []

    def construct_from_feed_dict(
        self,
        policy_pickle,
        env_pickle,
        baseline_pickle,  # UNUSED
        dynamics_model_pickle,
        feed_dict,
    ):

        from asynch_mb.samplers.sampler import Sampler
        from asynch_mb.samplers.mb_sample_processor import ModelSampleProcessor

        env = pickle.loads(env_pickle)
        policy = pickle.loads(policy_pickle)

        self.env = env
        self.env_sampler = Sampler(env=env,
                                   policy=policy,
                                   **feed_dict['sampler'])
        self.dynamics_sample_processor = ModelSampleProcessor(
            **feed_dict['sample_processor'])

    def prepare_start(self):
        random_sinusoid = self.queue.get()
        self.step(random_sinusoid)
        self.push()

    def step(self, random_sinusoid=(False, False)):
        time_step = time.time()

        if self.itr_counter == 1 and self.env_sampler.policy.dynamics_model.normalization is None:
            if self.verbose:
                logger.log('Data starts first step...')
            self.env_sampler.policy.dynamics_model = pickle.loads(
                self.queue.get())
            if self.verbose:
                logger.log('Data first step done...')
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random_sinusoid[0],
            sinusoid=random_sinusoid[1],
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)

    def _synch(self, dynamics_model_state_pickle):
        time_synch = time.time()
        dynamics_model_state = pickle.loads(dynamics_model_state_pickle)
        assert isinstance(dynamics_model_state, dict)
        self.env_sampler.policy.dynamics_model.set_shared_params(
            dynamics_model_state)
        time_synch = time.time() - time_synch

        logger.logkv('Data-TimeSynch', time_synch)

    def push(self):
        time_push = time.time()
        self.queue_next.put(pickle.dumps(self.samples_data_arr))
        self.samples_data_arr = []
        time_push = time.time() - time_push

        logger.logkv('Data-TimePush', time_push)

    def set_stop_cond(self):
        if self.itr_counter >= self.n_itr:
            self.stop_cond.set()