Esempio n. 1
0
def fixed_lag_stitch_post_split(graph: MultiDiGraph,
                                fixed_particles: MMParticles,
                                new_particles: MMParticles,
                                new_weights: np.ndarray,
                                mm_model: MapMatchingModel,
                                max_rejections: int) -> MMParticles:
    """
    Stitch together fixed_particles with samples from new_particles according to joint fixed-lag posterior
    :param graph: encodes road network, simplified and projected to UTM
    :param fixed_particles: trajectories before stitching time (won't be changed)
    :param new_particles: trajectories after stitching time (to be resampled)
        one observation time overlap with fixed_particles
    :param new_weights: weights applied to new_particles
    :param mm_model: MapMatchingModel
    :param max_rejections: number of rejections to attempt before doing full fixed-lag stitching
        0 will do full fixed-lag stitching and track ess_stitch
    :return: MMParticles object
    """

    n = len(fixed_particles)
    full_fixed_lag_resample = max_rejections == 0

    min_resample_time = new_particles.observation_times[1]
    min_resample_time_indices = [
        np.where(particle[:, 0] == min_resample_time)[0][0]
        if particle is not None else 0 for particle in new_particles
    ]
    originial_stitching_distances = np.array([
        new_particles[j][min_resample_time_indices[j],
                         -1] if new_particles[j] is not None else 0
        for j in range(n)
    ])

    max_fixed_time = fixed_particles._first_non_none_particle[-1, 0]

    stitch_time_interval = min_resample_time - max_fixed_time

    distance_prior_evals = mm_model.distance_prior_evaluate(
        originial_stitching_distances, stitch_time_interval)

    fixed_last_coords = np.array([
        part[0, 5:7] if part is not None else [0, 0] for part in new_particles
    ])
    new_coords = np.array([
        new_particles[j][min_resample_time_indices[j],
                         5:7] if new_particles[j] is not None else [0, 0]
        for j in range(n)
    ])
    deviation_prior_evals = mm_model.deviation_prior_evaluate(
        fixed_last_coords, new_coords, originial_stitching_distances)

    original_prior_evals = np.zeros(n)
    pos_inds = new_particles.prior_norm > 1e-5
    original_prior_evals[pos_inds] = distance_prior_evals[pos_inds] \
                                     * deviation_prior_evals[pos_inds] \
                                     * new_particles.prior_norm[pos_inds]

    out_particles = fixed_particles

    # Initiate some required quantities depending on whether to do rejection sampling or not
    if full_fixed_lag_resample:
        ess_stitch_track = np.zeros(n)

        # distance_prior_bound = None
        # adjusted_weights = None
    else:
        ess_stitch_track = None

        pos_prior_bound = mm_model.pos_distance_prior_bound(
            stitch_time_interval)
        prior_bound = mm_model.distance_prior_bound(stitch_time_interval)
        store_out_parts = fixed_particles.copy()

    adjusted_weights = new_weights.copy()
    adjusted_weights[original_prior_evals > 1e-5] /= original_prior_evals[
        original_prior_evals > 1e-5]
    adjusted_weights[original_prior_evals < 1e-5] = 0
    adjusted_weights /= np.sum(adjusted_weights)

    resort_to_full = False

    # Iterate through particles
    for j in range(n):
        fixed_particle = fixed_particles[j]

        # Check if particle is None
        # i.e. fixed lag approx has failed
        if fixed_particle is None:
            out_particles[j] = None
            if full_fixed_lag_resample:
                ess_stitch_track[j] = 0
            continue

        last_edge_fixed = fixed_particle[-1]
        last_edge_fixed_geom = get_geometry(graph, last_edge_fixed[1:4])
        last_edge_fixed_length = last_edge_fixed_geom.length

        if full_fixed_lag_resample:
            # Full resampling
            out_particles[j], ess_stitch_track[j] = full_fixed_lag_stitch(
                fixed_particle, last_edge_fixed, last_edge_fixed_length,
                new_particles, adjusted_weights, stitch_time_interval,
                min_resample_time_indices, mm_model, True)

        else:
            # Rejection sampling
            out_particles[j] = rejection_fixed_lag_stitch(
                fixed_particle,
                last_edge_fixed,
                last_edge_fixed_length,
                new_particles,
                adjusted_weights,
                stitch_time_interval,
                min_resample_time_indices,
                pos_prior_bound,
                mm_model,
                max_rejections,
                break_on_zero=True)
            if out_particles[j] is None:
                # Rejection sampling reached max_rejections -> try full resampling
                out_particles[j] = full_fixed_lag_stitch(
                    fixed_particle, last_edge_fixed, last_edge_fixed_length,
                    new_particles, adjusted_weights, stitch_time_interval,
                    min_resample_time_indices, mm_model, False)

            if isinstance(out_particles[j], int) and out_particles[j] == 0:
                resort_to_full = True
                break

    if resort_to_full:
        for j in range(n):
            fixed_particle = store_out_parts[j]

            # Check if particle is None
            # i.e. fixed lag approx has failed
            if fixed_particle is None:
                out_particles[j] = None
                if full_fixed_lag_resample:
                    ess_stitch_track[j] = 0
                continue

            last_edge_fixed = fixed_particle[-1]
            last_edge_fixed_geom = get_geometry(graph, last_edge_fixed[1:4])
            last_edge_fixed_length = last_edge_fixed_geom.length

            # Rejection sampling with full bound
            out_particles[j] = rejection_fixed_lag_stitch(
                fixed_particle, last_edge_fixed, last_edge_fixed_length,
                new_particles, adjusted_weights, stitch_time_interval,
                min_resample_time_indices, prior_bound, mm_model,
                max_rejections)
            if out_particles[j] is None:
                # Rejection sampling reached max_rejections -> try full resampling
                out_particles[j] = full_fixed_lag_stitch(
                    fixed_particle, last_edge_fixed, last_edge_fixed_length,
                    new_particles, adjusted_weights, stitch_time_interval,
                    min_resample_time_indices, mm_model, False)

    if full_fixed_lag_resample:
        out_particles.ess_stitch = np.append(out_particles.ess_stitch,
                                             np.atleast_2d(ess_stitch_track),
                                             axis=0)

    # Do full resampling where fixed lag approx broke
    none_inds = np.array([p is None for p in out_particles])
    good_inds = ~none_inds
    n_good = good_inds.sum()

    if n_good == 0:
        raise ValueError(
            "Map-matching failed: all stitching probabilities zero,"
            "try increasing the lag or number of particles")

    if n_good < n:
        none_inds_res_indices = np.random.choice(n,
                                                 n - n_good,
                                                 p=good_inds / n_good)
        for i, j in enumerate(np.where(none_inds)[0]):
            out_particles[j] = out_particles[none_inds_res_indices[i]]
        if full_fixed_lag_resample:
            out_particles.ess_stitch[-1,
                                     none_inds] = 1 / (new_weights**2).sum()

    return out_particles
Esempio n. 2
0
def backward_simulate(graph: MultiDiGraph,
                      filter_particles: MMParticles,
                      filter_weights: np.ndarray,
                      time_interval_arr: np.ndarray,
                      mm_model: MapMatchingModel,
                      max_rejections: int,
                      verbose: bool = False,
                      store_ess_back: bool = None,
                      store_norm_quants: bool = False) -> MMParticles:
    """
    Given particle filter output, run backwards simulation to output smoothed trajectories
    :param graph: encodes road network, simplified and projected to UTM
    :param filter_particles: marginal outputs from particle filter
    :param filter_weights: weights
    :param time_interval_arr: times between observations, must be length one less than filter_particles
    :param mm_model: MapMatchingModel
    :param max_rejections: number of rejections to attempt before doing full fixed-lag stitching
        0 will do full backward simulation and track ess_back
    :param verbose: print ess_pf or ess_back
    :param store_ess_back: whether to store ess_back (if possible) in MMParticles object
    :param store_norm_quants: if True normalisation quantities returned in out_particles
    :return: MMParticles object
    """
    n_samps = filter_particles[-1].n
    num_obs = len(filter_particles)

    if len(time_interval_arr) + 1 != num_obs:
        raise ValueError(
            "time_interval_arr must be length one less than that of filter_particles"
        )

    full_sampling = max_rejections == 0
    if store_ess_back is None:
        store_ess_back = full_sampling

    # Multinomial resample end particles if weighted
    if np.all(filter_weights[-1] == filter_weights[-1][0]):
        out_particles = filter_particles[-1].copy()
    else:
        out_particles = multinomial(filter_particles[-1], filter_weights[-1])
    if full_sampling:
        ess_back = np.zeros((num_obs, n_samps))
        ess_back[0] = 1 / (filter_weights[-1]**2).sum()
    else:
        ess_back = None

    if num_obs < 2:
        return out_particles

    if store_norm_quants:
        norm_quants = np.zeros(
            (num_obs - 1, *filter_particles[0].prior_norm.shape))

    for i in range(num_obs - 2, -1, -1):
        next_time = filter_particles[i + 1].latest_observation_time

        if not full_sampling:
            pos_prior_bound = mm_model.pos_distance_prior_bound(
                time_interval_arr[i])
            prior_bound = mm_model.distance_prior_bound(time_interval_arr[i])
            store_out_parts = out_particles.copy()

        if filter_particles[i].prior_norm.ndim == 2:
            prior_norm = filter_particles[i].prior_norm[:, 0]
        else:
            prior_norm = filter_particles[i].prior_norm
        adjusted_weights = filter_weights[i].copy()
        good_inds = np.logical_and(adjusted_weights != 0, prior_norm != 0)
        adjusted_weights[good_inds] /= prior_norm[good_inds]
        adjusted_weights[~good_inds] = 0
        adjusted_weights /= adjusted_weights.sum()

        if store_norm_quants:
            sampled_inds = np.zeros(n_samps, dtype=int)

        resort_to_full = False
        for j in range(n_samps):
            fixed_particle = out_particles[j].copy()
            first_edge_fixed = fixed_particle[0]
            first_edge_fixed_geom = get_geometry(graph, first_edge_fixed[1:4])
            first_edge_fixed_length = first_edge_fixed_geom.length
            fixed_next_time_index = np.where(
                fixed_particle[:, 0] == next_time)[0][0]

            if full_sampling:
                back_output = full_backward_sample(
                    fixed_particle,
                    first_edge_fixed,
                    first_edge_fixed_length,
                    filter_particles[i],
                    adjusted_weights,
                    time_interval_arr[i],
                    fixed_next_time_index,
                    mm_model,
                    return_ess_back=True,
                    return_sampled_index=store_norm_quants)

                if store_norm_quants:
                    out_particles[j], ess_back[
                        i, j], sampled_inds[j] = back_output
                else:
                    out_particles[j], ess_back[i, j] = back_output

            else:
                back_output = rejection_backward_sample(
                    fixed_particle,
                    first_edge_fixed,
                    first_edge_fixed_length,
                    filter_particles[i],
                    adjusted_weights,
                    time_interval_arr[i],
                    fixed_next_time_index,
                    pos_prior_bound,
                    mm_model,
                    max_rejections,
                    return_sampled_index=store_norm_quants,
                    break_on_zero=True)

                first_back_output = back_output[
                    0] if store_norm_quants else back_output

                if first_back_output is None:
                    back_output = full_backward_sample(
                        fixed_particle,
                        first_edge_fixed,
                        first_edge_fixed_length,
                        filter_particles[i],
                        adjusted_weights,
                        time_interval_arr[i],
                        fixed_next_time_index,
                        mm_model,
                        return_ess_back=False,
                        return_sampled_index=store_norm_quants)

                if isinstance(first_back_output,
                              int) and first_back_output == 0:
                    resort_to_full = True
                    break

                if store_norm_quants:
                    out_particles[j], sampled_inds[j] = back_output
                else:
                    out_particles[j] = back_output

        if resort_to_full:
            if store_norm_quants:
                sampled_inds = np.zeros(n_samps, dtype=int)
            for j in range(n_samps):
                fixed_particle = store_out_parts[j]
                first_edge_fixed = fixed_particle[0]
                first_edge_fixed_geom = get_geometry(graph,
                                                     first_edge_fixed[1:4])
                first_edge_fixed_length = first_edge_fixed_geom.length
                fixed_next_time_index = np.where(
                    fixed_particle[:, 0] == next_time)[0][0]

                back_output = rejection_backward_sample(
                    fixed_particle,
                    first_edge_fixed,
                    first_edge_fixed_length,
                    filter_particles[i],
                    adjusted_weights,
                    time_interval_arr[i],
                    fixed_next_time_index,
                    prior_bound,
                    mm_model,
                    max_rejections,
                    return_sampled_index=store_norm_quants,
                    break_on_zero=False)

                first_back_output = back_output[
                    0] if store_norm_quants else back_output

                if first_back_output is None:
                    back_output = full_backward_sample(
                        fixed_particle,
                        first_edge_fixed,
                        first_edge_fixed_length,
                        filter_particles[i],
                        adjusted_weights,
                        time_interval_arr[i],
                        fixed_next_time_index,
                        mm_model,
                        return_ess_back=False,
                        return_sampled_index=store_norm_quants)

                if store_norm_quants:
                    out_particles[j], sampled_inds[j] = back_output
                else:
                    out_particles[j] = back_output

        if store_norm_quants:
            norm_quants[i] = filter_particles[i].prior_norm[sampled_inds]

        none_inds = np.array([p is None or None in p for p in out_particles])
        good_inds = ~none_inds
        n_good = good_inds.sum()
        if n_good < n_samps:
            none_inds_res_indices = np.random.choice(n_samps,
                                                     n_samps - n_good,
                                                     p=good_inds / n_good)
            for i_none, j_none in enumerate(np.where(none_inds)[0]):
                out_particles[j_none] = out_particles[
                    none_inds_res_indices[i_none]].copy()
                if store_norm_quants:
                    norm_quants[:, j_none] = norm_quants[:,
                                                         none_inds_res_indices[
                                                             i_none]]
            if store_ess_back:
                out_particles.ess_back[i, none_inds] = n_samps

        if verbose:
            if full_sampling:
                print(
                    str(filter_particles[i].latest_observation_time) +
                    " Av Backward ESS: " + str(np.mean(ess_back[i])))
            else:
                print(str(filter_particles[i].latest_observation_time))

        if store_ess_back:
            out_particles.ess_back = ess_back

    if store_norm_quants:
        out_particles.dev_norm_quants = norm_quants

    return out_particles