Ejemplo n.º 1
0
    def run(self):
        start_step, save_itr = self._run_init_inference()
        last_eval_step = 0

        step = start_step
        while step < self._total_steps:
            step += 1

            if step >= self._sample_after_n_steps:
                step = self._run_env_step(step)

            if step - last_eval_step >= self._eval_every_n_steps and self._replay_pool.finished_storing_rollout:
                self._run_env_eval(step, do_sampler_step=True, calculate_holdout=False)
                last_eval_step = step

            if step % self._log_every_n_steps == 0:
                self._run_log(step)

            if step % self._save_every_n_steps == 0:
                logger.info('Saving files for itr {0}'.format(save_itr))
                self._save_inference(save_itr,
                                     self._replay_pool.get_recent_rollouts(),
                                     self._replay_pool_eval.get_recent_rollouts())
                save_itr += 1

        if step >= self._total_steps:
            logger.info('Saving files for itr {0}'.format(save_itr))
            self._save_inference(save_itr,
                                 self._replay_pool.get_recent_rollouts(),
                                 self._replay_pool_eval.get_recent_rollouts())
Ejemplo n.º 2
0
    def _restore_rollouts(self, train_or_eval):
        if train_or_eval == 'train':
            rp = self._replay_pool
            fname_func = self._fm.train_rollouts_fname
        elif train_or_eval == 'eval':
            rp = self._replay_pool_eval
            fname_func = self._fm.eval_rollouts_fname
        else:
            raise ValueError(
                'train_or_eval must be train or eval, not {0}'.format(
                    train_or_eval))

        if rp is None:
            return

        itr = 0
        rollout_filenames = []
        while True:
            fname = fname_func(itr)
            if not os.path.exists(fname):
                break

            rollout_filenames.append(fname)
            itr += 1

        logger.info('Restoring {0} iterations of {1} rollouts....'.format(
            itr, train_or_eval))
        rp.store_rollouts(rollout_filenames)
        logger.info('Done restoring rollouts!')
Ejemplo n.º 3
0
 def _add_offpolicy(self, folders, max_to_add):
     for folder in folders:
         assert (os.path.exists(folder))
         logger.info('Loading offpolicy data from {0}'.format(folder))
         rollout_filenames = [os.path.join(folder, fname) for fname in os.listdir(folder) if 'train_rollouts.pkl' in fname]
         self._sampler.add_rollouts(rollout_filenames, max_to_add=max_to_add)
     logger.info('Added {0} samples'.format(len(self._sampler)))
Ejemplo n.º 4
0
    def run(self):
        start_step, save_itr = self._run_init_train()

        step = start_step
        while step < self._total_steps:
            step += 1

            if step % self._eval_every_n_steps == 0:
                self._run_env_eval(step,
                                   do_sampler_step=False,
                                   calculate_holdout=True)

            if step >= self._learn_after_n_steps:
                self._run_train_step(step)

            if step % self._log_every_n_steps == 0:
                self._run_log(step)

            if step % self._save_every_n_steps == 0:
                logger.info('Saving files for itr {0}'.format(save_itr))
                self._save_train(save_itr)
                save_itr += 1

        if step >= self._total_steps:
            logger.info('Saving files for itr {0}'.format(save_itr))
            self._save_train(save_itr)
Ejemplo n.º 5
0
    def run(self):
        self._sampler.reset()
        step = 0
        itr = 0
        logger.info('Step {0}'.format(step))
        while step < self._steps:
            self._sampler.step(step, take_random_actions=False)
            step += 1

            if step > 0 and step % 1000 == 0:
                logger.info('Step {0}'.format(step))
                rollouts = self._sampler.get_recent_paths()
                if len(rollouts) > 0:
                    lengths = [len(r['dones']) for r in rollouts]
                    logger.info('Lengths: {0:.1f} +- {1:.1f}'.format(
                        np.mean(lengths), np.std(lengths)))
                    mypickle.dump({'rollouts': rollouts},
                                  self._onpolicy_file_name(itr))
                    itr += 1

        rollouts = self._sampler.get_recent_paths()
        if len(rollouts) > 0:
            logger.info('Step {0}'.format(step))
            lengths = [len(r['dones']) for r in rollouts]
            logger.info('Lengths: {0:.1f} +- {1:.1f}'.format(
                np.mean(lengths), np.std(lengths)))
            mypickle.dump({'rollouts': rollouts},
                          self._onpolicy_file_name(itr))
Ejemplo n.º 6
0
    def train(self):
        self._start_train_batch()

        logger.info('Training model')

        alg_args = self._params['alg']
        total_steps = int(alg_args['total_steps'])
        save_every_n_steps = int(alg_args['save_every_n_steps'])
        update_target_after_n_steps = int(
            alg_args['update_target_after_n_steps'])
        update_target_every_n_steps = int(
            alg_args['update_target_every_n_steps'])
        log_every_n_steps = int(alg_args['log_every_n_steps'])

        timeit.reset()
        timeit.start('total')
        save_itr = 0
        for step in range(total_steps):
            timeit.start('sample')
            # steps, observations, actions, rewards, dones, _ = self._replay_pool.sample(batch_size)
            steps, observations, actions, rewards, dones, _ = self._batch_queue.get(
            )
            timeit.stop('sample')
            timeit.start('train')
            self._model.train_step(step,
                                   steps=steps,
                                   observations=observations,
                                   actions=actions,
                                   rewards=rewards,
                                   dones=dones,
                                   use_target=True)
            timeit.stop('train')

            ### update target network
            if step > update_target_after_n_steps and step % update_target_every_n_steps == 0:
                self._model.update_target()

            ### log
            if step > 0 and step % log_every_n_steps == 0:
                logger.record_tabular('Step', step)
                self._model.log()
                logger.dump_tabular(print_func=logger.info)

                timeit.stop('total')
                for line in str(timeit).split('\n'):
                    logger.debug(line)
                timeit.reset()
                timeit.start('total')

            ### save model
            if step > 0 and step % save_every_n_steps == 0:
                logger.info('Saving files for itr {0}'.format(save_itr))
                self._save_train_policy(save_itr)
                save_itr += 1

        ### always save the end
        self._save_train_policy(save_itr)

        self._stop_train_batch()
Ejemplo n.º 7
0
    def _save(self, rollouts, new_rollouts):
        assert (len(new_rollouts) > 0)

        logger.info('Saving rollouts')
        rollouts += new_rollouts
        self._save_rollouts(self._eval_itr, rollouts)

        return rollouts
Ejemplo n.º 8
0
 def _eval_reset(self, **kwargs):
     while True:
         try:
             self._sampler.reset(**kwargs)
             break
         except Exception as e:
             logger.warn('Reset exception {0}'.format(str(e)))
             logger.info('Press enter to continue')
             input()
             logger.info('')
Ejemplo n.º 9
0
 def _add_offpolicy(self, folders, max_to_add):
     for folder in folders:
         assert (os.path.exists(folder))
         logger.info('Loading rosbag data from {0}'.format(folder))
         rosbag_filenames = sorted([
             os.path.join(folder, fname) for fname in os.listdir(folder)
             if '.bag' in fname
         ])
         self._add_rosbags(rosbag_filenames)
     logger.info('Added {0} samples'.format(len(self._sampler)))
Ejemplo n.º 10
0
def create_training_data(save_folder, image_shape, rescale, bordersize,
                         holdout_pct):
    """
    :param save_folder: where images are saved
    :param image_shape: shape of the image
    :param rescale: make rescale times bigger, for ease of labelling
    :param bordersize: how much to pad with 0s, for ease of labelling
    :param holdout_pct: how much data to holdout
    """

    ### read image, label pairs
    label_fnames = glob.glob(os.path.join(save_folder, 'label*'))
    random.shuffle(label_fnames)

    height, width, channels = image_shape

    images_train, labels_train = [], []
    images_holdout, labels_holdout = [], []
    for i, label_fname in enumerate(label_fnames):
        image_fname = label_fname.replace('label_', '')

        image = np.asarray(Image.open(image_fname))
        label = np.asarray(Image.open(label_fname))

        # reduce image back down
        image = image[bordersize:-bordersize, bordersize:-bordersize]
        image = utils.imresize(image, (height, width, channels))
        assert (tuple(image.shape) == tuple(image_shape))

        # reduce label back down
        label = label[bordersize:-bordersize, bordersize:-bordersize]
        label = utils.imresize(label, (height, width, 1), Image.BILINEAR)
        label = label[:, :, 0]
        label = (label > 0.5)
        assert (tuple(label.shape) == (height, width))

        if i / float(len(label_fnames)) > holdout_pct:
            images_train.append(image)
            labels_train.append(label)
        else:
            images_holdout.append(image)
            labels_holdout.append(label)

    np.save(os.path.join(save_folder, 'data_train_images.npy'),
            np.array(images_train))
    np.save(os.path.join(save_folder, 'data_train_labels.npy'),
            np.array(labels_train))

    np.save(os.path.join(save_folder, 'data_holdout_images.npy'),
            np.array(images_holdout))
    np.save(os.path.join(save_folder, 'data_holdout_labels.npy'),
            np.array(labels_holdout))

    logger.info('Saved train and holdout')
Ejemplo n.º 11
0
 def _log(self, msg, lvl):
     if not self.suppress_output:
         if lvl == "info":
             logger.info(msg)
         elif lvl == "debug":
             logger.debug(msg)
         elif lvl == "warn":
             logger.warn(msg)
         elif lvl == "error":
             logger.error(msg)
         else:
             print("NOT VALID LOG LEVEL")
Ejemplo n.º 12
0
 def _add_offpolicy(self, folders, max_to_add):
     for folder in folders:
         assert (os.path.exists(folder),
                 'offpolicy folder {0} does not exist'.format(folder))
         logger.info('Loading offpolicy data from {0}'.format(folder))
         rollout_filenames = [
             os.path.join(folder, fname) for fname in os.listdir(folder)
             if FileManager.train_rollouts_fname_suffix in fname
         ]
         self._replay_pool.store_rollouts(rollout_filenames,
                                          max_to_add=max_to_add)
     logger.info('Added {0} samples'.format(len(self._replay_pool)))
Ejemplo n.º 13
0
    def _restore_train_policy(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        while len(glob.glob(self._train_policy_file_name(itr) + '*')) > 0:
            itr += 1

        if itr > 0:
            logger.info('Loading train policy from iteration {0}...'.format(itr - 1))
            self._policy.restore(self._train_policy_file_name(itr - 1), train=True)
            logger.info('Loaded train policy!')
Ejemplo n.º 14
0
    def _eval_step(self):
        try:
            self._sampler.step(step=0,
                               take_random_actions=False,
                               explore=False)
        except Exception as e:
            logger.warn('Sampler exception {0}'.format(str(e)))
            self._sampler.trash_current_rollouts()

            logger.info('Press enter to continue')
            input()
            self._eval_reset(keep_rosbag=False)
Ejemplo n.º 15
0
    def _run_init_train(self):
        train_itr = self._fm.get_train_itr()
        if train_itr > 0:
            logger.info('Restore train iteration {0}'.format(train_itr - 1))
            self._policy.restore(self._fm.train_policy_fname(train_itr - 1), train=True)

        save_itr = train_itr
        start_step = save_itr * self._save_every_n_steps

        timeit.reset()
        timeit.start('total')

        return start_step, save_itr
Ejemplo n.º 16
0
    def _restore_inference_policy(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        while len(glob.glob(os.path.splitext(self._load_inference_policy_file_name(itr))[0] + '*')) > 0:
            itr += 1
        itr -= 1

        if itr > 0:
            logger.info('Loading inference policy from iteration {0}...'.format(itr))
            self._policy.restore(self._load_inference_policy_file_name(itr), train=False)
            logger.info('Loaded inference policy!')

        return itr
Ejemplo n.º 17
0
    def _restore_train_policy(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        while len(glob.glob(os.path.splitext(self._load_train_policy_file_name(itr))[0] + '*')) > 0:
            itr += 1
        itr -= 1

        if itr >= 0:
            logger.info('Loading train policy from {0} iteration {1}...'.format(self._load_dir, itr))
            self._policy.restore(self._load_train_policy_file_name(itr), train=True)
            logger.info('Loaded train policy!')

        return itr
Ejemplo n.º 18
0
    def _restore_train_rollouts(self):
        """
        :return: iteration that it is currently on
        """
        itr = 0
        rollout_filenames = []
        while True:
            fname = self._train_rollouts_file_name(itr)
            if not os.path.exists(fname):
                break

            rollout_filenames.append(fname)
            itr += 1

        logger.info('Restoring {0} iterations of train rollouts....'.format(itr))
        self._sampler.add_rollouts(rollout_filenames)
        logger.info('Done restoring rollouts!')
Ejemplo n.º 19
0
    def evaluate(self, eval_on_holdout=False):
        logger.info('Evaluating model')

        if eval_on_holdout:
            replay_pool = self._replay_holdout_pool
        else:
            replay_pool = self._replay_pool

        # get collision idx in obs_vec
        vec_spec = self._env.observation_vec_spec
        obs_vec_start_idxs = np.cumsum([space.flat_dim for space in vec_spec.values()]) - 1
        coll_idx = obs_vec_start_idxs[list(vec_spec.keys()).index('coll')]

        # model will be evaluated on 1e3 inputs at a time (accounting for bnn samples)
        batch_size = 1000 // self._num_bnn_samples
        assert (batch_size > 1)
        rp_gen = replay_pool.sample_all_generator(batch_size=batch_size, include_env_infos=True)

        # keep everything in dict d
        d = defaultdict(list)
        for steps, (observations_im, observations_vec), actions, rewards, dones, env_infos in rp_gen:
            observations = (observations_im[:, :self._model.obs_history_len, :],
                            observations_vec[:, :self._model.obs_history_len, :])
            coll_labels = (np.cumsum(observations_vec[:, self._model.obs_history_len:, coll_idx], axis=1) >= 1.).astype(float)

            observations_repeat = (np.repeat(observations[0], self._num_bnn_samples, axis=0),
                                   np.repeat(observations[1], self._num_bnn_samples, axis=0))
            actions_repeat = np.repeat(actions, self._num_bnn_samples, axis=0)

            yhats, bhats = self._model.get_model_outputs(observations_repeat, actions_repeat)
            coll_preds = np.reshape(yhats['coll'], (len(steps), self._num_bnn_samples, -1))

            d['coll_labels'].append(coll_labels)
            d['coll_preds'].append(coll_preds)
            d['env_infos'].append(env_infos)
            # Note: you can save more things (e.g. actions) if you want to do something with them later

        for k, v in d.items():
            d[k] = np.concatenate(v)

        # d['coll_preds'] has shape (num_replays, self._num_bnn_samples, horizon)
        # d['coll_labels'] has shape (num_replays, horizon)
        plotter = BnnPlotter(d['coll_preds'], d['coll_labels'])
        plotter.save_all_plots(self._save_dir)
        import IPython; IPython.embed()
Ejemplo n.º 20
0
    def reset(self, offline=False):
        if offline:
            self._is_collision = False
            return self._get_observation()

        assert (self.ros_is_good())

        if self._ros_rolloutbag.is_open:
            # should've been closed in step when done
            logger.debug('Trashing bag')
            self._ros_rolloutbag.trash()

        if self._press_enter_on_reset:
            logger.info('Resetting, press enter to continue')
            input()
        else:
            if self._is_collision:
                logger.debug('Resetting (collision)')
            else:
                logger.debug('Resetting (no collision)')

            if self._ros_msgs['collision/flip'].data:
                logger.warn('Car has flipped, please unflip it to continue')
                while self._ros_msgs['collision/flip'].data:
                    rospy.sleep(0.1)
                logger.warn('Car is now unflipped. Continuing...')
                rospy.sleep(1.)

            backup_steer = np.random.uniform(*self._backup_steer_range)
            self._set_steer(backup_steer)
            self._set_motor(self._backup_motor, self._backup_duration)
            self._set_steer(0.)
            self._set_vel(0.)

        rospy.sleep(0.5)

        self._last_step_time = rospy.Time.now()
        self._is_collision = False
        self._t = 0

        self._ros_rolloutbag.open()

        assert (self.ros_is_good())

        return self._get_observation()
Ejemplo n.º 21
0
    def _eval_model(self):
        logger.info('Restoring model')
        self._trav_graph.restore()

        logger.info('Evaluating model')

        while True:
            obs, labels, probs, obs_holdout, labels_holdout, probs_holdout = self._trav_graph.eval(
            )

            for obs_t, labels_t, probs_t, obs_holdout_t, labels_holdout_t, probs_holdout_t in \
                    zip(obs, labels, probs, obs_holdout, labels_holdout, probs_holdout):
                f, axes = plt.subplots(2, 4, figsize=(20, 5))

                axes[0, 0].imshow(obs_t)
                axes[0, 1].imshow(labels_t[..., 0],
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[0, 2].imshow(probs_t[..., 1],
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[0, 3].imshow(abs(labels_t[..., 0] - probs_t[..., 1]),
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)

                axes[1, 0].imshow(obs_holdout_t)
                axes[1, 1].imshow(labels_holdout_t[..., 0],
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[1, 2].imshow(probs_holdout_t[..., 1],
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[1, 3].imshow(abs(labels_holdout_t[..., 0] -
                                      probs_holdout_t[..., 1]),
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)

                plt.show()
Ejemplo n.º 22
0
    def _eval_save(self, rollouts, new_rollouts):
        logger.info('')
        logger.info('Keep rollout?')
        response = input()
        if response != 'y':
            logger.info('NOT saving rollouts')
        else:
            logger.info('Saving rollouts')
            rollouts += new_rollouts
            self._save_eval_rollouts(rollouts)

        return rollouts
Ejemplo n.º 23
0
    def reset(self, offline=False, keep_rosbag=False):
        if offline:
            self._is_collision = False
            return self._get_observation(), self._get_goal()

        assert (self.ros_is_good())

        if self._ros_rolloutbag.is_open:
            if keep_rosbag:
                self._ros_rolloutbag.close()
            else:
                # should've been closed in step when done
                logger.debug('Trashing bag')
                self._ros_rolloutbag.trash()

        if self._press_enter_on_reset:
            logger.info('Resetting, press enter to continue')
            input()
        else:
            if self._is_collision:
                logger.debug('Resetting (collision)')
            else:
                logger.debug('Resetting (no collision)')

            if self._backup_duration > 0:
                backup_steer = np.random.uniform(*self._backup_steer_range)
                self._set_steer(backup_steer)
                self._set_motor(self._backup_motor, self._backup_duration)
            self._set_steer(0.)
            self._set_vel(0.)

        rospy.sleep(0.5)

        self._last_step_time = rospy.Time.now()
        self._is_collision = False
        self._t = 0

        self._ros_rolloutbag.open()

        assert (self.ros_is_good())

        return self._get_observation(), self._get_goal()
Ejemplo n.º 24
0
 def _add_offpolicy(self, folders, max_to_add):
     for folder in folders:
         assert (os.path.exists(folder),
                 'offpolicy folder {0} does not exist'.format(folder))
         logger.info('Loading rosbag data from {0}'.format(folder))
         rosbag_filenames = sorted([
             os.path.join(folder, fname) for fname in os.listdir(folder)
             if '.bag' in fname
         ])
         train_rosbag_filenames, holdout_rosbag_filenames = self._split_rollouts(
             rosbag_filenames)
         logger.info('Adding train...')
         self._add_rosbags(self._sampler, self._replay_pool,
                           train_rosbag_filenames)
         logger.info('Adding holdout...')
         self._add_rosbags(self._sampler_eval, self._replay_pool_eval,
                           holdout_rosbag_filenames)
     logger.info('Added {0} train samples'.format(len(self._replay_pool)))
     logger.info('Added {0} holdout samples'.format(
         len(self._replay_pool_eval)))
Ejemplo n.º 25
0
    def _load_data(self, folder):
        """
        Loads all .pkl files that can be found recursively from this folder
        """
        assert (os.path.exists(folder))

        rollouts = []
        num_load_success, num_load_fail = 0, 0
        for fname in glob.iglob('{0}/**/*.pkl'.format(folder), recursive=True):
            try:
                rollouts += mypickle.load(fname)['rollouts']
                num_load_success += 1
            except:
                num_load_fail += 1
        logger.info('Files successfully loaded: {0:.2f}%'.format(
            100. * num_load_success / float(num_load_success + num_load_fail)))

        num_bootstraps = self.num_bootstraps
        if num_bootstraps is not None:
            logger.info('Creating {0} bootstraps'.format(num_bootstraps))
            ReplayPoolClass = lambda **kwargs: BootstrapReplayPool(
                num_bootstraps, **kwargs)
        else:
            ReplayPoolClass = ReplayPool

        replay_pool = ReplayPoolClass(
            env_spec=self._env.spec,
            env_horizon=self._env.horizon,
            N=self._model.N,
            gamma=self._model.gamma,
            size=int(1.1 * sum([len(r['dones']) for r in rollouts])),
            obs_history_len=self._model.obs_history_len,
            sampling_method='uniform',
            save_rollouts=False,
            save_rollouts_observations=False,
            save_env_infos=True,
            replay_pool_params={})

        replay_pool.store_rollouts(0, rollouts)

        return replay_pool
Ejemplo n.º 26
0
def create_labels(save_folder, image_shape, rescale, bordersize):
    """
    :param save_folder: where images are saved
    :param image_shape: shape of the image
    :param rescale: make rescale times bigger, for ease of labelling
    :param bordersize: how much to pad with 0s, for ease of labelling
    """

    ### load csv
    csv_fname = os.path.join(save_folder, 'via_region_data.csv')
    assert (os.path.exists(csv_fname))
    csv = pandas.read_csv(csv_fname)

    ### image indices
    xdim = int(rescale) * image_shape[0] + 2 * bordersize
    ydim = int(rescale) * image_shape[1] + 2 * bordersize
    x, y = np.meshgrid(np.arange(xdim), np.arange(ydim))
    indices = np.vstack((y.flatten(), x.flatten())).T

    ### create labels
    num_labelled = 0
    for i in range(len(csv)):
        fname = csv['#filename'][i]
        region_shape_attrs = eval(csv['region_shape_attributes'][i])
        if 'name' not in region_shape_attrs.keys():
            continue

        assert (region_shape_attrs['name'] == 'polygon')
        xy = np.stack((region_shape_attrs['all_points_x'],
                       region_shape_attrs['all_points_y'])).T
        label = 1 - Path(xy).contains_points(indices).reshape(
            ydim, xdim).T  # 0 is no collision
        label = label.astype(np.uint8)

        label_fname = os.path.join(
            save_folder, 'label_' + os.path.splitext(fname)[0] + '.jpg')
        Image.fromarray(label).save(label_fname)

        num_labelled += 1

    logger.info('{0} were labelled'.format(num_labelled))
Ejemplo n.º 27
0
    def _restore_train_rollouts(self):
        """
        :return: iteration that it is currently on
        """
        rosbag_num = 0
        rosbag_filenames = []
        while True:
            fname = self._rosbag_file_name(rosbag_num)
            if not os.path.exists(fname):
                break

            rosbag_num += 1
            if fname in self._added_rosbag_filenames:
                continue  # don't add already added rosbag filenames

            rosbag_filenames.append(fname)

        if len(rosbag_filenames) > 0:
            logger.info('Restoring {0} rosbags....'.format(rosbag_num))
            self._add_rosbags(rosbag_filenames)
            logger.info('Done restoring rosbags!')
Ejemplo n.º 28
0
    def eval_model(self):
        logger.info('Creating model')
        trav_graph = TraversabilityGraph(self._obs_shape, self._save_folder,
                                         **labeller_params)

        logger.info('Restoring model')
        trav_graph.restore()

        logger.info('Evaluating model')

        while True:
            obs, labels, probs, obs_holdout, labels_holdout, probs_holdout = trav_graph.eval(
            )

            for obs_t, labels_t, probs_t, obs_holdout_t, labels_holdout_t, probs_holdout_t in \
                    zip(obs, labels, probs, obs_holdout, labels_holdout, probs_holdout):
                f, axes = plt.subplots(2, 4, figsize=(20, 5))

                axes[0, 0].imshow(obs_t)
                axes[0, 1].imshow(labels_t[..., 0],
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[0, 2].imshow(probs_t, cmap='Greys', vmin=0, vmax=1)
                axes[0, 3].imshow(abs(labels_t[..., 0] - probs_t),
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)

                axes[1, 0].imshow(obs_holdout_t)
                axes[1, 1].imshow(labels_holdout_t[..., 0],
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[1, 2].imshow(probs_holdout_t,
                                  cmap='Greys',
                                  vmin=0,
                                  vmax=1)
                axes[1,
                     3].imshow(abs(labels_holdout_t[..., 0] - probs_holdout_t),
                               cmap='Greys',
                               vmin=0,
                               vmax=1)

                axes[0, 0].set_title('Input')
                axes[0, 1].set_title('Ground truth segmentation')
                axes[0, 2].set_title('Predicted segmentation')
                axes[0, 3].set_title('abs(Ground truth - predicted)')

                axes[0, 0].set_ylabel('Training')
                axes[1, 0].set_ylabel('Holdout')

                plt.show(block=False)
                plt.pause(0.1)
                response = input(
                    'Press enter to continue, or "quit" to exit\n')
                if response == 'quit':
                    return
                plt.close(f)
Ejemplo n.º 29
0
def extract_images_from_pkls(pkl_folder, save_folder, maxsaved, image_shape,
                             rescale, bordersize):
    """
    :param pkl_folder: folder containing pkls with training images
    :param save_folder: where to save the resulting images
    :param maxsaved: how many images to save
    :param image_shape: shape of the image
    :param rescale: make rescale times bigger, for ease of labelling
    :param bordersize: how much to pad with 0s, for ease of labelling
    """
    random.seed(0)

    fnames = glob.glob(os.path.join(pkl_folder, '*.pkl'))
    random.shuffle(fnames)
    logger.info('{0} files to read'.format(len(fnames)))
    fnames = itertools.cycle(fnames)

    im_num = 0
    while im_num < maxsaved:
        fname = next(fnames)
        rollout = random.choice(mypickle.load(fname)['rollouts'])
        obs = random.choice(rollout['observations_im'])

        height, width, channels = image_shape

        im = np.reshape(obs, image_shape)
        im = utils.imresize(im, (rescale * height, rescale * width, channels))
        im = np.pad(im, ((bordersize, bordersize), (bordersize, bordersize),
                         (0, 0)), 'constant')
        if im.shape[-1] == 1:
            im = im[:, :, 0]
        Image.fromarray(im).save(
            os.path.join(save_folder, 'image_{0:06d}.jpg'.format(im_num)))
        im_num += 1

    logger.info('Saved {0} images'.format(im_num))
Ejemplo n.º 30
0
    def _eval_pred_all(self, eval_on_holdout):
        pkl_file_name = self._eval_train_rollouts_file_name if not eval_on_holdout else self._eval_holdout_rollouts_file_name

        if os.path.exists(pkl_file_name):
            logger.info('Load evaluation rollouts for {0}'.format(
                'holdout' if eval_on_holdout else 'train'))
            d = mypickle.load(pkl_file_name)
        else:
            logger.info('Evaluating model on {0}'.format(
                'holdout' if eval_on_holdout else 'train'))

            replay_pool = self._replay_holdout_pool if eval_on_holdout else self._replay_pool

            # get collision idx in obs_vec
            vec_spec = self._env.observation_vec_spec
            obs_vec_start_idxs = np.cumsum(
                [space.flat_dim for space in vec_spec.values()]) - 1
            coll_idx = obs_vec_start_idxs[list(vec_spec.keys()).index('coll')]

            # model will be evaluated on 1e3 inputs at a time (accounting for bnn samples)
            batch_size = 1000 // self._num_bnn_samples
            assert (batch_size > 1)
            rp_gen = replay_pool.sample_all_generator(batch_size=batch_size,
                                                      include_env_infos=True)

            # keep everything in dict d
            d = defaultdict(list)
            for steps, (observations_im, observations_vec
                        ), actions, rewards, dones, env_infos in rp_gen:
                observations = (
                    observations_im[:, :self._model.obs_history_len, :],
                    observations_vec[:, :self._model.obs_history_len, :])
                coll_labels = (np.cumsum(
                    observations_vec[:, self._model.obs_history_len:,
                                     coll_idx],
                    axis=1) >= 1.).astype(float)

                observations_repeat = (np.repeat(observations[0],
                                                 self._num_bnn_samples,
                                                 axis=0),
                                       np.repeat(observations[1],
                                                 self._num_bnn_samples,
                                                 axis=0))
                actions_repeat = np.repeat(actions,
                                           self._num_bnn_samples,
                                           axis=0)

                yhats, bhats = self._model.get_model_outputs(
                    observations_repeat, actions_repeat)
                coll_preds = np.reshape(
                    yhats['coll'], (len(steps), self._num_bnn_samples, -1))

                d['coll_labels'].append(coll_labels)
                d['coll_preds'].append(coll_preds)
                d['env_infos'].append(env_infos)
                d['dones'].append(dones)
                # Note: you can save more things (e.g. actions) if you want to do something with them later

            for k, v in d.items():
                d[k] = np.concatenate(v)

            mypickle.dump(d, pkl_file_name)

        return d