Esempio n. 1
0
    def inference(self):
        ### restore where we left off
        self._restore_inference()
        inference_itr = self._get_inference_itr()
        inference_step = self._get_inference_step()
        train_itr = self._get_train_itr()

        self._run_rsync()

        assert (self._eval_sampler is None)  # TODO: temporary
        train_rollouts = []
        eval_rollouts = []

        self._reset_sampler()

        timeit.reset()
        timeit.start('total')
        while True:
            train_step = self._get_train_step()
            if inference_step > self._total_steps:
                break

            ### sample and add to buffer
            if inference_step > self._sample_after_n_steps:
                timeit.start('sample')
                try:
                    self._sampler.step(
                        inference_step,
                        take_random_actions=(
                            inference_step <= self._learn_after_n_steps
                            or inference_step <= self._onpolicy_after_n_steps),
                        explore=True)
                    inference_step += self._sampler.n_envs
                except Exception as e:
                    logger.warn('Sampler exception {0}'.format(str(e)))
                    trashed_steps = self._sampler.trash_current_rollouts()
                    inference_step -= trashed_steps
                    logger.warn('Trashed {0} steps'.format(trashed_steps))
                    while not self._env.ros_is_good(
                            print=False):  # TODO hard coded
                        time.sleep(0.25)
                    self._reset_sampler()
                    logger.warn('Continuing...')
                timeit.stop('sample')
            else:
                inference_step += self._sampler.n_envs

            ### sample and DON'T add to buffer (for validation)
            if self._eval_sampler is not None and inference_step > 0 and inference_step % self._eval_every_n_steps == 0:
                timeit.start('eval')
                eval_rollouts_step = []
                eval_step = inference_step
                while len(eval_rollouts_step) == 0:
                    self._eval_sampler.step(eval_step, explore=False)
                    eval_rollouts_step = self._eval_sampler.get_recent_paths()
                    eval_step += 1
                eval_rollouts += eval_rollouts_step
                timeit.stop('eval')

            ### log
            if inference_step % self._log_every_n_steps == 0:
                logger.info('train itr {0:04d} inference itr {1:04d}'.format(
                    train_itr, inference_itr))
                logger.record_tabular('Train step', train_step)
                logger.record_tabular('Inference step', inference_step)
                self._sampler.log()
                if self._eval_sampler:
                    self._eval_sampler.log(prefix='Eval')
                logger.dump_tabular(print_func=logger.info)
                timeit.stop('total')
                for line in str(timeit).split('\n'):
                    logger.debug(line)
                timeit.reset()
                timeit.start('total')

            ### save rollouts / load model
            train_rollouts += self._sampler.get_recent_paths()
            if inference_step > 0 and inference_step % self._inference_save_every_n_steps == 0 and \
                            len(train_rollouts) > 0:
                response = input('Keep rollouts?')
                if response != 'y':
                    train_rollouts = []
                    continue

                ### reset to stop rollout
                self._sampler.reset()

                ### save rollouts
                logger.debug('Saving files for itr {0}'.format(inference_itr))
                self._save_inference(inference_itr, train_rollouts,
                                     eval_rollouts)
                inference_itr += 1
                train_rollouts = []
                eval_rollouts = []

                ### load model
                with self._rsync_lock:  # to ensure the ckpt has been fully transferred over
                    new_train_itr = self._get_train_itr()
                    if train_itr < new_train_itr:
                        logger.debug(
                            'Loading policy for itr {0}'.format(new_train_itr -
                                                                1))
                        try:
                            self._policy.restore(
                                self._inference_policy_file_name(
                                    new_train_itr - 1),
                                train=False)
                            train_itr = new_train_itr
                        except:
                            logger.debug(
                                'Failed to load model for itr {0}'.format(
                                    new_train_itr - 1))
                            self._policy.restore(
                                self._inference_policy_file_name(train_itr -
                                                                 1),
                                train=False)
                            logger.debug('As backup, restored itr {0}'.format(
                                train_itr - 1))

        self._save_inference(inference_itr, self._sampler.get_recent_paths(),
                             eval_rollouts)
Esempio n. 2
0
    def _add_rosbags(self, rosbag_filenames):
        """ Convert rosbags to rollouts and add to sampler """
        def visualize(rollout):
            r_len = len(rollout['dones'])
            import matplotlib.pyplot as plt

            f, ax = plt.subplots(1, 1)

            imshow = ax.imshow(rollout['observations_im'][0][:, :, 0],
                               cmap='Greys_r')
            ax.set_title('r: {0}'.format(rollout['rewards'][0]))
            plt.show(block=False)
            plt.pause(0.01)
            input('t: 0')
            for t in range(1, r_len):
                imshow.set_data(rollout['observations_im'][t][:, :, 0])
                ax.set_title('r: {0}'.format(rollout['rewards'][t]))
                f.canvas.draw()
                plt.pause(0.01)
                input('t: {0}'.format(t))

            plt.close(f)

        timesteps_kept = 0
        timesteps_total = 0
        for fname in rosbag_filenames:
            self._added_rosbag_filenames.append(fname)

            ### read bag file
            try:
                bag = rosbag.Bag(fname, 'r')
            except:
                logger.warn('{0}: could not open'.format(
                    os.path.basename(fname)))
                continue
            d_bag = defaultdict(list)
            for topic, msg, t in bag.read_messages():
                d_bag[topic].append(msg)
            bag.close()

            timesteps_total += len(d_bag['mode']) - 1

            ### trim to whenever collision occurs
            colls = np.array([msg.data for msg in d_bag['collision/all']])
            if len(colls) == 0:
                logger.warn('{0}: empty bag'.format(os.path.basename(fname)))
                continue
            if colls.max() > 0:
                if colls.sum() > 1:
                    logger.warn('{0}: has multiple collisions'.format(
                        os.path.basename(fname)))
                    continue
                if colls[-1] != 1:
                    logger.warn(
                        '{0}: has collision, but does not end in collision'.
                        format(os.path.basename(fname)))
                    continue

            ### make sure it moved at least a little bit
            encoders = np.array([msg.data for msg in d_bag['encoder/both']])
            if (abs(encoders) > 1e-4).sum() < 2:
                logger.warn('{0}: car never moved'.format(
                    os.path.basename(fname)))
                continue

            ### update env and step
            def update_env(t):
                for key in d_bag.keys():
                    try:
                        self._env.ros_msg_update(d_bag[key][t], [key])
                    except:
                        import IPython
                        IPython.embed()

            update_env(0)
            if len(self._sampler) == 0:
                logger.warn('Resetting!')
                self._sampler.reset(offline=True)

            bag_length = len(d_bag['mode'])
            for t in range(1, bag_length):
                update_env(t)
                action = np.array([
                    d_bag['cmd/steer'][t - 1].data,
                    d_bag['cmd/vel'][t - 1].data
                ])
                self._sampler.step(len(self._sampler),
                                   actions=[action],
                                   offline=True)

            if not self._sampler.is_done_nexts:
                logger.warn(
                    '{0}: did not end in done, manually resetting'.format(
                        os.path.basename(fname)))
                self._sampler.reset(offline=True)

            # if not self._env._is_collision:
            #     logger.warn('{0}: not ending in collision'.format(os.path.basename(fname)))

            timesteps_kept += len(d_bag['mode']) - 1

        logger.info('Adding {0:d} timesteps ({1:.2f} kept)'.format(
            timesteps_kept, timesteps_kept / float(timesteps_total)))
Esempio n. 3
0
    def _add_rosbags(self, sampler, rosbag_filenames):

        timesteps_kept = 0
        timesteps_total = 0
        for fname in rosbag_filenames:
            self._added_rosbag_filenames.append(fname)

            ### read bag file
            try:
                bag = rosbag.Bag(fname, 'r', compression='bz2')
            except Exception as e:
                logger.warn('{0}: could not open'.format(
                    os.path.basename(fname)))
                print(e)
                continue
            d_bag = defaultdict(list)

            # bag.read_messages
            for topic, msg, t in bag.read_messages():
                if topic == 'joystop' and msg.stop == 1:
                    logger.warn(
                        '{0}: has incorrect collision detection. Skipping.'.
                        format(os.path.basename(fname)))
                    continue
                elif topic == 'joystop' and msg.stop == 0:
                    break
                else:
                    d_bag[topic].append(msg)

            bag.close()

            if len(d_bag['cf/0/data']) == 0:
                logger.warn('{0}: has no entries. Skipping.'.format(
                    os.path.basename(fname)))
                continue

            timesteps_total += len(d_bag['cf/0/data']) - 1

            d_bag_parsed = defaultdict(list)
            time_of_coll = 0

            for t in range(
                    min([
                        len(d_bag['cf/0/data']),
                        len(d_bag['cf/0/motion']),
                        len(d_bag['cf/0/coll']),
                        len(d_bag['cf/0/image'])
                    ])):
                for topic in d_bag:
                    try:
                        d_bag_parsed[topic].append(d_bag[topic][t])
                    except:
                        import IPython
                        IPython.embed()
                if d_bag['cf/0/coll'][t].data == 1:
                    time_of_coll = t
                    break

            parsed_colls = np.array(
                [msg.data for msg in d_bag_parsed['cf/0/coll']])
            colls = np.array([msg.data for msg in d_bag['cf/0/coll']])

            if len(parsed_colls) < 10:
                logger.warn(
                    '{0}: had a collision too early: at timestep {1}. Skipping'
                    .format(os.path.basename(fname), len(parsed_colls)))
                continue

            logger.info('Added rosbag: {0}, with {1}/{2} timesteps'.format(
                os.path.basename(fname), len(parsed_colls), len(colls)))
            timesteps_kept += len(parsed_colls) - 1

            ### update env and step
            def update_env(t):
                for key in d_bag.keys():
                    try:
                        sampler.env.ros_msg_update(d_bag[key][t], [key])
                    except Exception as e:
                        print("Issue updating env: ", str(e))
                        import IPython
                        IPython.embed()

            #makes sure no statements are printed
            self._env.suppress_output = True

            update_env(0)
            if len(sampler) == 0:
                logger.warn('Resetting!')
                sampler.reset(offline=True)

            bag_length = min([
                len(d_bag['cf/0/data']),
                len(d_bag['cf/0/motion']),
                len(d_bag['cf/0/coll']),
                len(d_bag['cf/0/image'])
            ])
            for t in range(1, bag_length):
                update_env(t)
                motion = d_bag['cf/0/motion'][t - 1]
                action = np.array([motion.x, motion.y, motion.yaw, motion.dz])
                st = time.time()
                sampler.step(len(sampler), action=action, offline=True)
                endt = time.time()

            self._env.suppress_output = False

            if not sampler.is_done_nexts:
                logger.warn(
                    '{0}: did not end in done, manually resetting'.format(
                        os.path.basename(fname)))
                sampler.reset(offline=True)

            num_steps = len(d_bag['mode']) - 1
            timesteps_kept += num_steps if num_steps > 0 else 0

        logger.info('Adding {0:d} timesteps ({1:.2f} kept)'.format(
            timesteps_kept, timesteps_kept / float(timesteps_total + 1)))
Esempio n. 4
0
    def __init__(self, params={}):
        params.setdefault('dt', 0.25)
        params.setdefault('horizon',
                          int(5. * 60. / params['dt']))  # 5 minutes worth
        params.setdefault('ros_namespace', '/rccar/')
        params.setdefault('obs_shape', (36, 64, 1))
        params.setdefault('steer_limits', [-0.9, 0.9])
        params.setdefault('speed_limits', [0.2, 0.2])
        params.setdefault('backup_motor', -0.22)
        params.setdefault('backup_duration', 1.6)
        params.setdefault('backup_steer_range', (-0.8, 0.8))
        params.setdefault('press_enter_on_reset', False)

        self._use_vel = True
        self._obs_shape = params['obs_shape']
        self._steer_limits = params['steer_limits']
        self._speed_limits = params['speed_limits']
        self._fixed_speed = (self._speed_limits[0] == self._speed_limits[1]
                             and self._use_vel)
        self._collision_reward = params['collision_reward']
        self._collision_reward_only = params['collision_reward_only']

        self._dt = params['dt']
        self.horizon = params['horizon']

        self._setup_spec()
        assert (self.observation_im_space.shape[-1] == 1
                or self.observation_im_space.shape[-1] == 3)
        self.spec = EnvSpec(observation_im_space=self.observation_im_space,
                            action_space=self.action_space,
                            action_selection_space=self.action_selection_space,
                            observation_vec_spec=self.observation_vec_spec,
                            action_spec=self.action_spec,
                            action_selection_spec=self.action_selection_spec,
                            goal_spec=self.goal_spec)

        self._last_step_time = None
        self._is_collision = False
        self._backup_motor = params['backup_motor']
        self._backup_duration = params['backup_duration']
        self._backup_steer_range = params['backup_steer_range']
        self._press_enter_on_reset = params['press_enter_on_reset']

        ### ROS
        if not ROS_IMPORTED:
            logger.warn('ROS not imported')
            return

        rospy.init_node('RWrccarEnv', anonymous=True)
        rospy.sleep(1)

        self._ros_namespace = params['ros_namespace']
        self._ros_topics_and_types = dict([
            ('camera/image_raw/compressed', sensor_msgs.msg.CompressedImage),
            ('mode', std_msgs.msg.Int32), ('steer', std_msgs.msg.Float32),
            ('motor', std_msgs.msg.Float32),
            ('encoder/left', std_msgs.msg.Float32),
            ('encoder/right', std_msgs.msg.Float32),
            ('encoder/both', std_msgs.msg.Float32),
            ('orientation/quat', geometry_msgs.msg.Quaternion),
            ('orientation/rpy', geometry_msgs.msg.Vector3),
            ('imu', geometry_msgs.msg.Accel),
            ('collision/all', std_msgs.msg.Int32),
            ('collision/flip', std_msgs.msg.Int32),
            ('collision/jolt', std_msgs.msg.Int32),
            ('collision/stuck', std_msgs.msg.Int32),
            ('collision/bumper', std_msgs.msg.Int32),
            ('cmd/steer', std_msgs.msg.Float32),
            ('cmd/motor', std_msgs.msg.Float32),
            ('cmd/vel', std_msgs.msg.Float32)
        ])
        self._ros_msgs = dict()
        self._ros_msg_times = dict()
        for topic, type in self._ros_topics_and_types.items():
            rospy.Subscriber(self._ros_namespace + topic, type,
                             self.ros_msg_update, (topic, ))
        self._ros_steer_pub = rospy.Publisher(self._ros_namespace +
                                              'cmd/steer',
                                              std_msgs.msg.Float32,
                                              queue_size=10)
        self._ros_vel_pub = rospy.Publisher(self._ros_namespace + 'cmd/vel',
                                            std_msgs.msg.Float32,
                                            queue_size=10)
        self._ros_motor_pub = rospy.Publisher(self._ros_namespace +
                                              'cmd/motor',
                                              std_msgs.msg.Float32,
                                              queue_size=10)
        self._ros_pid_enable_pub = rospy.Publisher(self._ros_namespace +
                                                   'pid/enable',
                                                   std_msgs.msg.Empty,
                                                   queue_size=10)
        self._ros_pid_disable_pub = rospy.Publisher(self._ros_namespace +
                                                    'pid/disable',
                                                    std_msgs.msg.Empty,
                                                    queue_size=10)

        self._ros_rolloutbag = RolloutRosbag()
        self._t = 0