コード例 #1
0
    def step(self, random=False):
        time_step = time.time()
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing environment samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        time_step = time.time() - time_step

        time_sleep = max(self.time_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)

        return samples_data
コード例 #2
0
    def pull(self, check_init=False):
        time_synch = time.time()
        samples_data_arr = ray.get(self.data_buffer.pull.remote())
        if check_init or not self.remaining_model_idx:
            # block wait until some data comes
            time_wait = time.time()
            while not samples_data_arr:
                samples_data_arr = ray.get(self.data_buffer.pull.remote())
            logger.logkv('Model-TimeBlockWait', time.time() - time_wait)
        if samples_data_arr:
            obs = np.concatenate([samples_data['observations'] for samples_data in samples_data_arr])
            act = np.concatenate([samples_data['actions'] for samples_data in samples_data_arr])
            obs_next = np.concatenate([samples_data['next_observations'] for samples_data in samples_data_arr])
            self.dynamics_model.update_buffer(
                obs=obs,
                act=act,
                obs_next=obs_next,
                check_init=check_init,
            )

            # Reset variables for early stopping condition
            self.with_new_data = True
            self.remaining_model_idx = list(range(self.dynamics_model.num_models))
            self.valid_loss_rolling_average = None
            logger.logkv('Model-TimePull', time.time() - time_synch)

        return len(samples_data_arr)
コード例 #3
0
    def step(self):
        time_step = time.time()
        """ -------------------- Sampling --------------------------"""

        if self.verbose:
            logger.log("Policy is obtaining samples ...")
        paths = self.model_sampler.obtain_samples(log=True,
                                                  log_prefix='Policy-')
        """ ----------------- Processing Samples ---------------------"""

        if self.verbose:
            logger.log("Policy is processing samples ...")
        samples_data = self.model_sample_processor.process_samples(
            paths, log='all', log_prefix='Policy-')

        if type(paths) is list:
            self.log_diagnostics(paths, prefix='Policy-')
        else:
            self.log_diagnostics(sum(paths.values(), []), prefix='Policy-')
        """ ------------------ Policy Update ---------------------"""

        if self.verbose:
            logger.log("Policy optimization...")
        # This needs to take all samples_data so that it can construct graph for meta-optimization.
        self.algo.optimize_policy(samples_data,
                                  log=True,
                                  verbose=False,
                                  prefix='Policy-')

        self.policy = self.model_sampler.policy

        logger.logkv('Policy-TimeStep', time.time() - time_step)
コード例 #4
0
 def log_diagnostics(self, paths, prefix=''):
     """
     Log extra information per iteration based on the collected paths
     """
     log_stds = np.vstack(
         [path["agent_infos"]["log_std"] for path in paths])
     logger.logkv(prefix + 'AveragePolicyStd', np.mean(np.exp(log_stds)))
コード例 #5
0
ファイル: worker_data.py プロジェクト: zzyunzhi/asynch-mb
    def push(self):
        time_push = time.time()
        self.queue_next.put(pickle.dumps(self.samples_data_arr))
        self.samples_data_arr = []
        time_push = time.time() - time_push

        logger.logkv('Data-TimePush', time_push)
コード例 #6
0
    def step(self, random=False):
        time_step = time.time()

        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )

        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing environment samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)

        # save snapshot
        params = self.get_itr_snapshot()
        logger.save_itr_params(self.itr_counter, params)
コード例 #7
0
ファイル: ppo.py プロジェクト: zzyunzhi/asynch-mb
    def optimize_policy(self,
                        samples_data,
                        log=True,
                        prefix='',
                        verbose=False):
        """
        Performs MAML outer step

        Args:
            samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        input_dict = self._extract_input_dict(samples_data,
                                              self._optimization_keys,
                                              prefix='train')

        if verbose: logger.log("Optimizing")
        loss_before = self.optimizer.optimize(input_val_dict=input_dict)

        if verbose: logger.log("Computing statistics")
        loss_after = self.optimizer.loss(input_val_dict=input_dict)

        if log:
            logger.logkv(prefix + 'LossBefore', loss_before)
            logger.logkv(prefix + 'LossAfter', loss_after)
コード例 #8
0
ファイル: worker_model.py プロジェクト: zzyunzhi/asynch-mb
    def _synch(self, samples_data_arr, check_init=False):
        time_synch = time.time()
        if self.verbose:
            logger.log('Model at {} is synchronizing...'.format(
                self.itr_counter))
        obs = np.concatenate([
            samples_data['observations'] for samples_data in samples_data_arr
        ])
        act = np.concatenate(
            [samples_data['actions'] for samples_data in samples_data_arr])
        obs_next = np.concatenate([
            samples_data['next_observations']
            for samples_data in samples_data_arr
        ])
        self.dynamics_model.update_buffer(
            obs=obs,
            act=act,
            obs_next=obs_next,
            check_init=check_init,
        )

        # Reset variables for early stopping condition
        logger.logkv('Model-AvgEpochs',
                     self.sum_model_itr / self.dynamics_model.num_models)
        self.sum_model_itr = 0
        self.with_new_data = True
        self.remaining_model_idx = list(range(self.dynamics_model.num_models))
        self.valid_loss_rolling_average = None
        time_synch = time.time() - time_synch

        logger.logkv('Model-TimeSynch', time_synch)
コード例 #9
0
ファイル: worker_data.py プロジェクト: zzyunzhi/asynch-mb
    def _synch(self, policy_state_pickle):
        time_synch = time.time()
        policy_state = pickle.loads(policy_state_pickle)
        assert isinstance(policy_state, dict)
        self.env_sampler.policy.set_shared_params(policy_state)
        time_synch = time.time() - time_synch

        logger.logkv('Data-TimeSynch', time_synch)
コード例 #10
0
 def push(self, samples_data):
     time_push = time.time()
     # broadcast samples to all data buffers
     samples_data_id = ray.put(samples_data)
     for data_buffer in self.data_buffers:
         # ray.get(data_buffer.push.remote(samples_data))
         data_buffer.push.remote(samples_data_id)
     logger.logkv('Data-TimePush', time.time() - time_push)
コード例 #11
0
 def pull(self):
     time_synch = time.time()
     if self.verbose:
         logger.log('Policy is synchronizing...')
     model_params = ray.get(self.model_ps.pull.remote())
     assert isinstance(model_params, dict)
     self.model_sampler.dynamics_model.set_shared_params(model_params)
     if hasattr(self.model_sampler, 'vec_env'):
         self.model_sampler.vec_env.dynamics_model.set_shared_params(
             model_params)
     logger.logkv('Policy-TimePull', time.time() - time_synch)
コード例 #12
0
    def _synch(self, dynamics_model_state_pickle):
        time_synch = time.time()
        if self.verbose:
            logger.log('Policy is synchronizing...')
        dynamics_model_state = pickle.loads(dynamics_model_state_pickle)
        assert isinstance(dynamics_model_state, dict)
        self.model_sampler.dynamics_model.set_shared_params(
            dynamics_model_state)
        if hasattr(self.model_sampler, 'vec_env'):
            self.model_sampler.vec_env.dynamics_model.set_shared_params(
                dynamics_model_state)
        time_synch = time.time() - time_synch

        logger.logkv('Policy-TimeSynch', time_synch)
コード例 #13
0
ファイル: worker_model.py プロジェクト: zzyunzhi/asynch-mb
    def push(self):
        time_push = time.time()
        state_pickle = pickle.dumps(
            self.dynamics_model.get_shared_param_values())
        assert state_pickle is not None
        while self.queue_next.qsize() > 5:
            try:
                logger.log('Model is off loading data from queue_next...')
                _ = self.queue_next.get_nowait()
            except Empty:
                break
        self.queue_next.put(state_pickle)
        time_push = time.time() - time_push

        logger.logkv('Model-TimePush', time_push)
コード例 #14
0
    def push(self):
        time_push = time.time()
        policy_state_pickle = pickle.dumps(
            self.policy.get_shared_param_values())
        assert policy_state_pickle is not None
        while self.queue_next.qsize() > 5:
            try:
                logger.log('Policy is off loading data from queue_next...')
                _ = self.queue_next.get_nowait()
            except Empty:
                # very rare chance to reach here
                break
        self.queue_next.put(policy_state_pickle)
        time_push = time.time() - time_push

        logger.logkv('Policy-TimePush', time_push)
コード例 #15
0
    def train(self):
        """
        Trains policy on env using algo
        """

        time_total = time.time()
        ''' --------------- worker looping --------------- '''

        futures = [worker.start.remote() for worker in self.workers]

        logger.log('Start looping...')
        ray.get(futures)

        logger.logkv('Trainer-TimeTotal', time.time() - time_total)
        logger.dumpkvs()
        logger.log('***** Training finished ******')
コード例 #16
0
    def step(self, obs=None, act=None, obs_next=None):
        time_model_fit = time.time()
        """ --------------- fit dynamics model --------------- """

        if self.verbose:
            logger.log('Model at iteration {} is training for one epoch...'.format(self.step_counter))
        self.remaining_model_idx, self.valid_loss_rolling_average = self.dynamics_model.fit_one_epoch(
            remaining_model_idx=self.remaining_model_idx,
            valid_loss_rolling_average_prev=self.valid_loss_rolling_average,
            with_new_data=self.with_new_data,
            verbose=self.verbose,
            log_tabular=True,
            prefix='Model-',
        )
        self.with_new_data = False

        logger.logkv('Model-TimeStep', time.time() - time_model_fit)
コード例 #17
0
ファイル: worker_data.py プロジェクト: zzyunzhi/asynch-mb
    def step(self, random=False):
        time_step = time.time()
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing samples...")
        if type(env_paths) is dict or type(env_paths) is OrderedDict:
            env_paths = list(env_paths.values())
            idxs = np.random.choice(range(len(env_paths)),
                                    size=self.num_rollouts_per_iter,
                                    replace=False)
            env_paths = sum([env_paths[idx] for idx in idxs], [])

        elif type(env_paths) is list:
            idxs = np.random.choice(range(len(env_paths)),
                                    size=self.num_rollouts_per_iter,
                                    replace=False)
            env_paths = [env_paths[idx] for idx in idxs]

        else:
            raise TypeError
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)
コード例 #18
0
    def train(self):
        """
        Trains policy on env using algo
        """
        worker_data_queue, worker_model_queue, worker_policy_queue = self.queues
        worker_data_remote, worker_model_remote, worker_policy_remote = self.remotes

        for p in self.ps:
            p.start()
        ''' --------------- worker warm-up --------------- '''

        logger.log('Prepare start...')

        worker_data_remote.send('prepare start')
        worker_data_queue.put(self.initial_random_samples)
        assert worker_data_remote.recv() == 'loop ready'

        worker_model_remote.send('prepare start')
        assert worker_model_remote.recv() == 'loop ready'

        worker_policy_remote.send('prepare start')
        assert worker_policy_remote.recv() == 'loop ready'

        time_total = time.time()
        ''' --------------- worker looping --------------- '''

        logger.log('Start looping...')
        for remote in self.remotes:
            remote.send('start loop')
        ''' --------------- collect info --------------- '''

        for remote in self.remotes:
            assert remote.recv() == 'loop done'
        logger.log('\n------------all workers exit loops -------------')
        for remote in self.remotes:
            assert remote.recv() == 'worker closed'

        for p in self.ps:
            p.terminate()

        logger.logkv('Trainer-TimeTotal', time.time() - time_total)
        logger.dumpkvs()
        logger.log("*****Training finished")
コード例 #19
0
ファイル: worker_model.py プロジェクト: zzyunzhi/asynch-mb
    def process_queue(self):
        do_push = 0
        samples_data_arr = []
        while True:
            try:
                if not self.remaining_model_idx:
                    logger.log(
                        'Model at iteration {} is block waiting for data'.
                        format(self.itr_counter))
                    # FIXME: check stop_cond
                    time_wait = time.time()
                    samples_data_arr_pickle = self.queue.get()
                    time_wait = time.time() - time_wait
                    logger.logkv('Model-TimeBlockWait', time_wait)
                    self.remaining_model_idx = list(
                        range(self.dynamics_model.num_models))
                else:
                    if self.verbose:
                        logger.log('Model try get_nowait.........')
                    samples_data_arr_pickle = self.queue.get_nowait()
                if samples_data_arr_pickle == 'push':
                    # Only push once before executing another step
                    if do_push == 0:
                        do_push = 1
                        self.push()
                else:
                    samples_data_arr.extend(
                        pickle.loads(samples_data_arr_pickle))
            except Empty:
                break

        do_synch = len(samples_data_arr)
        if do_synch:
            self._synch(samples_data_arr)

        do_step = 1

        if self.verbose:
            logger.log(
                'Model finishes processing queue with {}, {}, {}......'.format(
                    do_push, do_synch, do_step))

        return do_push, do_synch, do_step
コード例 #20
0
ファイル: trpo_maml.py プロジェクト: zzyunzhi/asynch-mb
    def optimize_policy(self,
                        all_samples_data,
                        log=True,
                        prefix='',
                        verbose=False):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)
        if verbose:
            logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(meta_op_input_dict)

        if verbose:
            logger.log("Computing loss before")
        loss_before = self.optimizer.loss(meta_op_input_dict)
        if verbose:
            logger.log("Optimizing")
        self.optimizer.optimize(meta_op_input_dict)
        if verbose:
            logger.log("Computing loss after")
        loss_after = self.optimizer.loss(meta_op_input_dict)

        if verbose:
            logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(meta_op_input_dict)
        if log:
            logger.logkv(prefix + 'MeanKLBefore', mean_kl_before)
            logger.logkv(prefix + 'MeanKL', mean_kl)

            logger.logkv(prefix + 'LossBefore', loss_before)
            logger.logkv(prefix + 'LossAfter', loss_after)
            logger.logkv(prefix + 'dLoss', loss_before - loss_after)
コード例 #21
0
ファイル: test_pr2_env.py プロジェクト: zzyunzhi/asynch-mb
    def log_diagnostics(self, paths, prefix=''):
        dist = [-path["env_infos"]['reward_dist'] for path in paths]
        final_dist = [-path["env_infos"]['reward_dist'][-1] for path in paths]
        ctrl_cost = [-path["env_infos"]['reward_ctrl'] for path in paths]

        logger.logkv(prefix + 'AvgDist', np.mean(dist))
        logger.logkv(prefix + 'AvgFinalDist', np.mean(final_dist))
        logger.logkv(prefix + 'AvgCtrlCost', np.mean(ctrl_cost))
コード例 #22
0
ファイル: worker_data.py プロジェクト: zzyunzhi/asynch-mb
    def step(self, random_sinusoid=(False, False)):
        time_step = time.time()

        if self.itr_counter == 1 and self.env_sampler.policy.dynamics_model.normalization is None:
            if self.verbose:
                logger.log('Data starts first step...')
            self.env_sampler.policy.dynamics_model = pickle.loads(
                self.queue.get())
            if self.verbose:
                logger.log('Data first step done...')
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random_sinusoid[0],
            sinusoid=random_sinusoid[1],
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)
コード例 #23
0
ファイル: worker_policy.py プロジェクト: zzyunzhi/asynch-mb
    def step(self):
        time_step = time.time()
        ''' --------------- MAML steps --------------- '''

        self.policy.switch_to_pre_update()  # Switch to pre-update policy
        all_samples_data = []

        for step in range(self.num_inner_grad_steps + 1):
            if self.verbose:
                logger.log("Policy Adaptation-Step %d **" % step)
            """ -------------------- Sampling --------------------------"""

            paths = self.model_sampler.obtain_samples(log=True,
                                                      log_prefix='Policy-',
                                                      buffer=None)
            """ ----------------- Processing Samples ---------------------"""

            samples_data = self.model_sample_processor.process_samples(
                paths, log='all', log_prefix='Policy-')
            all_samples_data.append(samples_data)

            self.log_diagnostics(sum(list(paths.values()), []),
                                 prefix='Policy-')
            """ ------------------- Inner Policy Update --------------------"""

            if step < self.num_inner_grad_steps:
                self.algo._adapt(samples_data)
        """ ------------------ Outer Policy Update ---------------------"""

        if self.verbose:
            logger.log("Policy is optimizing...")
        # This needs to take all samples_data so that it can construct graph for meta-optimization.
        self.algo.optimize_policy(all_samples_data, prefix='Policy-')

        time_step = time.time() - time_step
        self.policy = self.model_sampler.policy

        logger.logkv('Policy-TimeStep', time_step)
コード例 #24
0
ファイル: base.py プロジェクト: zzyunzhi/asynch-mb
    def start(self):
        logger.log(f"\n================ {self.name} starts ===============")
        time_start = time.time()
        with self.sess.as_default():
            # loop
            while not ray.get(self.stop_cond.is_set.remote()):
                do_synch, do_step = self.step_wrapper()
                self.synch_counter += do_synch
                self.step_counter += do_step

                # logging
                logger.logkv(self.name + '-TimeSoFar',
                             time.time() - time_start)
                logger.logkv(self.name + '-TotalStep', self.step_counter)
                logger.logkv(self.name + '-TotalSynch', self.synch_counter)
                logger.dumpkvs()

                self.set_stop_cond()

        logger.log(
            f"\n================== {self.name} closed ===================")
コード例 #25
0
    def obtain_samples(self, log=False, log_prefix='', random=False):
        """
        Collect batch_size trajectories from each task

        Args:
            log (boolean): whether to log sampling times
            log_prefix (str) : prefix for logger
            random (boolean): whether the actions are random

        Returns: 
            (dict) : A dict of paths of size [meta_batch_size] x (batch_size) x [5] x (max_path_length)
        """

        # initial setup / preparation
        paths = OrderedDict()
        for i in range(self.meta_batch_size):
            paths[i] = []

        n_samples = 0
        running_paths = [
            _get_empty_running_paths_dict()
            for _ in range(self.vec_env.num_envs)
        ]

        pbar = ProgBar(self.total_samples)
        policy_time, env_time = 0, 0

        policy = self.policy
        policy.reset(dones=[True] * self.meta_batch_size)

        # initial reset of meta_envs
        obses = self.vec_env.reset()

        while n_samples < self.total_samples:

            # execute policy
            t = time.time()
            obs_per_task = np.split(np.asarray(obses), self.meta_batch_size)
            if random:
                actions = np.stack([[self.env.action_space.sample()]
                                    for _ in range(len(obses))],
                                   axis=0)
                agent_infos = [[{
                    'mean':
                    np.zeros_like(self.env.action_space.sample()),
                    'log_std':
                    np.zeros_like(self.env.action_space.sample())
                }] * self.envs_per_task] * self.meta_batch_size
            else:
                actions, agent_infos = policy.get_actions(obs_per_task)
            policy_time += time.time() - t

            # step environments
            t = time.time()
            actions = np.concatenate(actions)  # stack meta batch
            next_obses, rewards, dones, env_infos = self.vec_env.step(actions)
            env_time += time.time() - t

            #  stack agent_infos and if no infos were provided (--> None) create empty dicts
            agent_infos, env_infos = self._handle_info_dicts(
                agent_infos, env_infos)

            new_samples = 0
            for idx, observation, action, reward, env_info, agent_info, done in zip(
                    itertools.count(), obses, actions, rewards, env_infos,
                    agent_infos, dones):
                # append new samples to running paths
                if isinstance(reward, np.ndarray):
                    reward = reward[0]
                running_paths[idx]["observations"].append(observation)
                running_paths[idx]["actions"].append(action)
                running_paths[idx]["rewards"].append(reward)
                running_paths[idx]["dones"].append(done)
                running_paths[idx]["env_infos"].append(env_info)
                running_paths[idx]["agent_infos"].append(agent_info)

                # if running path is done, add it to paths and empty the running path
                if done:
                    paths[idx // self.envs_per_task].append(
                        dict(
                            observations=np.asarray(
                                running_paths[idx]["observations"]),
                            actions=np.asarray(running_paths[idx]["actions"]),
                            rewards=np.asarray(running_paths[idx]["rewards"]),
                            dones=np.asarray(running_paths[idx]["dones"]),
                            env_infos=utils.stack_tensor_dict_list(
                                running_paths[idx]["env_infos"]),
                            agent_infos=utils.stack_tensor_dict_list(
                                running_paths[idx]["agent_infos"]),
                        ))
                    new_samples += len(running_paths[idx]["rewards"])
                    running_paths[idx] = _get_empty_running_paths_dict()

            pbar.update(new_samples)
            n_samples += new_samples
            obses = next_obses
        pbar.stop()

        self.total_timesteps_sampled += self.total_samples
        if log:
            logger.logkv(log_prefix + "PolicyExecTime", policy_time)
            logger.logkv(log_prefix + "EnvExecTime", env_time)

        return paths
コード例 #26
0
ファイル: mbmpo_trainer.py プロジェクト: zzyunzhi/asynch-mb
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            if type(self.meta_steps_per_iter) is tuple:
                meta_steps_per_iter = np.linspace(self.meta_steps_per_iter[0]
                                                  , self.meta_steps_per_iter[1], self.n_itr).astype(np.int)
            else:
                meta_steps_per_iter = [self.meta_steps_per_iter] * self.n_itr
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log("\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if self.initial_random_samples and itr == 0:
                    logger.log("Obtaining random samples from the environment...")
                    env_paths = self.env_sampler.obtain_samples(log=True, random=True, log_prefix='EnvSampler-')

                else:
                    logger.log("Obtaining samples from the environment using the policy...")
                    env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='EnvSampler-')

                logger.record_tabular('Time-EnvSampling', time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()
                if type(env_paths) is dict or type(env_paths) is collections.OrderedDict:
                    env_paths = list(env_paths.values())
                    idxs = np.random.choice(range(len(env_paths)),
                                            size=self.num_rollouts_per_iter,
                                            replace=False)
                    env_paths = sum([env_paths[idx] for idx in idxs], [])

                elif type(env_paths) is list:
                    idxs = np.random.choice(range(len(env_paths)),
                                            size=self.num_rollouts_per_iter,
                                            replace=False)
                    env_paths = [env_paths[idx] for idx in idxs]

                else:
                    raise TypeError
                samples_data = self.dynamics_sample_processor.process_samples(env_paths, log=True,
                                                                              log_prefix='EnvTrajs-')

                self.env.log_diagnostics(env_paths, prefix='EnvTrajs-')

                logger.record_tabular('Time-EnvSampleProc', time.time() - time_env_samp_proc)

                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                logger.log("Training dynamics model for %i epochs ..." % (self.dynamics_model_max_epochs))
                self.dynamics_model.fit(samples_data['observations'],
                                        samples_data['actions'],
                                        samples_data['next_observations'],
                                        epochs=self.dynamics_model_max_epochs, verbose=True, log_tabular=True)

                buffer = None if not self.sample_from_buffer else samples_data

                logger.record_tabular('Time-ModelFit', time.time() - time_fit_start)

                ''' ------------ log real performance --------------- '''

                if self.log_real_performance:
                    logger.log("Evaluating the performance of the real policy")
                    self.policy.switch_to_pre_update()
                    env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='PrePolicy-')
                    samples_data = self.model_sample_processor.process_samples(env_paths, log='all',
                                                                               log_prefix='PrePolicy-')
                    self.algo._adapt(samples_data)
                    env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='PostPolicy-')
                    self.model_sample_processor.process_samples(env_paths, log='all', log_prefix='PostPolicy-')

                ''' --------------- MAML steps --------------- '''

                times_dyn_sampling = []
                times_dyn_sample_processing = []
                times_meta_sampling = []
                times_inner_step = []
                times_total_inner_step = []
                times_outer_step = []
                times_maml_steps = []


                for meta_itr in range(meta_steps_per_iter[itr]):

                    logger.log("\n ---------------- Meta-Step %d ----------------" % int(sum(meta_steps_per_iter[:itr])
                                                                                         + meta_itr))
                    self.policy.switch_to_pre_update()  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], []
                    time_maml_steps_start = time.time()
                    start_total_inner_time = time.time()
                    for step in range(self.num_inner_grad_steps+1):
                        logger.log("\n ** Adaptation-Step %d **" % step)

                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(log=True,
                                                                  log_prefix='Step_%d-' % step,
                                                                  buffer=buffer)
                        list_sampling_time.append(time.time() - time_env_sampling_start)
                        all_paths.append(paths)

                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(paths, log='all', log_prefix='Step_%d-' % step)
                        all_samples_data.append(samples_data)
                        list_proc_samples_time.append(time.time() - time_proc_samples_start)

                        self.log_diagnostics(sum(list(paths.values()), []), prefix='Step_%d-' % step)

                        """ ------------------- Inner Policy Update --------------------"""

                        time_inner_step_start = time.time()
                        if step < self.num_inner_grad_steps:
                            logger.log("Computing inner policy updates...")
                            self.algo._adapt(samples_data)

                        list_inner_step_time.append(time.time() - time_inner_step_start)
                    total_inner_time = time.time() - start_total_inner_time

                    time_maml_opt_start = time.time()

                    """ ------------------ Outer Policy Update ---------------------"""

                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    time_outer_step_start = time.time()
                    self.algo.optimize_policy(all_samples_data)

                    times_inner_step.append(list_inner_step_time)
                    times_total_inner_step.append(total_inner_time)
                    times_outer_step.append(time.time() - time_outer_step_start)
                    times_meta_sampling.append(np.sum(list_sampling_time))
                    times_dyn_sampling.append(list_sampling_time)
                    times_dyn_sample_processing.append(list_proc_samples_time)
                    times_maml_steps.append(time.time() - time_maml_steps_start)


                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                if self.log_real_performance:
                    logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled/(3 * self.policy.meta_batch_size) * self.num_rollouts_per_iter)
                else:
                    logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled/self.policy.meta_batch_size * self.num_rollouts_per_iter)
                logger.logkv('AvgTime-OuterStep', np.mean(times_outer_step))
                logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step))
                logger.logkv('AvgTime-TotalInner', np.mean(times_total_inner_step))
                logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step))
                logger.logkv('AvgTime-SampleProc', np.mean(times_dyn_sample_processing))
                logger.logkv('AvgTime-Sampling', np.mean(times_dyn_sampling))
                logger.logkv('AvgTime-MAMLSteps', np.mean(times_maml_steps))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
コード例 #27
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log("\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if self.initial_random_samples and itr == 0:
                    logger.log("Obtaining random samples from the environment...")
                    env_paths = self.env_sampler.obtain_samples(log=True, random=True, log_prefix='Data-EnvSampler-')

                else:
                    logger.log("Obtaining samples from the environment using the policy...")
                    env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='Data-EnvSampler-')

                # Add sleeping time to match parallel experiment
                # time.sleep(10)

                logger.record_tabular('Data-TimeEnvSampling', time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()

                samples_data = self.dynamics_sample_processor.process_samples(env_paths, log=True,
                                                                              log_prefix='Data-EnvTrajs-')

                self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-')

                logger.record_tabular('Data-TimeEnvSampleProc', time.time() - time_env_samp_proc)

                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                self.dynamics_model.update_buffer(samples_data['observations'],
                                                  samples_data['actions'],
                                                  samples_data['next_observations'],
                                                  check_init=True)

                buffer = None if not self.sample_from_buffer else samples_data

                logger.record_tabular('Model-TimeModelFit', time.time() - time_fit_start)

                ''' --------------- MAML steps --------------- '''
                times_dyn_sampling = []
                times_dyn_sample_processing = []
                times_optimization = []
                times_step = []
                remaining_model_idx = list(range(self.dynamics_model.num_models))
                valid_loss_rolling_average_prev = None

                with_new_data = True
                for id_step in range(self.repeat_steps):

                    for epoch in range(self.num_epochs_per_step):
                        logger.log("Training dynamics model for %i epochs ..." % 1)
                        remaining_model_idx, valid_loss_rolling_average = self.dynamics_model.fit_one_epoch(
                                                                                remaining_model_idx,
                                                                                valid_loss_rolling_average_prev,
                                                                                with_new_data,
                                                                                log_tabular=True,
                                                                                prefix='Model-')
                        with_new_data = False

                    for step in range(self.num_grad_policy_per_step):

                        logger.log("\n ---------------- Grad-Step %d ----------------" % int(itr * self.repeat_steps * self.num_grad_policy_per_step +
                                                                                            id_step * self.num_grad_policy_per_step
                                                                                             + step))
                        step_start_time = time.time()

                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples from the model...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(log=True, log_prefix='Policy-', buffer=buffer)
                        sampling_time = time.time() - time_env_sampling_start

                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples from the model...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(paths, log='all', log_prefix='Policy-')
                        proc_samples_time = time.time() - time_proc_samples_start

                        if type(paths) is list:
                            self.log_diagnostics(paths, prefix='Policy-')
                        else:
                            self.log_diagnostics(sum(paths.values(), []), prefix='Policy-')

                        """ ------------------ Policy Update ---------------------"""

                        logger.log("Optimizing policy...")
                        # This needs to take all samples_data so that it can construct graph for meta-optimization.
                        time_optimization_step_start = time.time()
                        self.algo.optimize_policy(samples_data)
                        optimization_time = time.time() - time_optimization_step_start

                        times_dyn_sampling.append(sampling_time)
                        times_dyn_sample_processing.append(proc_samples_time)
                        times_optimization.append(optimization_time)
                        times_step.append(time.time() - step_start_time)

                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Iteration', itr)
                logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled)
                logger.logkv('Policy-TimeSampleProc', np.sum(times_dyn_sample_processing))
                logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling))
                logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization))
                logger.logkv('Policy-TimeStep', np.sum(times_step))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

            logger.logkv('Trainer-TimeTotal', time.time() - start_time)

        logger.log("Training finished")
        self.sess.close()
コード例 #28
0
 def pull(self):
     time_synch = time.time()
     policy_params = ray.get(self.policy_ps.pull.remote())
     assert isinstance(policy_params, dict)
     self.env_sampler.policy.set_shared_params(policy_params)
     logger.logkv('Data-TimePull', time.time() - time_synch)
コード例 #29
0
 def push(self):
     time_push = time.time()
     params = self.dynamics_model.get_shared_param_values()
     assert params is not None
     ray.get(self.model_ps.push.remote(params))  # FIXME: wait here until push succees?
     logger.logkv('Model-TimePush', time.time() - time_push)
コード例 #30
0
    def _log_path_stats(self, paths, log=False, log_prefix=''):
        # compute log stats
        average_discounted_return = [
            sum(path["discounted_rewards"]) for path in paths
        ]
        undiscounted_returns = [sum(path["rewards"]) for path in paths]

        if log == 'reward':
            logger.logkv(log_prefix + 'AverageReturn',
                         np.mean(undiscounted_returns))

        elif log == 'all' or log is True:
            logger.logkv(log_prefix + 'AverageDiscountedReturn',
                         np.mean(average_discounted_return))
            logger.logkv(log_prefix + 'AverageReturn',
                         np.mean(undiscounted_returns))
            logger.logkv(log_prefix + 'NumTrajs', len(paths))
            logger.logkv(log_prefix + 'StdReturn',
                         np.std(undiscounted_returns))
            logger.logkv(log_prefix + 'MaxReturn',
                         np.max(undiscounted_returns))
            logger.logkv(log_prefix + 'MinReturn',
                         np.min(undiscounted_returns))