def train_measurement(buddy, ekf_model, dataloader, log_interval=10):
    losses = []

    # Train measurement model only for 1 epoch
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        noisy_states, observations, log_likelihoods = batch_gpu

        # noisy_states = noisy_states[:, np.newaxis, :]
        z, R = ekf_model.measurement_model(observations, noisy_states)
        assert len(z.shape) == 2
        # pred_likelihoods = pred_likelihoods.squeeze(dim=1)
        # todo: get actual x!

        loss = torch.mean((pred_likelihoods - log_likelihoods)**2)
        losses.append(utils.to_numpy(loss))

        buddy.minimize(loss,
                       optimizer_name="measurement",
                       checkpoint_interval=10000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope("measurement"):
                buddy.log("Training loss", loss)

                buddy.log("Pred likelihoods mean", pred_likelihoods.mean())
                buddy.log("Pred likelihoods std", pred_likelihoods.std())

                buddy.log("Label likelihoods mean", log_likelihoods.mean())
                buddy.log("Label likelihoods std", log_likelihoods.std())

    print("Epoch loss:", np.mean(losses))
def train_dynamics(buddy,
                   pf_model,
                   dataloader,
                   log_interval=10,
                   optim_name="dynamics"):
    losses = []

    # Train dynamics only for 1 epoch
    # Train for 1 epoch
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        prev_states, _unused_observations, controls, new_states = batch_gpu

        prev_states += utils.to_torch(np.random.normal(0,
                                                       0.05,
                                                       size=prev_states.shape),
                                      device=buddy._device)
        prev_states = prev_states[:, np.newaxis, :]
        new_states_pred = pf_model.dynamics_model(prev_states,
                                                  controls,
                                                  noisy=False)
        new_states_pred = new_states_pred.squeeze(dim=1)

        mse_pos = F.mse_loss(new_states_pred, new_states)
        # mse_pos = torch.mean((new_states_pred - new_states) ** 2, axis=0)
        loss = mse_pos
        losses.append(utils.to_numpy(loss))

        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=1000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                # buddy.log("Training loss", loss)
                buddy.log("MSE position", mse_pos)

                label_std = new_states.std(dim=0)
                buddy.log("Label pos std", label_std[0])

                pred_std = new_states_pred.std(dim=0)
                buddy.log("Predicted pos std", pred_std[0])

                label_mean = new_states.mean(dim=0)
                buddy.log("Label pos mean", label_mean[0])

                pred_mean = new_states_pred.mean(dim=0)
                buddy.log("Predicted pos mean", pred_mean[0])

            # print(".", end="")
    print("Epoch loss:", np.mean(losses))
Example #3
0
def train(buddy, model, dataloader, log_interval=10, state_noise_std=0.2):
    losses = []

    # Train for 1 epoch
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        prev_states, observations, controls, new_states = batch_gpu
        prev_states += utils.to_torch(np.random.normal(0,
                                                       state_noise_std,
                                                       size=prev_states.shape),
                                      device=buddy._device)

        new_states_pred = model(prev_states, observations, controls)

        # mse_pos, mse_vel = torch.mean((new_states_pred - new_states) ** 2, axis=0)
        # loss = (mse_pos + mse_vel) / 2
        loss = torch.mean((new_states_pred - new_states)**2)
        losses.append(utils.to_numpy(loss))

        buddy.minimize(loss, checkpoint_interval=10000)

        if buddy._steps % log_interval == 0:
            with buddy.log_scope("baseline_training"):
                buddy.log("Training loss", loss)
                # buddy.log("MSE position", mse_pos)
                # buddy.log("MSE velocity", mse_vel)

                label_std = new_states.std(dim=0)
                buddy.log("Training pos std", label_std[0])
                # buddy.log("Training vel std", label_std[1])

                pred_std = new_states_pred.std(dim=0)
                buddy.log("Predicted pos std", pred_std[0])
                # buddy.log("Predicted vel std", pred_std[1])

                label_mean = new_states.mean(dim=0)
                buddy.log("Training pos mean", label_mean[0])
                # buddy.log("Training vel mean", label_mean[1])

                pred_mean = new_states_pred.mean(dim=0)
                buddy.log("Predicted pos mean", pred_mean[0])
                # buddy.log("Predicted vel mean", pred_mean[1])

    print("Epoch loss:", np.mean(losses))
def train_e2e(buddy,
              pf_model,
              dataloader,
              log_interval=2,
              loss_type="mse",
              optim_name="e2e",
              resample=False,
              know_image_blackout=False):
    # Train for 1 epoch
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        batch_particles, batch_states, batch_obs, batch_controls = batch_gpu

        # N = batch size, M = particle count
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        N, M, state_dim = batch_particles.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        # Give all particle equal weights
        particles = batch_particles
        log_weights = torch.ones((N, M), device=buddy._device) * (-np.log(M))

        # Accumulate losses from each timestep
        losses = []
        for t in range(1, timesteps):
            prev_particles = particles
            prev_log_weights = log_weights

            if know_image_blackout:
                state_estimates, new_particles, new_log_weights = pf_model.forward(
                    prev_particles,
                    prev_log_weights,
                    utils.DictIterator(batch_obs)[:, t - 1, :],
                    batch_controls[:, t, :],
                    resample=resample,
                    noisy_dynamics=True,
                    know_image_blackout=True)
            else:
                state_estimates, new_particles, new_log_weights = pf_model.forward(
                    prev_particles,
                    prev_log_weights,
                    utils.DictIterator(batch_obs)[:, t - 1, :],
                    batch_controls[:, t, :],
                    resample=resample,
                    noisy_dynamics=True,
                )

            if loss_type == "gmm":
                loss = dpf.gmm_loss(particles_states=new_particles,
                                    log_weights=new_log_weights,
                                    true_states=batch_states[:, t, :],
                                    gmm_variances=np.array([0.1]))
            elif loss_type == "mse":
                loss = torch.mean((state_estimates - batch_states[:, t, :])**2)
            else:
                assert False, "Invalid loss"

            losses.append(loss)

            # Enable backprop through time
            particles = new_particles
            log_weights = new_log_weights

            # # Disable backprop through time
            # particles = new_particles.detach()
            # log_weights = new_log_weights.detach()

            # assert state_estimates.shape == batch_states[:, t, :].shape

        buddy.minimize(torch.mean(torch.stack(losses)),
                       optimizer_name=optim_name,
                       checkpoint_interval=1000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", np.mean(utils.to_numpy(losses)))
                buddy.log("Log weights mean", log_weights.mean())
                buddy.log("Log weights std", log_weights.std())
                buddy.log("Particle states mean", particles.mean())
                buddy.log("particle states std", particles.std())

    print("Epoch loss:", np.mean(utils.to_numpy(losses)))
def train_dynamics_recurrent(buddy,
                             pf_model,
                             dataloader,
                             log_interval=10,
                             loss_type="l1",
                             optim_name="dynamics_recurrent"):

    assert loss_type in ('l1', 'l2', 'huber', 'peter')

    # Train dynamics only for 1 epoch
    # Train for 1 epoch
    epoch_losses = []
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        batch_states, batch_obs, batch_controls = batch_gpu
        # N = batch size, M = particle count
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        # Track current states as they're propagated through our dynamics model
        prev_states = batch_states[:, 0, :]
        assert prev_states.shape == (N, state_dim)

        # Accumulate losses from each timestep
        losses = []
        magnitude_losses = []
        direction_losses = []

        # Compute some state deltas for debugging
        label_deltas = np.mean(utils.to_numpy(batch_states[:, 1:, :] -
                                              batch_states[:, :-1, :])**2,
                               axis=(0, 2))
        assert label_deltas.shape == (timesteps - 1, )
        pred_deltas = []

        for t in range(1, timesteps):
            # Propagate current states through dynamics model
            controls = batch_controls[:, t, :]
            new_states = pf_model.dynamics_model(
                prev_states[:, np.newaxis, :],  # Add particle dimension
                controls,
                noisy=False,
            ).squeeze(dim=1)  # Remove particle dimension
            assert new_states.shape == (N, state_dim)

            # Compute deltas
            pred_delta = prev_states - new_states
            label_delta = batch_states[:, t - 1, :] - batch_states[:, t, :]
            assert pred_delta.shape == (N, state_dim)
            assert label_delta.shape == (N, state_dim)

            # Compute and add loss
            if loss_type == 'l1':
                # timestep_loss = F.l1_loss(pred_delta, label_delta)
                timestep_loss = F.l1_loss(new_states, batch_states[:, t, :])
            elif loss_type == 'l2':
                # timestep_loss = F.mse_loss(pred_delta, label_delta)
                timestep_loss = F.mse_loss(new_states, batch_states[:, t, :])
            elif loss_type == 'huber':
                # Note that the units our states are in will affect results
                # for Huber
                timestep_loss = F.smooth_l1_loss(batch_states[:, t, :],
                                                 new_states)
            elif loss_type == 'peter':
                # Use a Peter loss
                # Currently broken
                assert False

                pred_magnitude = torch.norm(pred_delta, dim=1)
                label_magnitude = torch.norm(label_delta, dim=1)
                assert pred_magnitude.shape == (N, )
                assert label_magnitude.shape == (N, )

                # pred_direction = pred_delta / (pred_magnitude + 1e-8)
                # label_direction = label_delta / (label_magnitude + 1e-8)
                # assert pred_direction.shape == (N, state_dim)
                # assert label_direction.shape == (N, state_dim)

                # Compute loss
                magnitude_loss = F.mse_loss(pred_magnitude, label_magnitude)
                # direction_loss =
                timestep_loss = magnitude_loss + direction_loss

                magnitude_losses.append(magnitude_loss)
                direction_losses.append(direction_loss)

            else:
                assert False
            losses.append(timestep_loss)

            # Compute delta and update states
            pred_deltas.append(
                np.mean(utils.to_numpy(new_states - prev_states)**2))
            prev_states = new_states

        pred_deltas = np.array(pred_deltas)
        assert pred_deltas.shape == (timesteps - 1, )

        loss = torch.mean(torch.stack(losses))
        epoch_losses.append(loss)
        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=1000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", loss)

                buddy.log("Label delta mean", label_deltas.mean())
                buddy.log("Label delta std", label_deltas.std())

                buddy.log("Pred delta mean", pred_deltas.mean())
                buddy.log("Pred delta std", pred_deltas.std())

                if magnitude_losses:
                    buddy.log("Magnitude loss",
                              torch.mean(torch.tensor(magnitude_losses)))
                if direction_losses:
                    buddy.log("Direction loss",
                              torch.mean(torch.tensor(direction_losses)))

    print("Epoch loss:", np.mean(utils.to_numpy(epoch_losses)))
def train_dynamics_recurrent(buddy,
                             kf_model,
                             dataloader,
                             log_interval=10,
                             loss_type="l2",
                             optim_name="dynamics_recurr",
                             checkpoint_interval=10000,
                             init_state_noise=0.5):
    epoch_losses = []

    assert loss_type in ('l1', 'l2')

    for batch_idx, batch in enumerate(tqdm(dataloader)):
        batch_gpu = utils.to_device(batch, buddy._device)
        batch_states, batch_obs, batch_controls = batch_gpu

        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        prev_states = batch_states[:, 0, :]

        losses = []
        magnitude_losses = []
        direction_losses = []

        # Compute some state deltas for debugging
        label_deltas = np.mean(utils.to_numpy(batch_states[:, 1:, :] -
                                              batch_states[:, :-1, :])**2,
                               axis=(0, 2))
        assert label_deltas.shape == (timesteps - 1, )
        pred_deltas = []

        for t in range(1, timesteps):
            controls = batch_controls[:, t, :]
            new_states = kf_model.dynamics_model(
                prev_states,
                controls,
                noisy=False,
            )

            pred_delta = prev_states - new_states
            label_delta = batch_states[:, t - 1, :] - batch_states[:, t, :]

            # todo: maybe switch back to l2
            if loss_type == "l1":
                timestep_loss = F.l1_loss(new_states, batch_states[:, t, :])
            else:
                timestep_loss = F.mse_loss(new_states, batch_states[:, t, :])

            losses.append(timestep_loss)

            pred_deltas.append(
                np.mean(utils.to_numpy(new_states - prev_states)**2))
            prev_states = new_states

        pred_deltas = np.array(pred_deltas)
        assert pred_deltas.shape == (timesteps - 1, )

        loss = torch.mean(torch.stack(losses))
        epoch_losses.append(loss)

        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=checkpoint_interval)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", loss)

                buddy.log("Label delta mean", label_deltas.mean())
                buddy.log("Label delta std", label_deltas.std())

                buddy.log("Pred delta mean", pred_deltas.mean())
                buddy.log("Pred delta std", pred_deltas.std())

                if magnitude_losses:
                    buddy.log("Magnitude loss",
                              torch.mean(torch.tensor(magnitude_losses)))
                if direction_losses:
                    buddy.log("Direction loss",
                              torch.mean(torch.tensor(direction_losses)))
def train_fusion(buddy,
                 fusion_model,
                 dataloader,
                 log_interval=2,
                 optim_name="fusion",
                 measurement_init=True,
                 init_state_noise=0.2,
                 one_loss=True,
                 nll=False):
    # todo: change loss to selection/mixed
    for batch_idx, batch in enumerate(dataloader):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        _, batch_states, batch_obs, batch_controls = batch_gpu
        # N = batch size
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        state = batch_states[:, 0, :]
        state_sigma = torch.eye(state.shape[-1],
                                device=buddy._device) * init_state_noise**2
        state_sigma = state_sigma.unsqueeze(0).repeat(N, 1, 1)

        if measurement_init:
            state, state_sigma = fusion_model.measurement_only(
                utils.DictIterator(batch_obs)[:, 0, :], state)

        else:
            dist = torch.distributions.Normal(
                torch.tensor([0.]),
                torch.ones(state.shape) * init_state_noise)
            noise = dist.sample().to(state.device)
            state += noise

        losses_image = []
        losses_force = []
        losses_fused = []
        losses_nll = []
        losses_total = []

        for t in range(1, timesteps - 1):
            prev_state = state
            prev_state_sigma = state_sigma

            state, state_sigma, force_state, image_state = fusion_model.forward(
                prev_state,
                prev_state_sigma,
                utils.DictIterator(batch_obs)[:, t, :],
                batch_controls[:, t, :],
            )

            loss_image = torch.mean((image_state - batch_states[:, t, :])**2)
            loss_force = torch.mean((force_state - batch_states[:, t, :])**2)
            loss_fused = torch.mean((state - batch_states[:, t, :])**2)

            losses_force.append(loss_force.item())
            losses_image.append(loss_image.item())
            losses_fused.append(loss_fused.item())

            if nll:
                loss_nll = torch.mean(-1.0 * utility.gaussian_log_likelihood(
                    state, batch_states[:, t, :], state_sigma))
                losses_nll.append(loss_nll)
                losses_total.append(loss_nll)

            elif one_loss:
                losses_total.append(loss_fused)
            else:
                losses_total.append(loss_image + loss_force + loss_fused)

        loss = torch.mean(torch.stack(losses_total))

        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=10000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope("fusion"):
                buddy.log("Training loss", loss.item())
                buddy.log("Image loss", np.mean(np.array(losses_image)))
                buddy.log("Force loss", np.mean(np.array(losses_force)))
                buddy.log("Fused loss", np.mean(np.array(losses_fused)))
                buddy.log_model_grad_norm()
def train_e2e(buddy,
              ekf_model,
              dataloader,
              log_interval=2,
              optim_name="ekf",
              measurement_init=True,
              checkpoint_interval=1000,
              init_state_noise=0.2,
              loss_type="mse"):
    # Train for 1 epoch
    for batch_idx, batch in enumerate(dataloader):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        _, batch_states, batch_obs, batch_controls = batch_gpu
        # N = batch size, M = particle count
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)

        state = batch_states[:, 0, :]
        state_sigma = torch.eye(state.shape[-1],
                                device=buddy._device) * init_state_noise**2
        state_sigma = state_sigma.unsqueeze(0).repeat(N, 1, 1)

        if measurement_init:
            state, state_sigma = ekf_model.measurement_model.forward(
                utils.DictIterator(batch_obs)[:, 0, :], batch_states[:, 0, :])
        else:
            dist = torch.distributions.Normal(
                torch.tensor([0.]),
                torch.ones(state.shape) * init_state_noise)
            noise = dist.sample().to(state.device)
            state += noise

        # Accumulate losses from each timestep
        losses = []
        for t in range(1, timesteps - 1):
            prev_state = state
            prev_state_sigma = state_sigma

            state, state_sigma = ekf_model.forward(
                prev_state,
                prev_state_sigma,
                utils.DictIterator(batch_obs)[:, t, :],
                batch_controls[:, t, :],
            )

            assert state.shape == batch_states[:, t, :].shape

            mse = torch.mean((state - batch_states[:, t, :])**2)

            assert loss_type in ['nll', 'mse', 'mixed']
            if loss_type == 'nll':
                nll = -1.0 * utility.gaussian_log_likelihood(
                    state, batch_states[:, t, :], state_sigma)
                nll = torch.mean(nll)
                # import ipdb;ipdb.set_trace()
                loss = nll
            elif loss_type == 'mse':
                loss = mse
            else:
                nll = -1.0 * utility.gaussian_log_likelihood(
                    state, batch_states[:, t, :], state_sigma)
                nll = torch.mean(nll)
                loss = nll + mse

            losses.append(loss)

        loss = torch.mean(torch.stack(losses))
        buddy.minimize(loss,
                       optimizer_name=optim_name,
                       checkpoint_interval=checkpoint_interval)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope(optim_name):
                buddy.log("Training loss", loss.item())
                buddy.log_model_grad_norm()