def train_measurement(buddy, ekf_model, dataloader, log_interval=10):
    losses = []

    # Train measurement model only for 1 epoch
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        noisy_states, observations, log_likelihoods = batch_gpu

        # noisy_states = noisy_states[:, np.newaxis, :]
        z, R = ekf_model.measurement_model(observations, noisy_states)
        assert len(z.shape) == 2
        # pred_likelihoods = pred_likelihoods.squeeze(dim=1)
        # todo: get actual x!

        loss = torch.mean((pred_likelihoods - log_likelihoods)**2)
        losses.append(utils.to_numpy(loss))

        buddy.minimize(loss,
                       optimizer_name="measurement",
                       checkpoint_interval=10000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope("measurement"):
                buddy.log("Training loss", loss)

                buddy.log("Pred likelihoods mean", pred_likelihoods.mean())
                buddy.log("Pred likelihoods std", pred_likelihoods.std())

                buddy.log("Label likelihoods mean", log_likelihoods.mean())
                buddy.log("Label likelihoods std", log_likelihoods.std())

    print("Epoch loss:", np.mean(losses))
Ejemplo n.º 2
0
def rollout(model, trajectories, max_timesteps=300):
    # To make things easier, we're going to cut all our trajectories to the
    # same length :)
    timesteps = np.min([len(s) for s, _, _ in trajectories] + [max_timesteps])

    predicted_states = [[states[0]] for states, _, _ in trajectories]
    actual_states = [states[:timesteps] for states, _, _ in trajectories]
    for t in range(1, timesteps):
        s = []
        o = {}
        c = []
        for i, traj in enumerate(trajectories):
            states, observations, controls = traj

            s.append(predicted_states[i][t - 1])
            o_t = utils.DictIterator(observations)[t]
            utils.DictIterator(o).append(o_t)
            c.append(controls[t])

        s = np.array(s)
        utils.DictIterator(o).convert_to_numpy()
        c = np.array(c)

        device = next(model.parameters()).device
        pred = model(*utils.to_torch([s, o, c], device=device))
        pred = utils.to_numpy(pred)
        assert pred.shape == (len(trajectories), 2)
        for i in range(len(trajectories)):
            predicted_states[i].append(pred[i])

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)
    return predicted_states, actual_states
def train_dynamics(buddy,
                   pf_model,
                   dataloader,
                   log_interval=10,
                   optim_name="dynamics"):
    losses = []

    # Train dynamics only for 1 epoch
    # Train for 1 epoch
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        prev_states, _unused_observations, controls, new_states = batch_gpu

        prev_states += utils.to_torch(np.random.normal(0,
                                                       0.05,
                                                       size=prev_states.shape),
                                      device=buddy._device)
        prev_states = prev_states[:, np.newaxis, :]
        new_states_pred = pf_model.dynamics_model(prev_states,
                                                  controls,
                                                  noisy=False)
        new_states_pred = new_states_pred.squeeze(dim=1)

        mse_pos = F.mse_loss(new_states_pred, new_states)
        # mse_pos = torch.mean((new_states_pred - new_states) ** 2, axis=0)
        loss = mse_pos
        losses.append(utils.to_numpy(loss))

        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=1000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                # buddy.log("Training loss", loss)
                buddy.log("MSE position", mse_pos)

                label_std = new_states.std(dim=0)
                buddy.log("Label pos std", label_std[0])

                pred_std = new_states_pred.std(dim=0)
                buddy.log("Predicted pos std", pred_std[0])

                label_mean = new_states.mean(dim=0)
                buddy.log("Label pos mean", label_mean[0])

                pred_mean = new_states_pred.mean(dim=0)
                buddy.log("Predicted pos mean", pred_mean[0])

            # print(".", end="")
    print("Epoch loss:", np.mean(losses))
def train_measurement(buddy,
                      kf_model,
                      dataloader,
                      log_interval=10,
                      optim_name="ekf_measurement",
                      checkpoint_interval=500,
                      loss_type="mse"):
    assert loss_type in ["mse", "mixed", "nll"]

    losses = []

    for batch_idx, batch in enumerate(dataloader):
        noisy_state, observation, _, state = fannypack.utils.to_device(
            batch, buddy._device)
        #         states = states[:,0,:]
        state_update, R = kf_model.measurement_model(observation, noisy_state)
        mse = F.mse_loss(state_update, state)

        if loss_type == "mse":
            loss = mse
        elif loss_type == "nll":
            nll = -1.0 * utility.gaussian_log_likelihood(
                state_update, state, R)
            nll = torch.mean(nll)
            loss = nll
        else:
            nll = -1.0 * utility.gaussian_log_likelihood(
                state_update, state, R)
            nll = torch.mean(nll)
            loss = mse + nll
            # import ipdb; ipdb.set_trace()
        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=checkpoint_interval)
        losses.append(utils.to_numpy(loss))

        if buddy.optimizer_steps % log_interval == 0:

            with buddy.log_scope(optim_name):
                buddy.log("loss", loss)
                buddy.log("label_mean", fannypack.utils.to_numpy(state).mean())
                buddy.log("label_std", fannypack.utils.to_numpy(state).std())
                buddy.log("pred_mean",
                          fannypack.utils.to_numpy(state_update).mean())
                buddy.log("pred_std",
                          fannypack.utils.to_numpy(state_update).std())
                buddy.log_model_grad_norm()
                # buddy.log_model_grad_hist()
                # buddy.log_model_weights_hist()
    print("Epoch loss:", np.mean(losses))
Ejemplo n.º 5
0
    def _update(self, observations, controls):
        # Pre-process model inputs
        states_prev = np.array(self.prev_estimate)[np.newaxis, np.newaxis]

        # Prediction
        with torch.no_grad():
            states_new = self.model(
                *utils.to_torch([states_prev, observations, controls],
                                device=self.buddy._device))

        # Post-process & return
        estimate = np.squeeze(states_new)
        self.prev_estimate = estimate
        return utils.to_numpy(estimate)
Ejemplo n.º 6
0
    def _update(self, observations, controls):
        # Run model
        state_estimates, new_particles, new_log_weights = self.pf_model.forward(
            self.particles,
            self.log_weights,
            *utils.to_torch([
                observations,
                controls,
            ],
                            device=self.buddy._device),
            resample=True,
            noisy_dynamics=True)

        self.particles = new_particles
        self.log_weights = new_log_weights

        return np.squeeze(utils.to_numpy(state_estimates))
Ejemplo n.º 7
0
def train(buddy, model, dataloader, log_interval=10, state_noise_std=0.2):
    losses = []

    # Train for 1 epoch
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        prev_states, observations, controls, new_states = batch_gpu
        prev_states += utils.to_torch(np.random.normal(0,
                                                       state_noise_std,
                                                       size=prev_states.shape),
                                      device=buddy._device)

        new_states_pred = model(prev_states, observations, controls)

        # mse_pos, mse_vel = torch.mean((new_states_pred - new_states) ** 2, axis=0)
        # loss = (mse_pos + mse_vel) / 2
        loss = torch.mean((new_states_pred - new_states)**2)
        losses.append(utils.to_numpy(loss))

        buddy.minimize(loss, checkpoint_interval=10000)

        if buddy._steps % log_interval == 0:
            with buddy.log_scope("baseline_training"):
                buddy.log("Training loss", loss)
                # buddy.log("MSE position", mse_pos)
                # buddy.log("MSE velocity", mse_vel)

                label_std = new_states.std(dim=0)
                buddy.log("Training pos std", label_std[0])
                # buddy.log("Training vel std", label_std[1])

                pred_std = new_states_pred.std(dim=0)
                buddy.log("Predicted pos std", pred_std[0])
                # buddy.log("Predicted vel std", pred_std[1])

                label_mean = new_states.mean(dim=0)
                buddy.log("Training pos mean", label_mean[0])
                # buddy.log("Training vel mean", label_mean[1])

                pred_mean = new_states_pred.mean(dim=0)
                buddy.log("Predicted pos mean", pred_mean[0])
                # buddy.log("Predicted vel mean", pred_mean[1])

    print("Epoch loss:", np.mean(losses))
Ejemplo n.º 8
0
def rollout_lstm(model, trajectories, max_timesteps=300):
    timesteps = np.min([len(s) for s, _, _ in trajectories] + [max_timesteps])

    trajectory_count = len(trajectories)

    state_dim = trajectories[0][0].shape[-1]
    actual_states = np.zeros((trajectory_count, timesteps, state_dim))

    batched_observations = {}
    batched_controls = []

    # Trajectories is a list of (states, observations, controls)
    for i, (states, observations, controls) in enumerate(trajectories):

        observations = utils.DictIterator(observations)[1:timesteps]
        utils.DictIterator(batched_observations).append(observations)
        batched_controls.append(controls[1:timesteps])

        assert states.shape == (timesteps, state_dim)
        actual_states[i] = states[:timesteps]  # * 0 + 0.1

    utils.DictIterator(batched_observations).convert_to_numpy()
    batched_controls = np.array(batched_controls)

    # Propagate through model
    # model.reset_hidden_states(utils.to_torch(actual_states[:, 0, :]))
    device = next(model.parameters()).device
    predicted_states = np.concatenate([
        actual_states[:, 0:1, :],
        utils.to_numpy(
            model(
                utils.to_torch(batched_observations, device),
                utils.to_torch(batched_controls, device),
            )),
    ],
                                      axis=1)

    # Indexing: batch, sequence length, state
    return predicted_states, actual_states
def train_e2e(buddy,
              pf_model,
              dataloader,
              log_interval=2,
              loss_type="mse",
              optim_name="e2e",
              resample=False,
              know_image_blackout=False):
    # Train for 1 epoch
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        batch_particles, batch_states, batch_obs, batch_controls = batch_gpu

        # N = batch size, M = particle count
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        N, M, state_dim = batch_particles.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        # Give all particle equal weights
        particles = batch_particles
        log_weights = torch.ones((N, M), device=buddy._device) * (-np.log(M))

        # Accumulate losses from each timestep
        losses = []
        for t in range(1, timesteps):
            prev_particles = particles
            prev_log_weights = log_weights

            if know_image_blackout:
                state_estimates, new_particles, new_log_weights = pf_model.forward(
                    prev_particles,
                    prev_log_weights,
                    utils.DictIterator(batch_obs)[:, t - 1, :],
                    batch_controls[:, t, :],
                    resample=resample,
                    noisy_dynamics=True,
                    know_image_blackout=True)
            else:
                state_estimates, new_particles, new_log_weights = pf_model.forward(
                    prev_particles,
                    prev_log_weights,
                    utils.DictIterator(batch_obs)[:, t - 1, :],
                    batch_controls[:, t, :],
                    resample=resample,
                    noisy_dynamics=True,
                )

            if loss_type == "gmm":
                loss = dpf.gmm_loss(particles_states=new_particles,
                                    log_weights=new_log_weights,
                                    true_states=batch_states[:, t, :],
                                    gmm_variances=np.array([0.1]))
            elif loss_type == "mse":
                loss = torch.mean((state_estimates - batch_states[:, t, :])**2)
            else:
                assert False, "Invalid loss"

            losses.append(loss)

            # Enable backprop through time
            particles = new_particles
            log_weights = new_log_weights

            # # Disable backprop through time
            # particles = new_particles.detach()
            # log_weights = new_log_weights.detach()

            # assert state_estimates.shape == batch_states[:, t, :].shape

        buddy.minimize(torch.mean(torch.stack(losses)),
                       optimizer_name=optim_name,
                       checkpoint_interval=1000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", np.mean(utils.to_numpy(losses)))
                buddy.log("Log weights mean", log_weights.mean())
                buddy.log("Log weights std", log_weights.std())
                buddy.log("Particle states mean", particles.mean())
                buddy.log("particle states std", particles.std())

    print("Epoch loss:", np.mean(utils.to_numpy(losses)))
def train_dynamics_recurrent(buddy,
                             pf_model,
                             dataloader,
                             log_interval=10,
                             loss_type="l1",
                             optim_name="dynamics_recurrent"):

    assert loss_type in ('l1', 'l2', 'huber', 'peter')

    # Train dynamics only for 1 epoch
    # Train for 1 epoch
    epoch_losses = []
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        batch_states, batch_obs, batch_controls = batch_gpu
        # N = batch size, M = particle count
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        # Track current states as they're propagated through our dynamics model
        prev_states = batch_states[:, 0, :]
        assert prev_states.shape == (N, state_dim)

        # Accumulate losses from each timestep
        losses = []
        magnitude_losses = []
        direction_losses = []

        # Compute some state deltas for debugging
        label_deltas = np.mean(utils.to_numpy(batch_states[:, 1:, :] -
                                              batch_states[:, :-1, :])**2,
                               axis=(0, 2))
        assert label_deltas.shape == (timesteps - 1, )
        pred_deltas = []

        for t in range(1, timesteps):
            # Propagate current states through dynamics model
            controls = batch_controls[:, t, :]
            new_states = pf_model.dynamics_model(
                prev_states[:, np.newaxis, :],  # Add particle dimension
                controls,
                noisy=False,
            ).squeeze(dim=1)  # Remove particle dimension
            assert new_states.shape == (N, state_dim)

            # Compute deltas
            pred_delta = prev_states - new_states
            label_delta = batch_states[:, t - 1, :] - batch_states[:, t, :]
            assert pred_delta.shape == (N, state_dim)
            assert label_delta.shape == (N, state_dim)

            # Compute and add loss
            if loss_type == 'l1':
                # timestep_loss = F.l1_loss(pred_delta, label_delta)
                timestep_loss = F.l1_loss(new_states, batch_states[:, t, :])
            elif loss_type == 'l2':
                # timestep_loss = F.mse_loss(pred_delta, label_delta)
                timestep_loss = F.mse_loss(new_states, batch_states[:, t, :])
            elif loss_type == 'huber':
                # Note that the units our states are in will affect results
                # for Huber
                timestep_loss = F.smooth_l1_loss(batch_states[:, t, :],
                                                 new_states)
            elif loss_type == 'peter':
                # Use a Peter loss
                # Currently broken
                assert False

                pred_magnitude = torch.norm(pred_delta, dim=1)
                label_magnitude = torch.norm(label_delta, dim=1)
                assert pred_magnitude.shape == (N, )
                assert label_magnitude.shape == (N, )

                # pred_direction = pred_delta / (pred_magnitude + 1e-8)
                # label_direction = label_delta / (label_magnitude + 1e-8)
                # assert pred_direction.shape == (N, state_dim)
                # assert label_direction.shape == (N, state_dim)

                # Compute loss
                magnitude_loss = F.mse_loss(pred_magnitude, label_magnitude)
                # direction_loss =
                timestep_loss = magnitude_loss + direction_loss

                magnitude_losses.append(magnitude_loss)
                direction_losses.append(direction_loss)

            else:
                assert False
            losses.append(timestep_loss)

            # Compute delta and update states
            pred_deltas.append(
                np.mean(utils.to_numpy(new_states - prev_states)**2))
            prev_states = new_states

        pred_deltas = np.array(pred_deltas)
        assert pred_deltas.shape == (timesteps - 1, )

        loss = torch.mean(torch.stack(losses))
        epoch_losses.append(loss)
        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=1000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", loss)

                buddy.log("Label delta mean", label_deltas.mean())
                buddy.log("Label delta std", label_deltas.std())

                buddy.log("Pred delta mean", pred_deltas.mean())
                buddy.log("Pred delta std", pred_deltas.std())

                if magnitude_losses:
                    buddy.log("Magnitude loss",
                              torch.mean(torch.tensor(magnitude_losses)))
                if direction_losses:
                    buddy.log("Direction loss",
                              torch.mean(torch.tensor(direction_losses)))

    print("Epoch loss:", np.mean(utils.to_numpy(epoch_losses)))
Ejemplo n.º 11
0
def rollout_kf(kf_model,
               trajectories,
               start_time=0,
               max_timesteps=300,
               noisy_dynamics=False,
               true_initial=False,
               init_state_noise=0.2,
               save_data_name=None):
    # To make things easier, we're going to cut all our trajectories to the
    # same length :)

    kf_model.eval()
    end_time = np.min([len(s) for s, _, _ in trajectories] +
                      [start_time + max_timesteps])

    print("endtime: ", end_time)

    actual_states = [
        states[start_time:end_time] for states, _, _ in trajectories
    ]

    contact_states = [
        action[start_time:end_time][:, -1]
        for states, obs, action in trajectories
    ]

    state_dim = len(actual_states[0][0])
    N = len(trajectories)
    controls_dim = trajectories[0][2][0].shape

    device = next(kf_model.parameters()).device

    initial_states = np.zeros((N, state_dim))
    initial_sigmas = np.zeros((N, state_dim, state_dim))
    initial_obs = {}

    if true_initial:
        for i in range(N):
            initial_states[i] = trajectories[i][0][0] + np.random.normal(
                0.0, scale=init_state_noise, size=initial_states[i].shape)
            initial_sigmas[i] = np.eye(state_dim) * init_state_noise**2
        (initial_states, initial_sigmas) = utils.to_torch(
            (initial_states, initial_sigmas), device=device)
    else:
        # Put into measurement model!
        dummy_controls = torch.ones((N, ) + controls_dim, ).to(device)
        for i in range(N):
            utils.DictIterator(initial_obs).append(
                utils.DictIterator(trajectories[i][1])[0])

        utils.DictIterator(initial_obs).convert_to_numpy()

        (initial_obs, initial_states, initial_sigmas) = utils.to_torch(
            (initial_obs, initial_states, initial_sigmas), device=device)

        states_tuple = kf_model.forward(
            initial_states,
            initial_sigmas,
            initial_obs,
            dummy_controls,
        )

        initial_states = states_tuple[0]
        initial_sigmas = states_tuple[1]
        predicted_states = [[utils.to_numpy(initial_states[i])]
                            for i in range(len(trajectories))]

    states = initial_states
    sigmas = initial_sigmas

    predicted_states = [[utils.to_numpy(initial_states[i])]
                        for i in range(len(trajectories))]
    predicted_sigmas = [[utils.to_numpy(initial_sigmas[i])]
                        for i in range(len(trajectories))]

    for t in tqdm(range(start_time + 1, end_time)):
        s = []
        o = {}
        c = []

        for i, traj in enumerate(trajectories):
            s, observations, controls = traj

            o_t = utils.DictIterator(observations)[t]
            utils.DictIterator(o).append(o_t)
            c.append(controls[t])

        s = np.array(s)
        utils.DictIterator(o).convert_to_numpy()
        c = np.array(c)
        (s, o, c) = utils.to_torch((s, o, c), device=device)

        estimates = kf_model.forward(
            states,
            sigmas,
            o,
            c,
        )

        state_estimates = estimates[0].data
        sigma_estimates = estimates[1].data

        states = state_estimates
        sigmas = sigma_estimates

        for i in range(len(trajectories)):
            predicted_states[i].append(utils.to_numpy(state_estimates[i]))
            predicted_sigmas[i].append(utils.to_numpy(sigma_estimates[i]))

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)
    predicted_sigmas = np.array(predicted_sigmas)

    rmse_x = np.sqrt(
        np.mean((predicted_states[:, start_time:, 0] -
                 actual_states[:, start_time:, 0])**2))

    rmse_y = np.sqrt(
        np.mean((predicted_states[:, start_time:, 1] -
                 actual_states[:, start_time:, 1])**2))

    print("rsme x: \n{} \n y:\n{}".format(rmse_x, rmse_y))

    if save_data_name is not None:
        import h5py
        filename = "rollout/" + save_data_name + ".h5"

        try:
            f = h5py.File(filename, 'w')
        except:
            import os
            new_dest = "rollout/old/{}.h5".format(save_data_name)
            os.rename(filename, new_dest)
            f = h5py.File(filename, 'w')
        f.create_dataset("predicted_states", data=predicted_states)
        f.create_dataset("actual_states", data=actual_states)
        f.create_dataset("predicted_sigmas", data=predicted_sigmas)
        f.close()

    return predicted_states, actual_states, predicted_sigmas, contact_states
Ejemplo n.º 12
0
def rollout_kf_full(
    kf_model,
    trajectories,
    start_time=0,
    max_timesteps=300,
    true_initial=False,
    init_state_noise=0.2,
):
    # To make things easier, we're going to cut all our trajectories to the
    # same length :)

    kf_model.eval()
    end_time = np.min([len(s) for s, _, _ in trajectories] +
                      [start_time + max_timesteps])

    print("endtime: ", end_time)

    actual_states = [
        states[start_time:end_time] for states, _, _ in trajectories
    ]

    contact_states = [
        action[start_time:end_time][:, -1]
        for states, obs, action in trajectories
    ]

    actions = get_actions(trajectories, start_time, max_timesteps)

    state_dim = len(actual_states[0][0])
    N = len(trajectories)
    controls_dim = trajectories[0][2][0].shape

    device = next(kf_model.parameters()).device

    initial_states = np.zeros((N, state_dim))
    initial_sigmas = np.zeros((N, state_dim, state_dim))
    initial_obs = {}

    if true_initial:
        for i in range(N):
            initial_states[i] = trajectories[i][0][0] + np.random.normal(
                0.0, scale=init_state_noise, size=initial_states[i].shape)
            initial_sigmas[i] = np.eye(state_dim) * init_state_noise**2
        (initial_states, initial_sigmas) = utils.to_torch(
            (initial_states, initial_sigmas), device=device)
    else:
        print("put in measurement model")
        # Put into measurement model!
        dummy_controls = torch.ones((N, ) + controls_dim, ).to(device)
        for i in range(N):
            utils.DictIterator(initial_obs).append(
                utils.DictIterator(trajectories[i][1])[0])

        utils.DictIterator(initial_obs).convert_to_numpy()

        (initial_obs, initial_states, initial_sigmas) = utils.to_torch(
            (initial_obs, initial_states, initial_sigmas), device=device)

        state, state_sigma = kf_model.measurement_model.forward(
            initial_obs, initial_states)
        initial_states = state
        initial_sigmas = state_sigma
        predicted_states = [[utils.to_numpy(initial_states[i])]
                            for i in range(len(trajectories))]

    states = initial_states
    sigmas = initial_sigmas

    predicted_states = [[utils.to_numpy(initial_states[i])]
                        for i in range(len(trajectories))]
    predicted_sigmas = [[utils.to_numpy(initial_sigmas[i])]
                        for i in range(len(trajectories))]

    predicted_dyn_states = [[utils.to_numpy(initial_states[i])]
                            for i in range(len(trajectories))]
    predicted_dyn_sigmas = [[utils.to_numpy(initial_sigmas[i])]
                            for i in range(len(trajectories))]

    predicted_meas_states = [[utils.to_numpy(initial_states[i])]
                             for i in range(len(trajectories))]
    predicted_meas_sigmas = [[utils.to_numpy(initial_sigmas[i])]
                             for i in range(len(trajectories))]

    # jacobian is not initialized
    predicted_jac = [[] for i in range(len(trajectories))]

    for t in tqdm(range(start_time + 1, end_time)):
        s = []
        o = {}
        c = []

        for i, traj in enumerate(trajectories):
            s, observations, controls = traj

            o_t = utils.DictIterator(observations)[t]
            utils.DictIterator(o).append(o_t)
            c.append(controls[t])

        s = np.array(s)
        utils.DictIterator(o).convert_to_numpy()
        c = np.array(c)
        (s, o, c) = utils.to_torch((s, o, c), device=device)

        estimates = kf_model.forward(
            states,
            sigmas,
            o,
            c,
        )

        state_estimates = estimates[0].data
        sigma_estimates = estimates[1].data

        states = state_estimates
        sigmas = sigma_estimates

        dynamics_states = kf_model.dynamics_states
        dynamics_sigma = kf_model.dynamics_sigma
        measurement_states = kf_model.measurement_states
        measurement_sigma = kf_model.measurement_sigma
        dynamics_jac = kf_model.dynamics_jac

        for i in range(len(trajectories)):
            predicted_dyn_states[i].append(utils.to_numpy(dynamics_states[i]))
            predicted_dyn_sigmas[i].append(utils.to_numpy(dynamics_sigma))
            predicted_meas_states[i].append(
                utils.to_numpy(measurement_states[i]))
            predicted_meas_sigmas[i].append(
                utils.to_numpy(measurement_sigma[i]))
            predicted_jac[i].append(utils.to_numpy(dynamics_jac[i]))
            predicted_states[i].append(utils.to_numpy(state_estimates[i]))
            predicted_sigmas[i].append(utils.to_numpy(sigma_estimates[i]))

    results = {}

    results['dyn_states'] = np.array(predicted_dyn_states)
    results['dyn_sigmas'] = np.array(predicted_dyn_sigmas)
    results['meas_states'] = np.array(predicted_meas_states)
    results['meas_sigmas'] = np.array(predicted_meas_sigmas)
    results['dyn_jac'] = np.array(predicted_jac)
    results['predicted_states'] = np.array(predicted_states)
    results['predicted_sigmas'] = np.array(predicted_sigmas)
    results['actual_states'] = np.array(actual_states)
    results['contact_states'] = np.array(contact_states)
    results['actions'] = np.array(actions)

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)

    rmse_x = np.sqrt(
        np.mean((predicted_states[:, start_time:, 0] -
                 actual_states[:, start_time:, 0])**2))

    rmse_y = np.sqrt(
        np.mean((predicted_states[:, start_time:, 1] -
                 actual_states[:, start_time:, 1])**2))

    print("rsme x: \n{} \n y:\n{}".format(rmse_x, rmse_y))

    return results
Ejemplo n.º 13
0
def rollout(pf_model,
            trajectories,
            start_time=0,
            max_timesteps=300,
            particle_count=100,
            noisy_dynamics=True):
    # To make things easier, we're going to cut all our trajectories to the
    # same length :)
    end_time = np.min([len(s) for s, _, _ in trajectories] +
                      [start_time + max_timesteps])
    predicted_states = [[states[start_time]] for states, _, _ in trajectories]
    actual_states = [
        states[start_time:end_time] for states, _, _ in trajectories
    ]

    state_dim = len(actual_states[0][0])
    N = len(trajectories)
    M = particle_count

    device = next(pf_model.parameters()).device

    particles = np.zeros((N, M, state_dim))
    for i in range(N):
        particles[i, :] = predicted_states[i][0]
    particles = utils.to_torch(particles, device=device)
    log_weights = torch.ones((N, M), device=device) * (-np.log(M))

    for t in tqdm_notebook(range(start_time + 1, end_time)):
        s = []
        o = {}
        c = []
        for i, traj in enumerate(trajectories):
            states, observations, controls = traj

            s.append(predicted_states[i][t - start_time - 1])
            o_t = utils.DictIterator(observations)[t]
            utils.DictIterator(o).append(o_t)
            c.append(controls[t])

        s = np.array(s)
        utils.DictIterator(o).convert_to_numpy()
        c = np.array(c)
        (s, o, c) = utils.to_torch((s, o, c), device=device)

        state_estimates, new_particles, new_log_weights = pf_model.forward(
            particles,
            log_weights,
            o,
            c,
            resample=True,
            noisy_dynamics=noisy_dynamics)

        particles = new_particles
        log_weights = new_log_weights

        for i in range(len(trajectories)):
            predicted_states[i].append(utils.to_numpy(state_estimates[i]))

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)
    return predicted_states, actual_states
def train_dynamics_recurrent(buddy,
                             kf_model,
                             dataloader,
                             log_interval=10,
                             loss_type="l2",
                             optim_name="dynamics_recurr",
                             checkpoint_interval=10000,
                             init_state_noise=0.5):
    epoch_losses = []

    assert loss_type in ('l1', 'l2')

    for batch_idx, batch in enumerate(tqdm(dataloader)):
        batch_gpu = utils.to_device(batch, buddy._device)
        batch_states, batch_obs, batch_controls = batch_gpu

        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        prev_states = batch_states[:, 0, :]

        losses = []
        magnitude_losses = []
        direction_losses = []

        # Compute some state deltas for debugging
        label_deltas = np.mean(utils.to_numpy(batch_states[:, 1:, :] -
                                              batch_states[:, :-1, :])**2,
                               axis=(0, 2))
        assert label_deltas.shape == (timesteps - 1, )
        pred_deltas = []

        for t in range(1, timesteps):
            controls = batch_controls[:, t, :]
            new_states = kf_model.dynamics_model(
                prev_states,
                controls,
                noisy=False,
            )

            pred_delta = prev_states - new_states
            label_delta = batch_states[:, t - 1, :] - batch_states[:, t, :]

            # todo: maybe switch back to l2
            if loss_type == "l1":
                timestep_loss = F.l1_loss(new_states, batch_states[:, t, :])
            else:
                timestep_loss = F.mse_loss(new_states, batch_states[:, t, :])

            losses.append(timestep_loss)

            pred_deltas.append(
                np.mean(utils.to_numpy(new_states - prev_states)**2))
            prev_states = new_states

        pred_deltas = np.array(pred_deltas)
        assert pred_deltas.shape == (timesteps - 1, )

        loss = torch.mean(torch.stack(losses))
        epoch_losses.append(loss)

        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=checkpoint_interval)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", loss)

                buddy.log("Label delta mean", label_deltas.mean())
                buddy.log("Label delta std", label_deltas.std())

                buddy.log("Pred delta mean", pred_deltas.mean())
                buddy.log("Pred delta std", pred_deltas.std())

                if magnitude_losses:
                    buddy.log("Magnitude loss",
                              torch.mean(torch.tensor(magnitude_losses)))
                if direction_losses:
                    buddy.log("Direction loss",
                              torch.mean(torch.tensor(direction_losses)))
def rollout(pf_model,
            trajectories,
            start_time=0,
            max_timesteps=300,
            particle_count=100,
            noisy_dynamics=True,
            true_initial=False):
    # To make things easier, we're going to cut all our trajectories to the
    # same length :)
    end_time = np.min([len(s) for s, _, _ in trajectories] +
                      [start_time + max_timesteps])
    actual_states = [
        states[start_time:end_time] for states, _, _ in trajectories
    ]

    state_dim = len(actual_states[0][0])
    N = len(trajectories)
    M = particle_count

    device = next(pf_model.parameters()).device

    particles = np.zeros((N, M, state_dim))
    if true_initial:
        for i in range(N):
            particles[i, :] = trajectories[i][0][0]
        particles += np.random.normal(0, 0.1, size=particles.shape)
    else:
        # Distribute initial particles randomly
        particles += np.random.normal(0, 1.0, size=particles.shape)

    # Populate the initial state estimate as just the estimate of our particles
    # This is a little hacky
    predicted_states = [[np.mean(particles[i], axis=0)]
                        for i in range(len(trajectories))]

    particles = utils.to_torch(particles, device=device)
    log_weights = torch.ones((N, M), device=device) * (-np.log(M))

    for t in tqdm(range(start_time + 1, end_time)):
        s = []
        o = {}
        c = []
        for i, traj in enumerate(trajectories):
            states, observations, controls = traj

            s.append(predicted_states[i][t - start_time - 1])
            o_t = utils.DictIterator(observations)[t]
            utils.DictIterator(o).append(o_t)
            c.append(controls[t])

        s = np.array(s)
        utils.DictIterator(o).convert_to_numpy()
        c = np.array(c)
        (s, o, c) = utils.to_torch((s, o, c), device=device)

        state_estimates, new_particles, new_log_weights = pf_model.forward(
            particles,
            log_weights,
            o,
            c,
            resample=True,
            noisy_dynamics=noisy_dynamics)

        particles = new_particles
        log_weights = new_log_weights

        for i in range(len(trajectories)):
            predicted_states[i].append(utils.to_numpy(state_estimates[i]))

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)
    return predicted_states, actual_states
Ejemplo n.º 16
0
def log_basic(
    estimator: Estimator,
    buddy: Buddy,
    viz: VisData,
    filter_length: int = 1,
    smooth: bool = False,
    plot_means: bool = True,
    ramp_pred: bool = False,
) -> None:
    """Log basic visual information for Gaussian estimators with filter ONLY.

    Parameters
    ----------
    estimator : Estimator
        The estimator.
    buddy : Buddy
        Buddy helper for training.
    viz : VisData
        Visualization data.
    filter_length : int, default=1
        Length of data to provide for filtering during prediction runs.
    smooth : bool, default=False
        Flag indicating whether estimator should smooth.
    plot_means : bool, default=True
        Flag indicating whether to plot means in addition to the samples.
    ramp_pred : bool, default=False
        Flag indicating whether to ramp pred horizon. Used for pred visualizations early
        on in KF training when it is numerically unstable.
    """
    assert filter_length >= 1
    filter_length = min(filter_length, len(viz.np_t))
    data_var = 0.0  # variance of independent Gaussian injected noise

    # ramp the visualizations
    if ramp_pred and hasattr(estimator, "_ramp_iters"):
        it = buddy.optimizer_steps
        idx_p = min((it // estimator._ramp_iters) + filter_length + 1, len(viz.t))  # type: ignore
    else:
        idx_p = None

    # ---------- #
    # PREDICTION #
    # ---------- #

    # filtered portion
    z0 = estimator.get_initial_hidden_state(viz.y0.shape[0])
    noise = torch.randn_like(viz.y[0:filter_length]) * np.sqrt(data_var)
    z_mu_f, z_cov_f = estimator(
        viz.t[:filter_length],
        viz.y[:filter_length] + noise,
        viz.u[:filter_length],
        z0,
        return_hidden=True,
    )
    y_mu_f, y_cov_f = estimator(
        viz.t[:filter_length], viz.y[:filter_length] + noise, viz.u[:filter_length], z0,
    )

    # smooth if possible
    if smooth:
        z_mu_s, z_cov_s = estimator.get_smooth()  # type: ignore
        y_mu_s, y_cov_s = estimator.latent_to_observation(z_mu_s, z_cov_s)
    else:
        z_mu_s = z_mu_f
        z_cov_s = z_cov_f
        y_mu_s = y_mu_f
        y_cov_s = y_cov_f

    # predicting
    y_mu_p, y_cov_p = estimator.predict(
        z_mu_s[-1],
        z_cov_s[-1],
        viz.t[(filter_length - 1) : idx_p],
        viz.u[(filter_length - 1) : idx_p],
    )

    # sampling from observation distributions
    y_samp_s = to_numpy(reparameterize_gauss(y_mu_s, y_cov_s))
    y_samp_p = to_numpy(reparameterize_gauss(y_mu_p, y_cov_p))
    y_mu_s = to_numpy(y_mu_s)
    y_mu_p = to_numpy(y_mu_p)

    # log prediction vs. ground truth
    with buddy.log_scope("0_predict"):
        # plotting predictions samples versus ground truth
        with ph.plot_context(viz.plot_settings) as (fig, ax):
            if filter_length == 1:
                ph.plot_xy_compare(
                    [y_samp_p, viz.np_y[:, :, 0:2]],
                    style_list=["r-", "b-"],
                    startmark_list=["ro", "bo"],
                    endmark_list=["rx", "bx"],
                )
            else:
                ph.plot_xy_compare(
                    [
                        y_samp_p,
                        viz.np_y[(filter_length - 1) : idx_p, :, 0:2],
                        y_samp_s,
                        viz.np_y[:filter_length, :, 0:2],
                    ],
                    style_list=["r-", "b-", "r--", "b--"],
                    startmark_list=[None, None, "ro", "bo"],
                    endmark_list=["rx", "bx", None, None],
                )
            p_img = ph.plot_as_image(fig)
        buddy.log_image("samples-xy-trajectory", p_img)

        # plotting means versus ground truth
        if plot_means:
            with ph.plot_context(viz.plot_settings) as (fig, ax):
                if filter_length == 1:
                    ph.plot_xy_compare(
                        [y_mu_p, viz.np_y[:, :, 0:2]],
                        style_list=["r-", "b-"],
                        startmark_list=["ro", "bo"],
                        endmark_list=["rx", "bx"],
                    )
                else:
                    ph.plot_xy_compare(
                        [
                            y_mu_p,
                            viz.np_y[(filter_length - 1) : idx_p, :, 0:2],
                            y_mu_s,
                            viz.np_y[:filter_length, :, 0:2],
                        ],
                        style_list=["r-", "b-", "r--", "b--"],
                        startmark_list=[None, None, "ro", "bo"],
                        endmark_list=["rx", "bx", None, None],
                    )
                p_img = ph.plot_as_image(fig)
            buddy.log_image("means-xy-trajectory", p_img)

        # xy timeseries - samples and means
        _log_timeseries(buddy, viz, y_samp_s, y_samp_p, filter_length, 0, "samples")
        _log_timeseries(buddy, viz, y_samp_s, y_samp_p, filter_length, 1, "samples")
        if plot_means:
            _log_timeseries(buddy, viz, y_mu_s, y_mu_p, filter_length, 0, "means")
            _log_timeseries(buddy, viz, y_mu_s, y_mu_p, filter_length, 1, "means")

    # -------------- #
    # FILTERING ONLY #
    # -------------- #

    # sample traj image
    z0 = estimator.get_initial_hidden_state(viz.y0.shape[0])
    noise = np.sqrt(data_var) * torch.randn_like(viz.y)
    z_mu_f, z_cov_f = estimator(
        viz.t[:filter_length],
        viz.y[:filter_length] + noise[:filter_length],
        viz.u[:filter_length],
        z0,
        return_hidden=True,
    )

    # sampling from observation distributions
    y_mu_f, y_cov_f = estimator.latent_to_observation(z_mu_f, z_cov_f)
    y_samp_f = to_numpy(reparameterize_gauss(y_mu_f, y_cov_f))
    y_mu_f = to_numpy(y_mu_f)

    # xy filter plots
    with buddy.log_scope("1_filter-xy-trajectory"):
        # plotting samples vs. ground truth
        with ph.plot_context(viz.plot_settings) as (fig, ax):
            ph.plot_xy_compare([y_samp_f, viz.np_y[:, :, 0:2]])
            p_img = ph.plot_as_image(fig)
        buddy.log_image("samples-no-noise", p_img)

        with ph.plot_context(viz.plot_settings) as (fig, ax):
            ph.plot_xy_compare([y_samp_f, viz.np_y + to_numpy(noise)])
            p_img = ph.plot_as_image(fig)
        buddy.log_image("samples-meas-noise", p_img)

        # plotting means vs. ground truth
        if plot_means:
            with ph.plot_context(viz.plot_settings) as (fig, ax):
                ph.plot_xy_compare([y_mu_f, viz.np_y[:, :, 0:2]])
                p_img = ph.plot_as_image(fig)
            buddy.log_image("means-no-noise", p_img)

            with ph.plot_context(viz.plot_settings) as (fig, ax):
                ph.plot_xy_compare([y_mu_f, viz.np_y + to_numpy(noise)])
                p_img = ph.plot_as_image(fig)
            buddy.log_image("means-meas-noise", p_img)

    # time filter plots
    with buddy.log_scope("1_filter-t-trajectory"):
        viz_noise = replace(viz, np_y=viz.np_y + to_numpy(noise))

        # x,y samples no noise/with noise
        _log_filter(buddy, viz, y_samp_f, filter_length, 0, "samples-no-noise")
        _log_filter(buddy, viz, y_samp_f, filter_length, 1, "samples-no-noise")
        _log_filter(
            buddy, viz_noise, y_samp_f, filter_length, 0, "samples-meas-noise",
        )
        _log_filter(
            buddy, viz_noise, y_samp_f, filter_length, 1, "samples-meas-noise",
        )

        # x,y means no noise/with noise
        if plot_means:
            _log_filter(buddy, viz, y_mu_f, filter_length, 0, "means-no-noise")
            _log_filter(buddy, viz, y_mu_f, filter_length, 1, "means-no-noise")
            _log_filter(
                buddy, viz_noise, y_mu_f, filter_length, 0, "means-meas-noise",
            )
            _log_filter(
                buddy, viz_noise, y_mu_f, filter_length, 1, "means-meas-noise",
            )
def rollout_and_eval(pf_model,
                     trajectories,
                     start_time=0,
                     max_timesteps=300,
                     particle_count=100,
                     noisy_dynamics=True,
                     true_initial=False):
    # To make things easier, we're going to cut all our trajectories to the
    # same length :)
    end_time = np.min([len(s) for s, _, _ in trajectories] +
                      [start_time + max_timesteps])
    actual_states = [
        states[start_time:end_time] for states, _, _ in trajectories
    ]

    state_dim = len(actual_states[0][0])
    N = len(trajectories)
    M = particle_count

    device = next(pf_model.parameters()).device

    particles = np.zeros((N, M, state_dim))
    if true_initial:
        for i in range(N):
            particles[i, :] = trajectories[i][0][0]
        particles += np.random.normal(0, 0.2, size=[N, 1, state_dim])
        particles += np.random.normal(0, 0.2, size=particles.shape)
    else:
        # Distribute initial particles randomly
        particles += np.random.normal(0, 1.0, size=particles.shape)

    # Populate the initial state estimate as just the estimate of our particles
    # This is a little hacky
    # (N, t, state_dim)
    predicted_states = [[np.mean(particles[i], axis=0)]
                        for i in range(len(trajectories))]

    particles = utils.to_torch(particles, device=device)
    log_weights = torch.ones((N, M), device=device) * (-np.log(M))

    # (N, t, M, state_dim)
    particles_history = []
    # (N, t, M)
    weights_history = []

    for i in range(N):
        particles_history.append([utils.to_numpy(particles[i])])
        weights_history.append([utils.to_numpy(log_weights[i])])

    for t in tqdm(range(start_time + 1, end_time)):
        s = []
        o = {}
        c = []
        for i, traj in enumerate(trajectories):
            states, observations, controls = traj

            s.append(predicted_states[i][t - start_time - 1])
            o_t = utils.DictIterator(observations)[t]
            utils.DictIterator(o).append(o_t)
            c.append(controls[t])

        s = np.array(s)
        utils.DictIterator(o).convert_to_numpy()
        c = np.array(c)
        (s, o, c) = utils.to_torch((s, o, c), device=device)

        state_estimates, new_particles, new_log_weights = pf_model.forward(
            particles,
            log_weights,
            o,
            c,
            resample=True,
            noisy_dynamics=noisy_dynamics)

        particles = new_particles
        log_weights = new_log_weights

        for i in range(len(trajectories)):
            predicted_states[i].append(utils.to_numpy(state_estimates[i]))

            particles_history[i].append(utils.to_numpy(particles[i]))
            weights_history[i].append(np.exp(utils.to_numpy(log_weights[i])))

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)

    ### Eval
    timesteps = len(actual_states[0])

    def color(i):
        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        return colors[i % len(colors)]

    state_dim = actual_states.shape[-1]
    for j in range(state_dim):
        plt.figure(figsize=(8, 6))
        for i, (pred, actual, particles, weights) in enumerate(
                zip(predicted_states, actual_states, particles_history,
                    weights_history)):
            predicted_label_arg = {}
            actual_label_arg = {}
            if i == 0:
                predicted_label_arg['label'] = "Predicted"
                actual_label_arg['label'] = "Ground Truth"
            plt.plot(range(timesteps),
                     pred[:, j],
                     c=color(i),
                     alpha=0.5,
                     **predicted_label_arg)
            plt.plot(range(timesteps),
                     actual[:, j],
                     c=color(i),
                     **actual_label_arg)

            for t in range(0, timesteps, 20):
                particle_ys = particles[t][:, j]
                particle_xs = [t for _ in particle_ys]
                plt.scatter(particle_xs, particle_ys, c=color(i), alpha=0.02)
                # particle_alphas = weights[t]
                # particle_alphas /= np.max(particle_alphas)
                # particle_alphas *= 0.3
                # particle_alphas += 0.05
                #
                # for px, py, pa in zip(
                #         particle_xs, particle_ys, particle_alphas):
                #     plt.scatter([px], [py], c=color(i), alpha=pa)

        rmse = np.sqrt(
            np.mean(
                (predicted_states[:, 10:, j] - actual_states[:, 10:, j])**2))
        print(rmse)

        plt.title(f"State #{j} // RMSE = {rmse}")
        plt.xlabel("Timesteps")
        plt.ylabel("Value")
        plt.legend()
        plt.show()
Ejemplo n.º 18
0
    def forward(self,
                states_prev,
                log_weights_prev,
                observations,
                controls,
                resample=True,
                noisy_dynamics=True,
                know_image_blackout=False):

        N, M, state_dim = states_prev.shape
        assert log_weights_prev.shape == (N, M)

        device = states_prev.device

        # If we aren't resampling, contract our particles within each
        # individual particle filter
        if resample:
            output_particles = M
        else:
            assert M % 2 == 0
            output_particles = M // 2

        # Propagate particles through each particle filter
        image_state_estimates, image_states_pred, image_log_weights_pred = self.image_model(
            states_prev,
            log_weights_prev,
            observations,
            controls,
            output_particles=output_particles,
            resample=False)
        force_state_estimates, force_states_pred, force_log_weights_pred = self.force_model(
            states_prev,
            log_weights_prev,
            observations,
            controls,
            output_particles=output_particles,
            resample=False)

        # Get weights
        image_log_beta, force_log_beta = self.weight_model(observations)
        assert image_log_beta.shape == (N, 1)
        assert force_log_beta.shape == (N, 1)

        self._betas = utils.to_numpy([image_log_beta, force_log_beta])

        # Ignore image if blacked out
        if know_image_blackout:
            blackout_indices = torch.sum(
                torch.abs(observations['image'].reshape(
                    (N, -1))), dim=1) < 1e-8

            ## Masking in-place breaks autograd
            # image_log_beta[blackout_indices, :] = float('-inf')
            # force_log_beta[blackout_indices, :] = 0.

            mask_shape = (N, 1)
            mask = torch.ones(mask_shape, device=device)
            mask[blackout_indices] = 0

            image_log_beta_new = torch.zeros(mask_shape, device=device)
            image_log_beta_new[blackout_indices] = np.log(1e-9)
            image_log_beta = image_log_beta_new + mask * image_log_beta

            force_log_beta_new = torch.zeros(mask_shape, device=device)
            force_log_beta_new[blackout_indices] = np.log(1. - 1e-9)
            force_log_beta = force_log_beta_new + mask * force_log_beta

        # Weight state estimates from each filter
        state_estimates = torch.exp(image_log_beta) * image_state_estimates \
            + torch.exp(force_log_beta) * force_state_estimates

        # Model freezing
        if self.freeze_image_model:
            image_state_estimates = image_state_estimates.detach()
            image_states_pred = image_states_pred.detach()
            image_log_weights_pred = image_log_weights_pred.detach()

        if self.freeze_force_model:
            force_state_estimates = force_state_estimates.detach()
            force_states_pred = force_states_pred.detach()
            force_log_weights_pred = force_log_weights_pred.detach()

        if self.freeze_weight_model:
            image_log_beta = image_log_beta.detach()
            force_log_beta = force_log_beta.detach()

        # Concatenate particles from each filter
        states_pred = torch.cat([
            image_states_pred,
            force_states_pred,
        ],
                                dim=1)
        log_weights_pred = torch.cat([
            image_log_weights_pred + image_log_beta,
            force_log_weights_pred + force_log_beta,
        ],
                                     dim=1)

        if resample:
            assert log_weights_pred.shape == (N, 2 * M)
            assert states_pred.shape == (N, 2 * M, state_dim)

            # Resample particles
            distribution = torch.distributions.Categorical(
                logits=log_weights_pred)
            state_indices = distribution.sample((M, )).T
            assert state_indices.shape == (N, M)

            states = torch.zeros((N, M, state_dim), device=device)
            for i in range(N):
                # We can probably optimize this loop out
                states[i] = states_pred[i][state_indices[i]]

            # Uniform weights
            log_weights = torch.zeros((N, M), device=device) - np.log(M)
        else:
            states = states_pred
            log_weights = log_weights_pred

            # Normalize predicted weights
            log_weights = log_weights_pred - \
                torch.logsumexp(log_weights_pred, dim=1)[:, np.newaxis]

        return state_estimates, states, log_weights