class PendulumDynamics:
    def __init__(self, latent_space, backward=False):
        assert latent_space.env.unwrapped.spec.id == "InvertedPendulum-v2"
        self.latent_space = latent_space
        self.backward = backward
        self.dynamics = ExactDynamicsMujoco(
            self.latent_space.env.unwrapped.spec.id,
            tolerance=1e-2,
            max_iters=100)
        self.low = np.array([-1, -np.pi, -10, -10])
        self.high = np.array([1, np.pi, 10, 10])

    def step(self, state, action, sample=True):
        obs = state  # self.latent_space.decoder(state)
        obs = np.clip(obs, self.low, self.high)
        if self.backward:
            obs = self.dynamics.inverse_dynamics(obs, action)
        else:
            obs = self.dynamics.dynamics(obs, action)
        state = obs  # self.latent_space.encoder(obs)
        return state

    def learn(self, *args, return_initial_loss=False, **kwargs):
        print("Using exact dynamics...")
        if return_initial_loss:
            return 0, 0
        return 0
 def __init__(self, latent_space, backward=False):
     assert latent_space.env.unwrapped.spec.id == "InvertedPendulum-v2"
     self.latent_space = latent_space
     self.backward = backward
     self.dynamics = ExactDynamicsMujoco(
         self.latent_space.env.unwrapped.spec.id,
         tolerance=1e-2,
         max_iters=100)
     self.low = np.array([-1, -np.pi, -10, -10])
     self.high = np.array([1, np.pi, 10, 10])
Ejemplo n.º 3
0
 def add_play_data(self, env, play_data):
     dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id)
     observations, actions = play_data["observations"], play_data["actions"]
     n_traj = len(observations)
     assert len(actions) == n_traj
     for i in range(n_traj):
         l_traj = len(observations[i])
         for t in range(l_traj):
             obs = observations[i][t]
             action = actions[i][t]
             next_obs = dynamics.dynamics(obs, action)
             self.append(obs, action, next_obs)
Ejemplo n.º 4
0
    def update_experience_replay(
        traj_actions_backward,
        traj_actions_forward,
        traj_observations_backward,
        traj_observations_forward,
        solver,
        experience_replay,
    ):
        """
        Rolls out the action sequences observed in the true dynamics model and keeps
        training the dynamics models on these.

        This is currently only used for debugging.
        """
        dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id,
                                       tolerance=1e-3,
                                       max_iters=100)

        for states, actions in zip(traj_observations_backward,
                                   traj_actions_backward):
            _get_transitions_from_rollout(states, actions, dynamics,
                                          experience_replay)

        for states, actions in zip(traj_observations_forward,
                                   traj_actions_forward):
            _get_transitions_from_rollout(states, actions, dynamics,
                                          experience_replay)
    def __init__(
        self,
        env,
        solver,
        experience_replay,
        tensorboard_log=None,
        learning_rate=3e-4,
        n_layers=10,
        layer_size=256,
        n_out=1,
        gauss_stdev=None,
    ):
        self.env = env
        self.solver = solver
        self.experience_replay = experience_replay
        self.dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id,
                                            tolerance=1e-3,
                                            max_iters=100)

        assert len(self.env.action_space.shape) == 1
        self.action_dim = self.env.action_space.shape[0]
        self.observation_shape = list(self.env.observation_space.shape)
        self.n_out = n_out
        self.gauss_stdev = gauss_stdev
        self.layer_size = layer_size
        self.n_layers = n_layers

        self._define_input_placeholders()
        self._define_model()

        self.loss = self._define_loss()
        self.learning_rate, self.global_step = get_learning_rate(
            learning_rate, None, 1)

        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.gradients = self.optimizer.compute_gradients(loss=self.loss)
        self.optimization_op = self.optimizer.apply_gradients(
            self.gradients, global_step=self.global_step)

        self.tensorboard_log = tensorboard_log
        if self.tensorboard_log is not None:
            self._define_tensorboard_metrics()

        self.sess = None
class InversePolicyMDN:
    """
    InversePolicyMDN
    """
    def __init__(
        self,
        env,
        solver,
        experience_replay,
        tensorboard_log=None,
        learning_rate=3e-4,
        n_layers=10,
        layer_size=256,
        n_out=1,
        gauss_stdev=None,
    ):
        self.env = env
        self.solver = solver
        self.experience_replay = experience_replay
        self.dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id,
                                            tolerance=1e-3,
                                            max_iters=100)

        assert len(self.env.action_space.shape) == 1
        self.action_dim = self.env.action_space.shape[0]
        self.observation_shape = list(self.env.observation_space.shape)
        self.n_out = n_out
        self.gauss_stdev = gauss_stdev
        self.layer_size = layer_size
        self.n_layers = n_layers

        self._define_input_placeholders()
        self._define_model()

        self.loss = self._define_loss()
        self.learning_rate, self.global_step = get_learning_rate(
            learning_rate, None, 1)

        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.gradients = self.optimizer.compute_gradients(loss=self.loss)
        self.optimization_op = self.optimizer.apply_gradients(
            self.gradients, global_step=self.global_step)

        self.tensorboard_log = tensorboard_log
        if self.tensorboard_log is not None:
            self._define_tensorboard_metrics()

        self.sess = None

    def step(self, in_state, sample=True):
        batch_in_states = np.expand_dims(in_state, 0)

        if sample:
            (out, ) = self.sess.run([self.out_sample],
                                    feed_dict={self.in_state: batch_in_states})
            out = out[0]
        else:
            factors, means, = self.sess.run(
                [self.mixture_factors, [c.loc for c in self.components]],
                feed_dict={self.in_state: batch_in_states},
            )
            out = means[np.argmax(factors[0])][0]
        out = np.clip(out, self.env.action_space.low,
                      self.env.action_space.high)
        return out

    def _define_input_placeholders(self):
        self.true_action = tf.placeholder(tf.float32, [None, self.action_dim],
                                          name="action")
        self.in_state = tf.placeholder(tf.float32,
                                       [None] + self.observation_shape,
                                       name="in_state")

    def _define_model(self):
        activation = tf.nn.relu

        x = tf.reshape(self.in_state, [-1, np.prod(self.observation_shape)])
        x_1 = x
        for i in range(self.n_layers):
            x = tf.keras.layers.Dense(self.layer_size,
                                      activation=activation,
                                      name="hidden_{}".format(i + 1))(x)

        means = x
        means = tf.keras.layers.Dense(self.n_out * self.action_dim,
                                      activation=None,
                                      name="output")(means)
        # note: this requires an 1d action space
        self.mixture_means = tf.split(means, [self.action_dim] * self.n_out,
                                      -1)

        if self.gauss_stdev is None:
            stddevs = x
            stddevs = tf.keras.layers.Dense(self.n_out,
                                            activation=None)(stddevs)
            stddevs = tf.exp(stddevs)
            stddevs = tf.clip_by_value(stddevs,
                                       clip_value_min=1e-10,
                                       clip_value_max=1e10)
            # same stddev for all components
            # self.mixture_stddevs = tf.reshape(
            #     stddevs, (-1, self.n_out, self.state_size)
            # )
            self.mixture_stddevs = tf.split(stddevs, [1] * self.n_out, -1)
        else:
            self.mixture_stddevs = None

        factors = tf.concat([x, x_1], -1)
        # factors = tf.keras.layers.Dense(
        #     self.layer_size // 2, activation=activation, name="factors_hidden"
        # )(factors)
        factors = tf.keras.layers.Dense(self.n_out,
                                        activation=None,
                                        name="factors_out")(factors)
        # if we don't clip these, we get NANs in the loss computation
        factors = tf.clip_by_value(factors,
                                   clip_value_min=0.1,
                                   clip_value_max=10)
        self.mixture_factors = tf.nn.softmax(factors)

        if self.gauss_stdev is None:
            self.components = [
                tfd.MultivariateNormalDiag(
                    loc=mean,
                    scale_diag=tf.tile(stddev, [1, self.action_dim]),
                    validate_args=True,
                    allow_nan_stats=False,
                ) for mean, stddev in zip(self.mixture_means,
                                          self.mixture_stddevs)
            ]
        else:
            self.components = [
                tfd.MultivariateNormalDiag(
                    loc=mean,
                    scale_diag=tf.fill(tf.shape(mean), self.gauss_stdev),
                    validate_args=True,
                    allow_nan_stats=False,
                ) for mean in self.mixture_means
            ]

        self.mixture_distribution = tfd.Categorical(probs=self.mixture_factors,
                                                    allow_nan_stats=False)
        self.out_distribution = tfd.Mixture(
            cat=self.mixture_distribution,
            components=self.components,
            validate_args=True,
            allow_nan_stats=False,
        )
        self.out_sample = self.out_distribution.sample()

    def _define_loss(self):
        neg_logprob = -tf.reduce_mean(
            self.out_distribution.log_prob(self.true_action))
        self.logprobs = [
            tf.reduce_mean(dist.log_prob(self.true_action))
            for dist in self.components
        ]
        self.mixture_entropy = tf.reduce_mean(
            self.mixture_distribution.entropy())
        return neg_logprob

    def _define_tensorboard_metrics(self):
        tf.summary.scalar("loss/loss", self.loss)
        tf.summary.scalar("loss/mixture_entropy", self.mixture_entropy)
        tf.summary.scalar("learning_rate", self.learning_rate)
        for i, lp in enumerate(self.logprobs):
            tf.summary.scalar("loss/logprob_{}".format(i + 1), lp)
        tensorboard_log_gradients(self.gradients)

    def _apply_policy(self, observations):
        next_states, actions = [], []
        for obs in observations:
            action = self.solver.predict(obs)[0]
            try:
                obs = self.dynamics.dynamics(obs, action)
                next_states.append(np.copy(obs))
                actions.append(np.copy(action))
            except Exception as e:
                print("_apply_policy", e)
        return next_states, actions

    def learn(
        self,
        n_epochs=1,
        batch_size=16,
        return_initial_loss=False,
        verbose=True,
        reinitialize=False,
    ):
        """
        Main training loop
        """
        if self.sess is None:
            self.sess = get_tf_session()
            reinitialize = True
        if reinitialize:
            self.sess.run(tf.global_variables_initializer())

        if self.tensorboard_log is not None:
            summaries_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.tensorboard_log,
                                                   self.sess.graph)
        else:
            summaries_op = tf.no_op()

        n_batches = len(self.experience_replay) // batch_size

        first_epoch_losses = []
        last_epoch_losses = []

        for epoch in range(n_epochs):
            for batch in range(n_batches):
                obs, _, _ = self.experience_replay.sample(batch_size,
                                                          normalize=False)
                batch_next_states, batch_actions = self._apply_policy(obs)

                (
                    batch_loss,
                    _,
                    batch_lr,
                    summary,
                    step,
                    mixture_entropy,
                ) = self.sess.run(
                    [
                        self.loss,
                        self.optimization_op,
                        self.learning_rate,
                        summaries_op,
                        self.global_step,
                        self.mixture_entropy,
                    ],
                    feed_dict={
                        self.in_state: batch_next_states,
                        self.true_action: batch_actions,
                    },
                )

                if epoch == 0:
                    first_epoch_losses.append(batch_loss)
                if epoch == n_epochs - 1:
                    last_epoch_losses.append(batch_loss)

                if self.tensorboard_log is not None:
                    summary_writer.add_summary(summary, step)

                if verbose:
                    print(
                        "Epoch: {}/{}...".format(epoch + 1, n_epochs),
                        "Batch: {}/{}...".format(batch + 1, n_batches),
                        "Training loss: {:.4f}   (ent {:.4f})   ".format(
                            batch_loss, mixture_entropy),
                        "(learning_rate = {:.6f})".format(batch_lr),
                    )

        if return_initial_loss:
            return np.mean(first_epoch_losses), np.mean(last_epoch_losses)
        return np.mean(last_epoch_losses)
Ejemplo n.º 7
0
def latent_rlsp(
    _run,
    env,
    current_obs,
    max_horizon,
    experience_replay,
    latent_space,
    inverse_transition_model,
    policy_horizon_factor=1,
    epochs=1,
    learning_rate=0.2,
    r_prior=None,
    r_vec=None,
    threshold=1e-2,
    n_trajectories=50,
    n_trajectories_forward_factor=1.0,
    print_level=0,
    callback=None,
    reset_solver=False,
    solver_iterations=1000,
    trajectory_video_path=None,
    continue_training_dynamics=False,
    continue_training_latent_space=False,
    tf_graphs=None,
    clip_mujoco_obs=False,
    horizon_curriculum=False,
    inverse_model_parameters=dict(),
    latent_space_model_parameters=dict(),
    inverse_policy_parameters=dict(),
    reweight_gradient=False,
    max_epochs_per_horizon=20,
    init_from_policy=None,
    solver_str="sac",
    reward_action_norm_factor=0,
):
    """
    Deep RLSP algorithm.
    """
    assert solver_str in ("sac", "ppo")

    def update_policy(r_vec, solver, horizon, obs_backward):
        print("Updating policy")
        obs_backward.extend(current_obs)
        wrapped_env = LatentSpaceRewardWrapper(
            env,
            latent_space,
            r_vec,
            inferred_weight=None,
            init_observations=obs_backward,
            time_horizon=max_horizon * policy_horizon_factor,
            use_task_reward=False,
            reward_action_norm_factor=reward_action_norm_factor,
        )

        if reset_solver:
            print("resetting solver")
            if solver_str == "sac":
                solver = get_sac(wrapped_env,
                                 learning_starts=0,
                                 verbose=int(print_level >= 2))
            else:
                solver = get_ppo(wrapped_env, verbose=int(print_level >= 2))
        else:
            solver.set_env(wrapped_env)
        solver.learn(total_timesteps=solver_iterations, log_interval=100)
        return solver

    def update_inverse_policy(solver):
        print("Updating inverse policy")
        # inverse_policy = InversePolicyMDN(
        #     env, solver, experience_replay, **inverse_policy_parameters["model"]
        # )
        # inverse_policy.solver = solver
        first_epoch_loss, last_epoch_loss = inverse_policy.learn(
            return_initial_loss=True,
            verbose=print_level >= 2,
            **inverse_policy_parameters["learn"],
        )
        print("Inverse policy loss: {} --> {}".format(first_epoch_loss,
                                                      last_epoch_loss))
        return first_epoch_loss, last_epoch_loss

    def _get_transitions_from_rollout(observations, actions, dynamics,
                                      experience_replay):
        for obs, action in zip(observations, actions):
            try:
                # action = env.action_space.sample()
                next_obs = dynamics.dynamics(obs, action)
                experience_replay.append(obs, action, next_obs)
            except Exception as e:
                print("_get_transitions_from_rollout", e)

    def update_experience_replay(
        traj_actions_backward,
        traj_actions_forward,
        traj_observations_backward,
        traj_observations_forward,
        solver,
        experience_replay,
    ):
        """
        Rolls out the action sequences observed in the true dynamics model and keeps
        training the dynamics models on these.

        This is currently only used for debugging.
        """
        dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id,
                                       tolerance=1e-3,
                                       max_iters=100)

        for states, actions in zip(traj_observations_backward,
                                   traj_actions_backward):
            _get_transitions_from_rollout(states, actions, dynamics,
                                          experience_replay)

        for states, actions in zip(traj_observations_forward,
                                   traj_actions_forward):
            _get_transitions_from_rollout(states, actions, dynamics,
                                          experience_replay)

        # experience_replay.add_policy_rollouts(
        #     env, solver, 10, env.spec.max_episode_steps
        # )

    def compute_grad(r_vec, epoch, env, horizon):
        init_obs_list = []
        feature_counts_backward = np.zeros_like(r_vec)

        weight_sum_bwd = 0
        obs_counts_backward = np.zeros_like(current_obs[0])

        if trajectory_video_path is not None:
            trajectory_rgbs = []

        states_backward = []
        states_forward = []
        obs_backward = []
        traj_actions_backward = []
        traj_observations_backward = []

        np.random.shuffle(current_obs)
        current_obs_cycle = itertools.cycle(current_obs)

        print("Horizon", horizon)

        if print_level >= 2:
            print("Backward")

        for i_traj in range(n_trajectories):
            obs = np.copy(next(current_obs_cycle))
            if i_traj == 0 and trajectory_video_path is not None:
                rgb = render_mujoco_from_obs(env, obs)
                trajectory_rgbs.append(rgb)
            state = latent_space.encoder(obs)  # t = T

            feature_counts_backward_traj = np.copy(state)
            obs_counts_backward_traj = np.copy(obs)
            actions_backward = []
            observations_backward = []

            # simulate trajectory into the past
            # iterates for t = T, T-1, T-2, ..., 1
            for t in range(horizon, 0, -1):
                action = inverse_policy.step(obs)
                if print_level >= 2 and i_traj == 0:
                    print("inverse action", action)
                actions_backward.append(action)
                with timer.start("backward"):
                    obs = inverse_transition_model.step(obs,
                                                        action,
                                                        sample=True)
                    if clip_mujoco_obs:
                        obs = clipper.clip(obs)[0]
                    state = latent_space.encoder(obs)

                feature_counts_backward_traj += state
                states_backward.append(state)
                obs_backward.append(obs)

                # debugging
                observations_backward.append(obs)
                if print_level >= 2 and i_traj == 0:
                    with np.printoptions(suppress=True):
                        print(obs)
                obs_counts_backward_traj += obs
                if i_traj == 0 and trajectory_video_path is not None:
                    rgb = render_mujoco_from_obs(env, obs)
                    trajectory_rgbs.append(rgb)

            init_obs_list.append(obs)
            weight = similarity(obs, initial_obs) if reweight_gradient else 1
            weight_sum_bwd += weight
            if print_level >= 2:
                print("similarity weight", weight)

            feature_counts_backward += feature_counts_backward_traj * weight
            obs_counts_backward += obs_counts_backward_traj * weight
            traj_actions_backward.append(actions_backward)
            traj_observations_backward.append(observations_backward)

        if trajectory_video_path is not None:
            trajectory_rgbs.extend([np.zeros_like(rgb)] * 5)

        init_obs_cycle = itertools.cycle(init_obs_list)
        feature_counts_forward = np.zeros_like(r_vec)

        weight_sum_fwd = 0
        obs_counts_forward = np.zeros_like(current_obs[0])

        n_trajectories_forward = int(n_trajectories *
                                     n_trajectories_forward_factor)
        traj_actions_forward = []
        traj_observations_forward = []

        if print_level >= 2:
            print("Forward")
        for i_traj in range(n_trajectories_forward):
            obs = np.copy(next(init_obs_cycle))  # t = 0

            if i_traj == 0 and trajectory_video_path is not None:
                rgb = render_mujoco_from_obs(env, obs)
                trajectory_rgbs.append(rgb)

            weight = similarity(obs, initial_obs) if reweight_gradient else 1
            weight_sum_fwd += weight

            env = init_env_from_obs(env, obs)
            state = latent_space.encoder(obs)

            feature_counts_forward_traj = np.copy(state)
            obs_counts_forward_traj = np.copy(obs)
            actions_forward = []
            observations_forward = []
            failed = False

            # iterates for t = 0, ..., T-1
            for t in range(horizon):
                action = solver.predict(obs)[0]
                if print_level >= 2 and i_traj == 0:
                    print("forward action", action)
                actions_forward.append(action)
                with timer.start("forward"):
                    if not failed:
                        try:
                            new_obs = env.step(action)[0]
                            obs = new_obs
                        except Exception as e:
                            failed = True
                            print("compute_grad", e)

                        if clip_mujoco_obs:
                            obs, clipped = clipper.clip(obs)
                            if clipped:
                                env = init_env_from_obs(env, obs)
                    state = latent_space.encoder(obs)
                feature_counts_forward_traj += state
                states_forward.append(state)

                # debugging
                observations_forward.append(obs)
                if print_level >= 2 and i_traj == 0:
                    with np.printoptions(suppress=True):
                        print(obs)
                obs_counts_forward_traj += obs

                if i_traj == 0 and trajectory_video_path is not None:
                    rgb = env.render("rgb_array")
                    trajectory_rgbs.append(rgb)

            feature_counts_forward += weight * feature_counts_forward_traj
            obs_counts_forward += weight * obs_counts_forward_traj
            traj_actions_forward.append(actions_forward)
            traj_observations_forward.append(observations_forward)

        if trajectory_video_path is not None:
            video_path = os.path.join(trajectory_video_path,
                                      "epoch_{}_traj.avi".format(epoch))
            save_video(trajectory_rgbs, video_path, fps=2.0)
            print("Saved video to", video_path)

        # Normalize the gradient per-action,
        # so that its magnitude is not sensitive to the horizon
        feature_counts_forward /= weight_sum_fwd * horizon
        feature_counts_backward /= weight_sum_bwd * horizon

        dL_dr_vec = feature_counts_backward - feature_counts_forward

        print()
        print("\tfeature_counts_backward", feature_counts_backward)
        print("\tfeature_counts_forward", feature_counts_forward)
        print("\tn_trajectories_forward", n_trajectories_forward)
        print("\tn_trajectories", n_trajectories)
        print("\tdL_dr_vec", dL_dr_vec)
        print()

        # debugging
        if env.unwrapped.spec.id == "InvertedPendulum-v2":
            obs_counts_forward /= weight_sum_fwd * horizon
            obs_counts_backward /= weight_sum_bwd * horizon
            obs_counts_backward_enc = latent_space.encoder(obs_counts_backward)
            obs_counts_forward_enc = latent_space.encoder(obs_counts_forward)
            obs_counts_backward_old_reward = np.dot(r_vec,
                                                    obs_counts_backward_enc)
            obs_counts_forward_old_reward = np.dot(r_vec,
                                                   obs_counts_forward_enc)
            r_vec_new = r_vec + learning_rate * dL_dr_vec
            obs_counts_backward_new_reward = np.dot(r_vec_new,
                                                    obs_counts_backward_enc)
            obs_counts_forward_new_reward = np.dot(r_vec_new,
                                                   obs_counts_forward_enc)
            print("\tdebugging info")
            print("\tweight_sum_bwd", weight_sum_bwd)
            print("\tobs_counts_backward", obs_counts_backward)
            print("\tweight_sum_fwd", weight_sum_fwd)
            print("\tobs_counts_forward", obs_counts_forward)
            print("\told reward")
            print("\tobs_counts_backward", obs_counts_backward_old_reward)
            print("\tobs_counts_forward", obs_counts_forward_old_reward)
            print("\tnew reward")
            print("\tobs_counts_backward", obs_counts_backward_new_reward)
            print("\tobs_counts_forward", obs_counts_forward_new_reward)

            _run.log_scalar("debug_forward_pos", obs_counts_forward[0], epoch)
            _run.log_scalar("debug_forward_angle", obs_counts_forward[1],
                            epoch)
            _run.log_scalar("debug_forward_velocity", obs_counts_forward[2],
                            epoch)
            _run.log_scalar("debug_forward_angular_velocity",
                            obs_counts_forward[3], epoch)
            _run.log_scalar("debug_backward_pos", obs_counts_backward[0],
                            epoch)
            _run.log_scalar("debug_backward_angle", obs_counts_backward[1],
                            epoch)
            _run.log_scalar("debug_backward_velocity", obs_counts_backward[2],
                            epoch)
            _run.log_scalar("debug_backward_angular_velocity",
                            obs_counts_backward[3], epoch)
            init_obs_array = np.array(init_obs_list)
            _run.log_scalar("debug_init_state_pos",
                            init_obs_array[:, 0].mean(), epoch)
            _run.log_scalar("debug_init_state_angle",
                            init_obs_array[:, 1].mean(), epoch)
            _run.log_scalar("debug_init_state_velocity",
                            init_obs_array[:, 2].mean(), epoch)
            _run.log_scalar("debug_init_state_angular_velocity",
                            init_obs_array[:, 3].mean(), epoch)

        # Gradient of the prior
        if r_prior is not None:
            dL_dr_vec += r_prior.logdistr_grad(r_vec)

        return (
            dL_dr_vec,
            feature_counts_forward,
            feature_counts_backward,
            traj_actions_backward,
            traj_actions_forward,
            traj_observations_backward,
            traj_observations_forward,
            states_forward,
            states_backward,
            obs_backward,
        )

    timer = Timer()
    clipper = MujocoObsClipper(env.unwrapped.spec.id)

    if trajectory_video_path is not None:
        os.makedirs(trajectory_video_path, exist_ok=True)

    current_state = [latent_space.encoder(obs) for obs in current_obs]
    initial_obs = env.reset()

    if r_vec is None:
        r_vec = sum(current_state)
        r_vec /= np.linalg.norm(r_vec)

    with np.printoptions(precision=4, suppress=True, threshold=10):
        print("Initial reward vector: {}".format(r_vec))

    dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id,
                                   tolerance=1e-3,
                                   max_iters=100)

    wrapped_env = LatentSpaceRewardWrapper(env, latent_space, r_vec)

    if init_from_policy is not None:
        print(f"Loading policy from {init_from_policy}")
        solver = SAC.load(init_from_policy)
    else:
        if solver_str == "sac":
            solver = get_sac(wrapped_env,
                             learning_starts=0,
                             verbose=int(print_level >= 2))
        else:
            solver = get_ppo(wrapped_env, verbose=int(print_level >= 2))

    inverse_policy_graph = tf.Graph()

    with inverse_policy_graph.as_default():
        inverse_policy = InversePolicyMDN(env, solver, experience_replay,
                                          **inverse_policy_parameters["model"])

    gradients = []
    obs_backward = []

    solver = update_policy(r_vec, solver, 1, obs_backward)
    with inverse_policy_graph.as_default():
        (
            inverse_policy_initial_loss,
            inverse_policy_final_loss,
        ) = update_inverse_policy(solver)

    epoch = 0
    for horizon in range(1, max_horizon + 1):
        if not horizon_curriculum:
            horizon = max_horizon
            max_horizon = env.spec.max_episode_steps
            max_epochs_per_horizon = epochs
            threshold = -float("inf")

        last_n_grad_norm = float("inf")
        # initialize negatively in case we don't continue to train
        backward_final_loss = -float("inf")
        latent_final_loss = -float("inf")
        inverse_policy_final_loss = -float("inf")

        backward_threshold = float("inf")
        latent_threshold = float("inf")
        inverse_policy_threshold = float("inf")

        gradients = []
        print(f"New horizon: {horizon}")
        epochs_per_horizon = 0

        while epochs_per_horizon < max_epochs_per_horizon and (
                last_n_grad_norm > threshold
                or backward_final_loss > backward_threshold
                or latent_final_loss > latent_threshold
                or inverse_policy_final_loss > inverse_policy_threshold):
            epochs_per_horizon += 1
            epoch += 1
            if epoch > epochs:
                print(f"Stopping after {epoch} epochs.")
                return r_vec

            (
                dL_dr_vec,
                feature_counts_forward,
                feature_counts_backward,
                traj_actions_backward,
                traj_actions_forward,
                traj_observations_backward,
                traj_observations_forward,
                states_forward,
                states_backward,
                obs_backward,
            ) = compute_grad(r_vec, epoch, env, horizon)

            print("threshold", threshold)

            if clip_mujoco_obs:
                print("clipper.counter", clipper.counter)
                clipper.counter = 0

            if print_level >= 1:
                _print_timing(timer)

            grad_mean_n = 10
            gradients.append(dL_dr_vec)
            last_n_grad_norm = np.linalg.norm(
                np.mean(gradients[-grad_mean_n:], axis=0))
            # Clip gradient by norm
            grad_norm = np.linalg.norm(gradients[-1])
            if grad_norm > 10:
                dL_dr_vec = 10 * dL_dr_vec / grad_norm

            # Gradient ascent
            r_vec = r_vec + learning_rate * dL_dr_vec

            with np.printoptions(precision=3, suppress=True, threshold=10):
                print(
                    f"Epoch {epoch}; Reward vector: {r_vec} ",
                    "(norm {:.3f}); grad_norm {:.3f}; last_n_grad_norm: {:.3f}"
                    .format(np.linalg.norm(r_vec), grad_norm,
                            last_n_grad_norm),
                )

            latent_initial_loss = None
            forward_initial_loss = None
            backward_initial_loss = None
            if continue_training_dynamics or continue_training_latent_space:
                assert tf_graphs is not None
                update_experience_replay(
                    traj_actions_backward,
                    traj_actions_forward,
                    traj_observations_backward,
                    traj_observations_forward,
                    solver,
                    experience_replay,
                )

                if continue_training_dynamics:
                    with tf_graphs["inverse"].as_default():
                        (
                            backward_initial_loss,
                            backward_final_loss,
                        ) = inverse_transition_model.learn(
                            return_initial_loss=True,
                            verbose=print_level >= 2,
                            **inverse_model_parameters["learn"],
                        )
                    print("Backward model loss:  {} --> {}".format(
                        backward_initial_loss, backward_final_loss))

                if continue_training_latent_space:
                    with tf_graphs["latent"].as_default():
                        latent_initial_loss, latent_final_loss = latent_space.learn(
                            experience_replay,
                            return_initial_loss=True,
                            verbose=print_level >= 2,
                            **latent_space_model_parameters["learn"],
                        )
                    print("Latent space loss:  {} --> {}".format(
                        latent_initial_loss, latent_final_loss))

            solver = update_policy(r_vec, solver, horizon, obs_backward)
            with inverse_policy_graph.as_default():
                (
                    inverse_policy_initial_loss,
                    inverse_policy_final_loss,
                ) = update_inverse_policy(solver)

            if callback:
                callback(locals(), globals())

    return r_vec