def _init_params_to_attrs(self, params):
        self._num_latent_classes = params.num_latent_classes  # latent classes
        self._latent_dim = params.latent_dim

        self._beta_kl = params.beta_kl

        known_latent_mu = params.known_latent_default_mu
        known_latent_log_sigma = params.known_latent_default_log_sigma

        if known_latent_mu is None or known_latent_log_sigma is None:
            logger.warn("[Gaussian Latent]: Assuming unknown latent values")
            self._unknown_latent = True

            known_latent_mu = np.zeros(
                (self._num_latent_classes, self._latent_dim))
            known_latent_log_sigma = np.zeros(
                (self._num_latent_classes, self._latent_dim))

        self._latent_default_mu = np.broadcast_to(
            np.array(known_latent_mu, dtype=np.float32),
            (self._num_latent_classes, self._latent_dim))
        self._latent_default_log_sigma = np.broadcast_to(
            np.array(known_latent_log_sigma, dtype=np.float32),
            (self._num_latent_classes, self._latent_dim))

        # mean and log_sigma midpoints are used for the online mean
        self._online_latent_default_mu = self._latent_default_mu.mean(axis=0)
        self._online_latent_default_log_sigma = np.zeros(self._latent_dim,
                                                         dtype=np.float32)
 def ros_is_good(self, display):
     for topic in self._ros_topics_and_types.keys():
         # make sure we have received at least one msg of each type recently
         if 'coll' not in topic:
             if topic not in self._ros_msg_times:
                 if display:
                     logger.warn('Topic {0} has never been received'.format(topic))
                 return False
             elapsed = (rospy.Time.now() - self._ros_msg_times[topic]).to_sec()
             if elapsed > 2:  # self._dt * 50:
                 if display:
                     logger.warn(
                         'Topic {0} was received {1} seconds ago (dt is {2})'.format(topic, elapsed, self._dt))
                 return False
     return True
    def _step(self, action, **kwargs):

        with timeit("copter_step", reset_on_stop=True):
            obs, _, done = self._copter.step(action, no_rew_pub=True, **kwargs)

        pos = obs.obs[0]
        true_reward = -self._get_cost(pos)
        reward = self.reward_mask(pos, self.trajectory.get_i())

        # TODO
        self._copter.pub_rew(rx=true_reward,
                             ry=reward)  # publish the true reward and masked reward for the sake of accurate data collect

        # print(pos[:3], self.ac_goal_pos[:3], self.ac_goal_pos_index)
        # print('reward', reward)
        # print(action, np.array2string(pos, separator=', '), self._copter._curr_goal_pos, reward, true_reward, done,
        #       self.trajectory.get_i(), self.trajectory.curr_goal)

        self._copter.set_target(self.trajectory.next(pos), self.fix_waypoint3)

        done = done or self.trajectory.is_finished()  # returns true when the full trajectory has been run through

        ### online control safety measures

        # stop if out of frame
        if pos[0] < 1e-4 and pos[1] < 1e-4:
            # target is out of frame
            done = True
            logger.warn("[TELLO CTRL] Target Left frame! Terminating rollout")

        # if terminating, send stop signal
        if self._use_data_capture and not self._copter.offline and done:
            self._copter.sig_rollout(end=True)

        goal = self.get_goal()

        def arr3D_2str(arr, names=("x", "y", "z")):
            str = "{%s: %.5f, %s: %.5f, %s: %.5f }" \
                  % (names[0], arr[0], names[1], arr[1], names[2], arr[2])
            return str

        logger.debug("[TPC] T: %.4f sec // OBS: %s -- GOAL: %s -- ACT: %s" %
                     (timeit.elapsed("copter_step"), arr3D_2str(pos), arr3D_2str(goal.goal_obs[0, 0]), arr3D_2str(action.act[0])))

        return obs, goal, done
    def ros_msg_update(self, msg, args):
        topic = args[0]

        if 'coll' in topic:
            # when collision is True
            if msg.data:
                self.collided = True
                self._done = True
        elif 'data' in topic:
            pass
        elif 'target_vector' in topic:
            # update state
            self._obs[0] = msg.vector.x * self._coef_cx  # cz
            self._obs[1] = msg.vector.y * self._coef_cy  # cy
            self._obs[2] = msg.vector.z * self._coef_area  # area
        else:
            logger.warn("[CF]: Unhandled ROS msg")

        self._ros_msgs[topic] = msg
        self._ros_msg_times[topic] = rospy.Time.now()
Exemple #5
0
def split_data_by_episodes(samples,
                           horizon,
                           n_obs=0,
                           n_acs=0,
                           truncate_size=-1):
    assert horizon >= 1

    # new change
    has_latent = True
    if 'latent' not in samples:
        logger.warn("No latent in dataset!")
        has_latent = False

    # how many time steps to ignore in the beginning when the history is not long enough
    pre_remove = 0
    # pre_remove = max(nobs, nacs)

    obs_all = samples['obs']
    acs_all = samples['acs']
    if has_latent:
        latent_all = samples['latent']

    mu_obs = np.mean(obs_all, axis=0)
    sigma_obs = np.std(obs_all, axis=0)

    if 'episode_sizes' in samples:
        episode_sizes = samples['episode_sizes'].flatten().astype(int)
    else:
        episode_sizes = np.array([obs_all.shape[0]])

    ep_sizes = np.cumsum(episode_sizes)

    # truncate
    if 0 < truncate_size < np.sum(episode_sizes):
        ep_max = ep_sizes.size
        for i in range(ep_sizes.size):
            if ep_sizes[i] > truncate_size:
                ep_max = i
                break

        print("Truncating from %d to %d episodes (%d to %d samples)" %
              (ep_sizes.size, ep_max, np.sum(episode_sizes),
               np.sum(episode_sizes[:, :ep_max])))
        episode_sizes = episode_sizes[:, :ep_max]
        ep_sizes = ep_sizes[:ep_max]

    obs_list = np.split(obs_all, ep_sizes, axis=0)
    acs_list = np.split(acs_all, ep_sizes, axis=0)
    if has_latent:
        latent_list = np.split(latent_all, ep_sizes, axis=0)

    obs_start_list = []
    acs_start_list = []
    next_obs_list = []
    latent_start_list = []
    prev_obs_start_list = []
    prev_acs_start_list = []
    obs_seq_list = []
    ac_seq_list = []
    done_list = []

    new_episode_sizes = []

    for ep in range(ep_sizes.size):
        obs = obs_list[ep]
        acs = acs_list[ep]
        if has_latent:
            latent = latent_list[ep]

        ### Action sequences

        # row i corresponds to actions taken until action i
        prev_acs = split_dim_np(fill_n_prev(acs, horizon - 1),
                                axis=1,
                                new_shape=[horizon - 1] + list(acs.shape[1:]))
        prev_obs_horizon = split_dim_np(fill_n_prev(obs,
                                                    horizon,
                                                    initial_zero=False),
                                        axis=1,
                                        new_shape=[horizon] +
                                        list(obs.shape[1:]))

        # appending to create action sequence list removing the first plan_hor-1 elements
        #  (and last element since we don't know the result of it)
        ac_seq = np.concatenate([acs[:, None], prev_acs],
                                axis=1)[horizon - 1:-1]

        # un-reversing and removing the initial ones
        ac_seq = ac_seq[pre_remove:, ::-1, :]

        ### Observation sequences

        # obs sequence in the future
        obs_seq = np.concatenate([obs[:, None], prev_obs_horizon],
                                 axis=1)[horizon:]
        # un reversing
        obs_seq = obs_seq[pre_remove:, ::-1, :]

        # trashing data
        if obs_seq.shape[0] < 2:
            print("[] Trashing rollout %d due to lack of samples" % int(ep))
            continue

        ### Action histories

        prev_acs = split_dim_np(fill_n_prev(acs, n_acs),
                                axis=1,
                                new_shape=[n_acs] + list(acs.shape[1:]))
        acs_start = acs[pre_remove:-horizon]
        prev_acs_start = prev_acs[pre_remove:-horizon]

        ### Observation histories

        # we don't use initial_zero=False here bc we remove the first pre_remove anyways
        prev_obs = split_dim_np(fill_n_prev(obs, n_obs),
                                axis=1,
                                new_shape=[n_obs] + list(obs.shape[1:]))
        obs_start = obs[pre_remove:-horizon]
        next_obs = obs[pre_remove + 1:-horizon + 1]
        if has_latent:
            latent_start = latent[pre_remove:-horizon].astype(int)
        else:
            latent_start = np.zeros((obs_start.shape[0], 0)).astype(int)
        prev_obs_start = prev_obs[pre_remove:-horizon]

        obs_seq_list.append(obs_seq)  # (N x H+1 x dO)
        ac_seq_list.append(ac_seq)  # (N x H x dU)
        obs_start_list.append(obs_start)  # (N x dO)
        acs_start_list.append(acs_start)  # (N x dU)
        next_obs_list.append(next_obs)  # (N x dO)
        latent_start_list.append(latent_start)  # (N x 1)
        prev_obs_start_list.append(prev_obs_start)  # (N x nobs x dO)
        prev_acs_start_list.append(prev_acs_start)  # (N x nacs x dO)
        done_list.append(
            np.array([False for _ in range(obs_start.shape[0] - 1)] + [True],
                     dtype=np.bool))
        new_episode_sizes.append(obs_start.shape[0])

    # remove bad eps:
    episode_sizes = np.array(new_episode_sizes)

    # some input statistics
    delta_obs = np.concatenate(next_obs_list, axis=0) - np.concatenate(
        obs_start_list, axis=0)
    mu_delta_obs = np.mean(delta_obs, axis=0)
    sigma_delta_obs = np.std(delta_obs, axis=0)
    next_obs_sigma_list = [
        np.tile(sigma_delta_obs[None], (next_obs.shape[0], 1))
        for next_obs in next_obs_list
    ]

    return_dict = {
        'mu_obs': mu_obs,
        'sigma_obs': sigma_obs,
        'mu_delta_obs': mu_delta_obs,
        'sigma_delta_obs': sigma_delta_obs,
        'episode_sizes': episode_sizes,
        'done': done_list,
        'obs_full': obs_list,
        'act_full': acs_list,
        # 'obs_seq': obs_seq_list,
        'latent': latent_start_list,
        'act_seq': ac_seq_list,
        'obs': obs_start_list,
        'act': acs_start_list,
        'prev_obs': prev_obs_start_list,
        'prev_act': prev_acs_start_list,
        'next_obs': next_obs_list,
        'next_obs_sigma': next_obs_sigma_list,
        'goal_obs': obs_seq_list,
    }

    return return_dict