예제 #1
0
    def step(self, action, offline=False):
        if not offline:
            assert (self.ros_is_good())

        lb, ub = self.action_space.bounds
        action = np.clip(action, lb, ub)

        cmd_steer, cmd_vel = action
        self._set_steer(cmd_steer)
        self._set_vel(cmd_vel)

        if not offline:
            rospy.sleep(
                max(
                    0., self._dt -
                    (rospy.Time.now() - self._last_step_time).to_sec()))
            self._last_step_time = rospy.Time.now()

        next_observation = self._get_observation()
        reward = self._get_reward()
        done = self._get_done()
        env_info = dict()

        self._t += 1

        if not offline:
            self._ros_rolloutbag.write_all(self._ros_topics_and_types.keys(),
                                           self._ros_msgs, self._ros_msg_times)
            if done:
                logger.debug('Done after {0} steps'.format(self._t))
                self._t = 0
                self._ros_rolloutbag.close()

        return next_observation, reward, done, env_info
예제 #2
0
파일: bnn.py 프로젝트: gkahn13/gcg-old
    def train(self):
        self._start_train_batch()

        logger.info('Training model')

        alg_args = self._params['alg']
        total_steps = int(alg_args['total_steps'])
        save_every_n_steps = int(alg_args['save_every_n_steps'])
        update_target_after_n_steps = int(
            alg_args['update_target_after_n_steps'])
        update_target_every_n_steps = int(
            alg_args['update_target_every_n_steps'])
        log_every_n_steps = int(alg_args['log_every_n_steps'])

        timeit.reset()
        timeit.start('total')
        save_itr = 0
        for step in range(total_steps):
            timeit.start('sample')
            # steps, observations, actions, rewards, dones, _ = self._replay_pool.sample(batch_size)
            steps, observations, actions, rewards, dones, _ = self._batch_queue.get(
            )
            timeit.stop('sample')
            timeit.start('train')
            self._model.train_step(step,
                                   steps=steps,
                                   observations=observations,
                                   actions=actions,
                                   rewards=rewards,
                                   dones=dones,
                                   use_target=True)
            timeit.stop('train')

            ### update target network
            if step > update_target_after_n_steps and step % update_target_every_n_steps == 0:
                self._model.update_target()

            ### log
            if step > 0 and step % log_every_n_steps == 0:
                logger.record_tabular('Step', step)
                self._model.log()
                logger.dump_tabular(print_func=logger.info)

                timeit.stop('total')
                for line in str(timeit).split('\n'):
                    logger.debug(line)
                timeit.reset()
                timeit.start('total')

            ### save model
            if step > 0 and step % save_every_n_steps == 0:
                logger.info('Saving files for itr {0}'.format(save_itr))
                self._save_train_policy(save_itr)
                save_itr += 1

        ### always save the end
        self._save_train_policy(save_itr)

        self._stop_train_batch()
예제 #3
0
 def _log(self, msg, lvl):
     if not self.suppress_output:
         if lvl == "info":
             logger.info(msg)
         elif lvl == "debug":
             logger.debug(msg)
         elif lvl == "warn":
             logger.warn(msg)
         elif lvl == "error":
             logger.error(msg)
         else:
             print("NOT VALID LOG LEVEL")
예제 #4
0
    def _train_load_data(self, inference_itr):
        new_inference_itr = self._get_inference_itr()
        if inference_itr < new_inference_itr:
            for i in range(inference_itr, new_inference_itr):
                try:
                    logger.debug('Loading files for itr {0}'.format(i))
                    self._sampler.add_rollouts(
                        [self._train_rollouts_file_name(i)])
                    inference_itr = i + 1
                except:
                    logger.debug('Failed to load files for itr {0}'.format(i))

        return inference_itr
예제 #5
0
    def _train_load_data(self, inference_itr):
        new_inference_itr = self._get_inference_itr()
        if inference_itr < new_inference_itr:
            for i in range(inference_itr, new_inference_itr):
                try:
                    logger.debug('Loading files for itrs [{0}, {1}]'.format(
                        inference_itr + 1, new_inference_itr))
                    self._restore_train_rollouts()
                    inference_itr = new_inference_itr
                except:
                    logger.debug('Failed to load files for itr {0}'.format(i))

        return inference_itr
예제 #6
0
 def reset(self, pos=None, hpr=None, hard_reset=False):
     if self._do_back_up and not hard_reset and \
             pos is None and hpr is None:
         if self._collision:
             self._back_up()
     else:
         if hard_reset:
             logger.debug('Hard resetting!')
         if pos is None and hpr is None:
             pos, hpr = self._next_restart_pos_hpr()
         self._place_vehicle(pos=pos, hpr=hpr)
     self._collision = False
     self._env_time_step = 0
     return self._get_observation(), self._get_goal()
예제 #7
0
    def __init__(self,
                 eval_itr,
                 num_rollouts,
                 eval_params,
                 exp_name,
                 env_eval_params,
                 policy_params,
                 rp_eval_params,
                 seed=None,
                 log_level='info',
                 log_fname='log_eval.txt'):
        self._eval_itr = eval_itr
        self._num_rollouts = num_rollouts

        ### create file manager and setup logger
        self._fm = FileManager(exp_name,
                               is_continue=True,
                               log_level=log_level,
                               log_fname=log_fname,
                               log_folder='eval_itr_{0:04d}'.format(
                                   self._eval_itr))

        logger.debug('Git current')
        logger.debug(
            subprocess.check_output('git status | head -n 1',
                                    shell=True).decode('utf-8').strip())
        logger.debug(
            subprocess.check_output('git log -n 1| head -n 1',
                                    shell=True).decode('utf-8').strip())

        logger.debug('Seed {0}'.format(seed))
        utils.set_seed(seed)

        ### create environments
        self._env_eval = env_eval_params['class'](
            params=env_eval_params['kwargs'])

        ### create policy
        self._policy = policy_params['class'](env_spec=self._env_eval.spec,
                                              exploration_strategies=[],
                                              inference_only=True,
                                              **policy_params['kwargs'])

        ### create replay pools
        self._save_async = True
        self._replay_pool_eval = ReplayPool(
            env_spec=self._env_eval.spec,
            obs_history_len=self._policy.obs_history_len,
            N=self._policy.N,
            labeller=None,
            size=int(5 * self._env_eval.horizon),
            save_rollouts=True,
            save_rollouts_observations=True,
            save_env_infos=True)

        ### create samplers
        self._sampler_eval = Sampler(env=self._env_eval,
                                     policy=self._policy,
                                     replay_pool=self._replay_pool_eval)
예제 #8
0
 def _run_log(self, step):
     logger.record_tabular('Step', step)
     self._env.log()
     self._replay_pool.log()
     if self._env_eval:
         self._env_eval.log(prefix='Eval')
     if self._replay_pool_eval:
         self._replay_pool_eval.log(prefix='Eval')
     self._policy.log()
     logger.dump_tabular(print_func=logger.info)
     timeit.stop('total')
     for line in str(timeit).split('\n'):
         logger.debug(line)
     timeit.reset()
     timeit.start('total')
예제 #9
0
 def _graph_optimize(self, tf_cost, tf_policy_vars):
     tf_lr_ph = tf.placeholder(tf.float32, (), name="learning_rate")
     update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
     num_parameters = 0
     with tf.control_dependencies(update_ops):
         optimizer = tf.train.AdamOptimizer(learning_rate=tf_lr_ph,
                                            epsilon=1e-4)
         gradients = optimizer.compute_gradients(tf_cost,
                                                 var_list=tf_policy_vars)
         for i, (grad, var) in enumerate(gradients):
             num_parameters += int(np.prod(var.get_shape().as_list()))
             if grad is not None:
                 gradients[i] = (tf.clip_by_norm(grad,
                                                 self._grad_clip_norm), var)
         tf_opt = optimizer.apply_gradients(gradients,
                                            global_step=self.global_step)
     logger.debug('Number of parameters: {0:e}'.format(
         float(num_parameters)))
     return tf_opt, tf_lr_ph
예제 #10
0
    def step(self, action, offline=False):
        if not offline:
            assert (self.ros_is_good())

        action = np.asarray(action)
        if not (np.logical_and(action >= self.action_space.low,
                               action <= self.action_space.high).all()):
            logger.warn(
                'Action {0} will be clipped to be within bounds: {1}, {2}'.
                format(action, self.action_space.low, self.action_space.high))
            action = np.clip(action, self.action_space.low,
                             self.action_space.high)

        cmd_steer, cmd_vel = action
        self._set_steer(cmd_steer)
        self._set_vel(cmd_vel)

        if not offline:
            rospy.sleep(
                max(
                    0., self._dt -
                    (rospy.Time.now() - self._last_step_time).to_sec()))
            self._last_step_time = rospy.Time.now()

        next_observation = self._get_observation()
        goal = self._get_goal()
        reward = self._get_reward()
        done = self._get_done()
        env_info = dict()

        self._t += 1

        if not offline:
            self._ros_rolloutbag.write_all(self._ros_topics_and_types.keys(),
                                           self._ros_msgs, self._ros_msg_times)
            if done:
                logger.debug('Done after {0} steps'.format(self._t))
                self._t = 0
                self._ros_rolloutbag.close()

        return next_observation, goal, reward, done, env_info
예제 #11
0
    def ros_is_good(self, print=True):
        # check that all not commands are coming in at a continuous rate
        for topic in self._ros_topics_and_types.keys():
            if 'cmd' not in topic and 'collision' not in topic:
                elapsed = (rospy.Time.now() -
                           self._ros_msg_times[topic]).to_sec()
                if elapsed > self._dt:
                    if print:
                        logger.debug(
                            'Topic {0} was received {1} seconds ago (dt is {2})'
                            .format(topic, elapsed, self._dt))
                    return False

        # check if in python mode
        if self._ros_msgs.get(
                'mode') is None or self._ros_msgs['mode'].data != 2:
            if print:
                logger.debug('In mode {0}'.format(self._ros_msgs.get('mode')))
            return False

        # check if battery is low
        if self._ros_msgs.get('battery/low') is None or self._ros_msgs[
                'battery/low'].data == 1:
            if print:
                logger.debug('Low battery!')
            return False

        return True
예제 #12
0
    def ros_is_good(self, print=True):
        # check that all not commands are coming in at a continuous rate
        for topic in self._ros_topics_and_types.keys():
            if 'cmd' not in topic and 'collision' not in topic:
                if topic not in self._ros_msg_times:
                    if print:
                        logger.debug(
                            'Topic {0} has never been received'.format(topic))
                    return False
                elapsed = (rospy.Time.now() -
                           self._ros_msg_times[topic]).to_sec()
                if elapsed > self._dt:
                    if print:
                        logger.debug(
                            'Topic {0} was received {1} seconds ago (dt is {2})'
                            .format(topic, elapsed, self._dt))
                    return False

        # check if in python mode
        if self._ros_msgs.get(
                'mode') is None or self._ros_msgs['mode'].data != 2:
            if print:
                logger.debug('In mode {0}'.format(self._ros_msgs.get('mode')))
            return False

        if self._ros_msgs['collision/flip'].data:
            if print:
                logger.warn('Car has flipped, please unflip it to continue')
            self._is_collision = False  # otherwise will stay flipped forever
            return False

        return True
예제 #13
0
    def __init__(self, **kwargs):
        self._batch_size = None
        self._tfrecord_train_fnames = []
        self._tfrecord_holdout_fnames = []

        inference_only = kwargs.get('inference_only', False)
        if not inference_only:
            self._batch_size = kwargs['batch_size']

            for folder in kwargs['tfrecord_folders']:
                tfrecord_fnames = [
                    os.path.join(folder, fname) for fname in os.listdir(folder)
                    if os.path.splitext(fname)[1] == '.tfrecord'
                ]

                for tfrecord_fname in tfrecord_fnames:
                    logger.debug('Tfrecord {0}'.format(tfrecord_fname))
                    if os.path.splitext(
                            FileManager.train_rollouts_fname_suffix
                    )[0] in os.path.splitext(tfrecord_fname)[0]:
                        self._tfrecord_train_fnames.append(tfrecord_fname)
                    elif os.path.splitext(
                            FileManager.eval_rollouts_fname_suffix
                    )[0] in os.path.splitext(tfrecord_fname)[0]:
                        self._tfrecord_holdout_fnames.append(tfrecord_fname)
                    else:
                        raise ValueError(
                            'tfrecord {0} does not end in {1} or {2}'.format(
                                tfrecord_fname,
                                os.path.splitext(
                                    FileManager.train_rollouts_fname_suffix)
                                [0],
                                os.path.splitext(
                                    FileManager.eval_rollouts_fname_suffix)
                                [0]))

            random.shuffle(self._tfrecord_train_fnames)
            random.shuffle(self._tfrecord_holdout_fnames)

        super(GCGPolicyTfrecord, self).__init__(**kwargs)
예제 #14
0
    def reset(self, offline=False):
        if offline:
            self._is_collision = False
            return self._get_observation()

        assert (self.ros_is_good())

        if self._ros_rolloutbag.is_open:
            # should've been closed in step when done
            logger.debug('Trashing bag')
            self._ros_rolloutbag.trash()

        if self._press_enter_on_reset:
            logger.info('Resetting, press enter to continue')
            input()
        else:
            if self._is_collision:
                logger.debug('Resetting (collision)')
            else:
                logger.debug('Resetting (no collision)')

            if self._ros_msgs['collision/flip'].data:
                logger.warn('Car has flipped, please unflip it to continue')
                while self._ros_msgs['collision/flip'].data:
                    rospy.sleep(0.1)
                logger.warn('Car is now unflipped. Continuing...')
                rospy.sleep(1.)

            backup_steer = np.random.uniform(*self._backup_steer_range)
            self._set_steer(backup_steer)
            self._set_motor(self._backup_motor, self._backup_duration)
            self._set_steer(0.)
            self._set_vel(0.)

        rospy.sleep(0.5)

        self._last_step_time = rospy.Time.now()
        self._is_collision = False
        self._t = 0

        self._ros_rolloutbag.open()

        assert (self.ros_is_good())

        return self._get_observation()
예제 #15
0
    def reset(self, offline=False, keep_rosbag=False):
        if offline:
            self._is_collision = False
            return self._get_observation(), self._get_goal()

        assert (self.ros_is_good())

        if self._ros_rolloutbag.is_open:
            if keep_rosbag:
                self._ros_rolloutbag.close()
            else:
                # should've been closed in step when done
                logger.debug('Trashing bag')
                self._ros_rolloutbag.trash()

        if self._press_enter_on_reset:
            logger.info('Resetting, press enter to continue')
            input()
        else:
            if self._is_collision:
                logger.debug('Resetting (collision)')
            else:
                logger.debug('Resetting (no collision)')

            if self._backup_duration > 0:
                backup_steer = np.random.uniform(*self._backup_steer_range)
                self._set_steer(backup_steer)
                self._set_motor(self._backup_motor, self._backup_duration)
            self._set_steer(0.)
            self._set_vel(0.)

        rospy.sleep(0.5)

        self._last_step_time = rospy.Time.now()
        self._is_collision = False
        self._t = 0

        self._ros_rolloutbag.open()

        assert (self.ros_is_good())

        return self._get_observation(), self._get_goal()
예제 #16
0
파일: gcg.py 프로젝트: XieKaixuan/gcg
    def train(self):
        ### restore where we left off
        save_itr = self._restore()

        target_updated = False
        eval_rollouts = []

        self._sampler.reset()
        if self._eval_sampler is not None:
            self._eval_sampler.reset()

        timeit.reset()
        timeit.start('total')
        for step in range(0, self._total_steps, self._sampler.n_envs):
            ### sample and add to buffer
            if step > self._sample_after_n_steps:
                timeit.start('sample')
                self._sampler.step(
                    step,
                    take_random_actions=(step <= self._onpolicy_after_n_steps),
                    explore=True)
                timeit.stop('sample')

            ### sample and DON'T add to buffer (for validation)
            if self._eval_sampler is not None and step > 0 and step % self._eval_every_n_steps == 0:
                timeit.start('eval')
                for _ in range(self._rollouts_per_eval):
                    eval_rollouts_step = []
                    eval_step = step
                    while len(eval_rollouts_step) == 0:
                        self._eval_sampler.step(eval_step, explore=False)
                        eval_rollouts_step = self._eval_sampler.get_recent_paths(
                        )
                        eval_step += 1
                    eval_rollouts += eval_rollouts_step
                timeit.stop('eval')

            if step >= self._learn_after_n_steps:
                ### training step
                if self._train_every_n_steps >= 1:
                    if step % int(self._train_every_n_steps) == 0:
                        timeit.start('batch')
                        steps, observations, goals, actions, rewards, dones, _ = \
                            self._sampler.sample(self._batch_size)
                        timeit.stop('batch')
                        timeit.start('train')
                        self._policy.train_step(step,
                                                steps=steps,
                                                observations=observations,
                                                goals=goals,
                                                actions=actions,
                                                rewards=rewards,
                                                dones=dones,
                                                use_target=target_updated)
                        timeit.stop('train')
                else:
                    for _ in range(int(1. / self._train_every_n_steps)):
                        timeit.start('batch')
                        steps, observations, goals, actions, rewards, dones, _ = \
                            self._sampler.sample(self._batch_size)
                        timeit.stop('batch')
                        timeit.start('train')
                        self._policy.train_step(step,
                                                steps=steps,
                                                observations=observations,
                                                goals=goals,
                                                actions=actions,
                                                rewards=rewards,
                                                dones=dones,
                                                use_target=target_updated)
                        timeit.stop('train')

                ### update target network
                if step > self._update_target_after_n_steps and step % self._update_target_every_n_steps == 0:
                    self._policy.update_target()
                    target_updated = True

                ### log
                if step % self._log_every_n_steps == 0:
                    logger.record_tabular('Step', step)
                    self._sampler.log()
                    self._eval_sampler.log(prefix='Eval')
                    self._policy.log()
                    logger.dump_tabular(print_func=logger.info)
                    timeit.stop('total')
                    for line in str(timeit).split('\n'):
                        logger.debug(line)
                    timeit.reset()
                    timeit.start('total')

            ### save model
            if step > 0 and step % self._save_every_n_steps == 0:
                logger.info('Saving files for itr {0}'.format(save_itr))
                self._save(save_itr, self._sampler.get_recent_paths(),
                           eval_rollouts)
                save_itr += 1
                eval_rollouts = []

        self._save(save_itr, self._sampler.get_recent_paths(), eval_rollouts)
예제 #17
0
    def inference(self):
        ### restore where we left off
        self._restore_inference()
        inference_itr = self._get_inference_itr()
        inference_step = self._get_inference_step()
        train_itr = self._get_train_itr()

        self._run_rsync()

        train_rollouts = []
        eval_rollouts = []

        self._inference_reset_sampler()

        timeit.reset()
        timeit.start('total')
        while True:
            train_step = self._get_train_step()
            if inference_step > self._total_steps:
                break

            ### sample and add to buffer
            if inference_step > self._sample_after_n_steps:
                timeit.start('sample')
                inference_step = self._inference_step(inference_step)
                timeit.stop('sample')
            else:
                inference_step += self._sampler.n_envs

            ### sample and DON'T add to buffer (for validation)
            if self._eval_sampler is not None and inference_step > 0 and inference_step % self._eval_every_n_steps == 0:
                timeit.start('eval')
                eval_rollouts_step = []
                eval_step = inference_step
                while len(eval_rollouts_step) == 0:
                    self._eval_sampler.step(eval_step, explore=False)
                    eval_rollouts_step = self._eval_sampler.get_recent_paths()
                    eval_step += 1
                eval_rollouts += eval_rollouts_step
                timeit.stop('eval')

            ### log
            if inference_step % self._log_every_n_steps == 0:
                logger.info('train itr {0:04d} inference itr {1:04d}'.format(
                    train_itr, inference_itr))
                logger.record_tabular('Train step', train_step)
                logger.record_tabular('Inference step', inference_step)
                self._sampler.log()
                if self._eval_sampler:
                    self._eval_sampler.log(prefix='Eval')
                logger.dump_tabular(print_func=logger.info)
                timeit.stop('total')
                for line in str(timeit).split('\n'):
                    logger.debug(line)
                timeit.reset()
                timeit.start('total')

            ### save rollouts / load model
            train_rollouts += self._sampler.get_recent_paths()
            if inference_step > 0 and inference_step % self._inference_save_every_n_steps == 0:
                self._inference_reset_sampler()

                ### save rollouts
                logger.debug('Saving files for itr {0}'.format(inference_itr))
                self._save_inference(inference_itr, train_rollouts,
                                     eval_rollouts)
                inference_itr += 1
                train_rollouts = []
                eval_rollouts = []

                ### load model
                with self._rsync_lock:  # to ensure the ckpt has been fully transferred over
                    new_train_itr = self._get_train_itr()
                    if train_itr < new_train_itr:
                        logger.debug(
                            'Loading policy for itr {0}'.format(new_train_itr -
                                                                1))
                        try:
                            self._policy.restore(
                                self._inference_policy_file_name(
                                    new_train_itr - 1),
                                train=False)
                            train_itr = new_train_itr
                        except:
                            logger.debug(
                                'Failed to load model for itr {0}'.format(
                                    new_train_itr - 1))
                            self._policy.restore(
                                self._inference_policy_file_name(train_itr -
                                                                 1),
                                train=False)
                            logger.debug('As backup, restored itr {0}'.format(
                                train_itr - 1))

        self._save_inference(inference_itr, self._sampler.get_recent_paths(),
                             eval_rollouts)
예제 #18
0
    def train(self):
        ### restore where we left off
        init_inference_step = len(self._sampler)  # don't count offpolicy
        self._restore_train()
        train_itr = self._get_train_itr()
        train_step = self._get_train_step()
        inference_itr = self._get_inference_itr()

        target_updated = False

        timeit.reset()
        timeit.start('total')
        while True:
            inference_step = len(self._sampler) - init_inference_step
            if inference_step > self._total_steps or train_step > self._train_total_steps:
                break

            if inference_step >= self._learn_after_n_steps:
                ### training step
                train_step += 1
                timeit.start('batch')
                steps, observations, goals, actions, rewards, dones, _ = \
                    self._sampler.sample(self._batch_size)
                timeit.stop('batch')
                timeit.start('train')
                self._policy.train_step(train_step,
                                        steps=steps,
                                        observations=observations,
                                        goals=goals,
                                        actions=actions,
                                        rewards=rewards,
                                        dones=dones,
                                        use_target=target_updated)
                timeit.stop('train')

                ### update target network
                if train_step > self._update_target_after_n_steps and train_step % self._update_target_every_n_steps == 0:
                    self._policy.update_target()
                    target_updated = True

                ### log
                if train_step % self._log_every_n_steps == 0:
                    logger.info(
                        'train itr {0:04d} inference itr {1:04d}'.format(
                            train_itr, inference_itr))
                    logger.record_tabular('Train step', train_step)
                    logger.record_tabular('Inference step', inference_step)
                    self._policy.log()
                    logger.dump_tabular(print_func=logger.info)
                    timeit.stop('total')
                    for line in str(timeit).split('\n'):
                        logger.debug(line)
                    timeit.reset()
                    timeit.start('total')
            else:
                time.sleep(1)

            ### save model
            if train_step > 0 and train_step % self._train_save_every_n_steps == 0:
                logger.debug('Saving files for itr {0}'.format(train_itr))
                self._save_train(train_itr)
                train_itr += 1

            ### reset model
            if train_step > 0 and self._train_reset_every_n_steps is not None and \
                                    train_step % self._train_reset_every_n_steps == 0:
                logger.debug('Resetting model')
                self._policy.reset_weights()

            ### load data
            inference_itr = self._train_load_data(inference_itr)
예제 #19
0
    def __init__(self, params={}):
        params.setdefault('dt', 0.25)
        params.setdefault('horizon',
                          int(5. * 60. / params['dt']))  # 5 minutes worth
        params.setdefault('ros_namespace', '/crazyflie/')
        params.setdefault('obs_shape', (72, 96, 1))
        params.setdefault('yaw_limits', [-120, 120])  #default yaw rate range
        params.setdefault('fixed_alt', 0.4)
        params.setdefault('fixed_velocity_range', [0.4, 0.4])
        params.setdefault('press_enter_on_reset', False)
        params.setdefault('prompt_save_rollout_on_coll', False)
        params.setdefault('enable_adjustment_on_start', True)
        params.setdefault('use_joy_commands', True)
        params.setdefault('joy_start_btn', 1)  #A
        params.setdefault('joy_stop_btn', 2)  #B
        params.setdefault('joy_coll_stop_btn', 0)  #X
        params.setdefault('joy_trash_rollout_btn', 3)  # Y
        params.setdefault('joy_topic', '/joy')
        params.setdefault('collision_reward', 1)
        params.setdefault('collision_reward_only', True)

        self._obs_shape = params['obs_shape']
        self._yaw_limits = params['yaw_limits']
        self._fixed_alt = params['fixed_alt']
        self._collision_reward = params['collision_reward']
        self._collision_reward_only = params['collision_reward_only']
        self._fixed_velocity_range = params['fixed_velocity_range']
        self._fixed_velocity = np.random.uniform(self._fixed_velocity_range[0],
                                                 self._fixed_velocity_range[1])
        self._dt = params['dt']
        self.horizon = params['horizon']

        # start stop and pause
        self._enable_adjustment_on_start = params['enable_adjustment_on_start']
        self._use_joy_commands = params['use_joy_commands']
        self._joy_topic = params['joy_topic']
        self._joy_stop_btn = params['joy_stop_btn']
        self._joy_coll_stop_btn = params['joy_coll_stop_btn']
        self._joy_start_btn = params['joy_start_btn']
        self._joy_trash_rollout_btn = params['joy_trash_rollout_btn']
        self._press_enter_on_reset = params['press_enter_on_reset']
        self._prompt_save_rollout_on_coll = params[
            'prompt_save_rollout_on_coll']
        self._start_pressed = False
        self._stop_pressed = False
        self._trash_rollout = False
        self._coll_stop_pressed = False
        self._curr_joy = None
        self._curr_motion = crazyflie.msg.CFMotion()

        self._setup_spec()
        assert (self.observation_im_space.shape[-1] == 1
                or self.observation_im_space.shape[-1] == 3)
        self.spec = EnvSpec(observation_im_space=self.observation_im_space,
                            action_space=self.action_space,
                            action_selection_space=self.action_selection_space,
                            observation_vec_spec=self.observation_vec_spec,
                            action_spec=self.action_spec,
                            action_selection_spec=self.action_selection_spec,
                            goal_spec=self.goal_spec)

        self._last_step_time = None
        self._is_collision = False

        rospy.init_node('CrazyflieEnv', anonymous=True)
        time.sleep(0.5)

        self._ros_namespace = params['ros_namespace']
        self._ros_topics_and_types = dict([
            ('cf/0/image', sensor_msgs.msg.CompressedImage),
            ('cf/0/data', crazyflie.msg.CFData),
            ('cf/0/coll', std_msgs.msg.Bool),
            ('cf/0/motion', crazyflie.msg.CFMotion)
        ])
        self._ros_msgs = dict()
        self._ros_msg_times = dict()
        for topic, type in self._ros_topics_and_types.items():
            rospy.Subscriber(topic, type, self.ros_msg_update, (topic, ))

        self._ros_motion_pub = rospy.Publisher("/cf/0/motion",
                                               crazyflie.msg.CFMotion,
                                               queue_size=10)
        self._ros_command_pub = rospy.Publisher("/cf/0/command",
                                                crazyflie.msg.CFCommand,
                                                queue_size=10)
        self._ros_stop_pub = rospy.Publisher('/joystop',
                                             crazyflie.msg.JoyStop,
                                             queue_size=10)
        if self._use_joy_commands:
            logger.debug("Environment using joystick commands")
            self._ros_joy_sub = rospy.Subscriber(self._joy_topic,
                                                 sensor_msgs.msg.Joy,
                                                 self._joy_cb)

        # I don't think this is needed
        # self._ros_pid_enable_pub = rospy.Publisher(self._ros_namespace + 'pid/enable', std_msgs.msg.Empty,
        #                                            queue_size=10)
        # self._ros_pid_disable_pub = rospy.Publisher(self._ros_namespace + 'pid/disable', std_msgs.msg.Empty,
        #                                             queue_size=10)

        self._ros_rolloutbag = RolloutRosbag()
        self._t = 0

        self.suppress_output = False
        self.resetting = False
        self._send_override = False  # set true only when resetting but still wanting to send background thread motion commands
        threading.Thread(target=self._background_thread).start()

        time.sleep(1.0)  #waiting for some messages before resetting

        self.delete_this_variable = 0
예제 #20
0
    def inference(self):
        ### restore where we left off
        self._restore_inference()
        inference_itr = self._get_inference_itr()
        inference_step = self._get_inference_step()
        train_itr = self._get_train_itr()

        self._run_rsync()

        assert (self._eval_sampler is None)  # TODO: temporary
        train_rollouts = []
        eval_rollouts = []

        self._reset_sampler()

        timeit.reset()
        timeit.start('total')
        while True:
            train_step = self._get_train_step()
            if inference_step > self._total_steps:
                break

            ### sample and add to buffer
            if inference_step > self._sample_after_n_steps:
                timeit.start('sample')
                try:
                    self._sampler.step(
                        inference_step,
                        take_random_actions=(
                            inference_step <= self._learn_after_n_steps
                            or inference_step <= self._onpolicy_after_n_steps),
                        explore=True)
                    inference_step += self._sampler.n_envs
                except Exception as e:
                    logger.warn('Sampler exception {0}'.format(str(e)))
                    trashed_steps = self._sampler.trash_current_rollouts()
                    inference_step -= trashed_steps
                    logger.warn('Trashed {0} steps'.format(trashed_steps))
                    while not self._env.ros_is_good(
                            print=False):  # TODO hard coded
                        time.sleep(0.25)
                    self._reset_sampler()
                    logger.warn('Continuing...')
                timeit.stop('sample')
            else:
                inference_step += self._sampler.n_envs

            ### sample and DON'T add to buffer (for validation)
            if self._eval_sampler is not None and inference_step > 0 and inference_step % self._eval_every_n_steps == 0:
                timeit.start('eval')
                eval_rollouts_step = []
                eval_step = inference_step
                while len(eval_rollouts_step) == 0:
                    self._eval_sampler.step(eval_step, explore=False)
                    eval_rollouts_step = self._eval_sampler.get_recent_paths()
                    eval_step += 1
                eval_rollouts += eval_rollouts_step
                timeit.stop('eval')

            ### log
            if inference_step % self._log_every_n_steps == 0:
                logger.info('train itr {0:04d} inference itr {1:04d}'.format(
                    train_itr, inference_itr))
                logger.record_tabular('Train step', train_step)
                logger.record_tabular('Inference step', inference_step)
                self._sampler.log()
                if self._eval_sampler:
                    self._eval_sampler.log(prefix='Eval')
                logger.dump_tabular(print_func=logger.info)
                timeit.stop('total')
                for line in str(timeit).split('\n'):
                    logger.debug(line)
                timeit.reset()
                timeit.start('total')

            ### save rollouts / load model
            train_rollouts += self._sampler.get_recent_paths()
            if inference_step > 0 and inference_step % self._inference_save_every_n_steps == 0 and \
                            len(train_rollouts) > 0:
                response = input('Keep rollouts?')
                if response != 'y':
                    train_rollouts = []
                    continue

                ### reset to stop rollout
                self._sampler.reset()

                ### save rollouts
                logger.debug('Saving files for itr {0}'.format(inference_itr))
                self._save_inference(inference_itr, train_rollouts,
                                     eval_rollouts)
                inference_itr += 1
                train_rollouts = []
                eval_rollouts = []

                ### load model
                with self._rsync_lock:  # to ensure the ckpt has been fully transferred over
                    new_train_itr = self._get_train_itr()
                    if train_itr < new_train_itr:
                        logger.debug(
                            'Loading policy for itr {0}'.format(new_train_itr -
                                                                1))
                        try:
                            self._policy.restore(
                                self._inference_policy_file_name(
                                    new_train_itr - 1),
                                train=False)
                            train_itr = new_train_itr
                        except:
                            logger.debug(
                                'Failed to load model for itr {0}'.format(
                                    new_train_itr - 1))
                            self._policy.restore(
                                self._inference_policy_file_name(train_itr -
                                                                 1),
                                train=False)
                            logger.debug('As backup, restored itr {0}'.format(
                                train_itr - 1))

        self._save_inference(inference_itr, self._sampler.get_recent_paths(),
                             eval_rollouts)
예제 #21
0
    def train(self):
        ### restore where we left off
        self._restore_train()
        train_itr = self._get_train_itr()
        train_step = self._get_train_step()
        inference_itr = self._get_inference_itr()
        init_inference_step = len(self._sampler)

        target_updated = False

        timeit.reset()
        timeit.start('total')
        while True:
            inference_step = len(self._sampler) - init_inference_step
            if inference_step > self._total_steps:
                break

            if inference_step >= self._learn_after_n_steps:
                ### update preprocess
                if train_step % self._update_preprocess_every_n_steps == 0:
                    self._policy.update_preprocess(self._sampler.statistics)

                ### training step
                train_step += 1
                timeit.start('batch')
                batch = self._sampler.sample(self._batch_size)
                timeit.stop('batch')
                timeit.start('train')
                self._policy.train_step(train_step,
                                        *batch,
                                        use_target=target_updated)
                timeit.stop('train')

                ### update target network
                if train_step > self._update_target_after_n_steps and train_step % self._update_target_every_n_steps == 0:
                    self._policy.update_target()
                    target_updated = True

                ### log
                if train_step % self._log_every_n_steps == 0:
                    logger.info(
                        'train itr {0:04d} inference itr {1:04d}'.format(
                            train_itr, inference_itr))
                    logger.record_tabular('Train step', train_step)
                    logger.record_tabular('Inference step', inference_step)
                    self._policy.log()
                    logger.dump_tabular(print_func=logger.info)
                    timeit.stop('total')
                    for line in str(timeit).split('\n'):
                        logger.debug(line)
                    timeit.reset()
                    timeit.start('total')
            else:
                time.sleep(1)

            ### save model
            if train_step > 0 and train_step % self._train_save_every_n_steps == 0:
                logger.debug('Saving files for itr {0}'.format(train_itr))
                self._save_train(train_itr)
                train_itr += 1

            ### reset model
            if train_step > 0 and self._train_reset_every_n_steps is not None and \
                                    train_step % self._train_reset_every_n_steps == 0:
                logger.debug('Resetting model')
                self._policy.reset_weights()

            ### load data
            new_inference_itr = self._get_inference_itr()
            if inference_itr < new_inference_itr:
                for i in range(inference_itr, new_inference_itr):
                    try:
                        logger.debug('Loading files for itr {0}'.format(i))
                        self._sampler.add_rollouts(
                            [self._train_rollouts_file_name(i)])
                        inference_itr = i + 1
                    except:
                        logger.debug(
                            'Failed to load files for itr {0}'.format(i))
예제 #22
0
    def __init__(self,
                 exp_name,
                 env_params,
                 env_eval_params,
                 rp_params,
                 rp_eval_params,
                 labeller_params,
                 policy_params,
                 alg_params,
                 log_level='info',
                 log_fname='log.txt',
                 seed=None,
                 is_continue=False,
                 params_txt=None):
        ### create file manager and setup logger
        self._fm = FileManager(exp_name,
                               is_continue=is_continue,
                               log_level=log_level,
                               log_fname=log_fname)

        logger.debug('Git current')
        logger.debug(
            subprocess.check_output('git status | head -n 1',
                                    shell=True).decode('utf-8').strip())
        logger.debug(
            subprocess.check_output('git log -n 1| head -n 1',
                                    shell=True).decode('utf-8').strip())

        logger.debug('Seed {0}'.format(seed))
        utils.set_seed(seed)

        ### copy params for posterity
        if params_txt:
            with open(self._fm.params_fname, 'w') as f:
                f.write(params_txt)

        ### create environments
        self._env = env_params['class'](params=env_params['kwargs'])
        self._env_eval = env_eval_params['class'](
            params=env_eval_params['kwargs']) if env_eval_params else self._env

        ### create policy
        self._policy = policy_params['class'](
            env_spec=self._env.spec,
            exploration_strategies=alg_params['exploration_strategies'],
            **policy_params['kwargs'])

        ### create labeller
        self._labeller = labeller_params['class'](
            env_spec=self._env.spec,
            policy=self._policy,
            **labeller_params['kwargs']) if labeller_params['class'] else None

        ### create replay pools
        self._replay_pool = rp_params['class'](
            env_spec=self._env.spec,
            obs_history_len=self._policy.obs_history_len,
            N=self._policy.N,
            labeller=self._labeller,
            **rp_params['kwargs'])
        self._replay_pool_eval = rp_eval_params['class'](
            env_spec=self._env_eval.spec if self._env_eval else self._env.spec,
            obs_history_len=self._policy.obs_history_len,
            N=self._policy.N,
            labeller=self._labeller,
            **rp_eval_params['kwargs']) if rp_eval_params else None

        ### create samplers
        self._sampler = Sampler(env=self._env,
                                policy=self._policy,
                                replay_pool=self._replay_pool)
        self._sampler_eval = Sampler(
            env=self._env_eval,
            policy=self._policy,
            replay_pool=self._replay_pool_eval
        ) if self._env_eval is not None and self._replay_pool_eval is not None else None

        ### create algorithm
        self._total_steps = int(alg_params['total_steps'])
        self._sample_after_n_steps = int(alg_params['sample_after_n_steps'])
        self._onpolicy_after_n_steps = int(
            alg_params['onpolicy_after_n_steps'])
        self._learn_after_n_steps = int(alg_params['learn_after_n_steps'])
        self._train_every_n_steps = alg_params['train_every_n_steps']
        self._eval_every_n_steps = int(alg_params['eval_every_n_steps'])
        self._rollouts_per_eval = int(alg_params.get('rollouts_per_eval', 1))
        self._save_every_n_steps = int(alg_params['save_every_n_steps'])
        self._save_async = alg_params.get('save_async', False)
        self._update_target_after_n_steps = int(
            alg_params['update_target_after_n_steps'])
        self._update_target_every_n_steps = int(
            alg_params['update_target_every_n_steps'])
        self._log_every_n_steps = int(alg_params['log_every_n_steps'])
        self._batch_size = alg_params['batch_size']
        if alg_params['offpolicy'] is not None:
            self._add_offpolicy(alg_params['offpolicy'],
                                max_to_add=alg_params['num_offpolicy'])
        if alg_params['init_inference_ckpt'] is not None:
            self._policy.restore(alg_params['init_inference_ckpt'],
                                 train=False)
        if alg_params['init_train_ckpt'] is not None:
            self._policy.restore(alg_params['init_train_ckpt'], train=True)