Example #1
0
    def __init__(self, particle, gas, efield, bfield, seed):
        self.particle = particle
        self.gas = gas
        self.efield = np.asarray(efield)
        self.bfield = np.asarray(bfield)
        self.kfilter = UKF(dim_x=self.sv_dim, dim_z=self.meas_dim, fx=self.update_state_vector,
                           hx=self.generate_measurement, dtx=self.find_timestep)

        self.kfilter.Q *= 1e-6
        self.kfilter.R *= 1e-2
        self.kfilter.x = seed
        self.kfilter.P *= 100
Example #2
0
class Tracker(object):
    """Class responsible for fitting particle tracks.

    This is basically a wrapper around the UnscentedKalmanFilter class that provides an interface between the physics
    simulation functions and the Kalman filter functions.

    Parameters
    ----------
    particle : Particle
        The particle to be tracked
    gas : Gas, or subclass
        The gas in the detector
    efield : array-like
        The electric field, in SI units
    bfield : array-like
        The magnetic field, in Tesla
    seed : array-like
        The initial seed for the Kalman filter. This is a guess for the state vector at time 0.

    Attributes
    ----------
    meas_dim : int
        The dimensionality of the measurements
    sv_dim : int
        The dimensionality of the state vectors
    """

    meas_dim = 3
    sv_dim = 6

    def __init__(self, particle, gas, efield, bfield, seed):
        self.particle = particle
        self.gas = gas
        self.efield = np.asarray(efield)
        self.bfield = np.asarray(bfield)
        self.kfilter = UKF(dim_x=self.sv_dim, dim_z=self.meas_dim, fx=self.update_state_vector,
                           hx=self.generate_measurement, dtx=self.find_timestep)

        self.kfilter.Q *= 1e-6
        self.kfilter.R *= 1e-2
        self.kfilter.x = seed
        self.kfilter.P *= 100

    def update_state_vector(self, state, dt):
        """Find the next state vector.

        This implements the f(x) prediction function for the Kalman filter.

        The state vector is of the form (x, y, z, px, py, pz).

        Parameters
        ----------
        state : array-like
            The current state
        dt : float
            The time step to the next state

        Returns
        -------
        ndarray
            The next state
        """
        pos0 = state[0:3]
        mom0 = state[3:6]

        self.particle.position = pos0
        self.particle.momentum = mom0

        # tstep = pos_step / (self.particle.beta * c_lgt)

        new_state = sim.find_next_state(self.particle, self.gas, self.efield, self.bfield, dt)
        self.particle.state_vector = new_state
        return np.hstack([self.particle.position, self.particle.momentum])

    def generate_measurement(self, state):
        """Convert a state vector to a measurement vector.

        This implements the h(x) measurement function for the Kalman filter.

        Parameters
        ----------
        state : array-like
            The state vector

        Returns
        -------
        ndarray
            The corresponding measurement vector
        """
        pos = state[0:3]
        # self.particle.position = pos
        return pos

    def find_timestep(self, sv, dpos):
        """Calculate a time step from a position step.

        This is done by dividing by the relativistic velocity.

        Parameters
        ----------
        sv : array-like
            The state vector
        dpos : float
            The position step

        Returns
        -------
        float
            The time step
        """
        self.particle.state_vector = sv
        return dpos / (self.particle.beta * c_lgt)

    def track(self, meas):
        """Run the tracker on the given data set.

        Parameters
        ----------
        meas : array-like
            The measured data points

        Returns
        -------
        res : ndarray
            The fitted state vectors
        covar : ndarray
            The covariance matrices of the state vectors
        times : ndarray
            The time at each step, for plotting on the x axis
        """
        res, covar, times = self.kfilter.batch_filter(meas)
        return res, covar, times