class PendulumDynamics: def __init__(self, latent_space, backward=False): assert latent_space.env.unwrapped.spec.id == "InvertedPendulum-v2" self.latent_space = latent_space self.backward = backward self.dynamics = ExactDynamicsMujoco( self.latent_space.env.unwrapped.spec.id, tolerance=1e-2, max_iters=100) self.low = np.array([-1, -np.pi, -10, -10]) self.high = np.array([1, np.pi, 10, 10]) def step(self, state, action, sample=True): obs = state # self.latent_space.decoder(state) obs = np.clip(obs, self.low, self.high) if self.backward: obs = self.dynamics.inverse_dynamics(obs, action) else: obs = self.dynamics.dynamics(obs, action) state = obs # self.latent_space.encoder(obs) return state def learn(self, *args, return_initial_loss=False, **kwargs): print("Using exact dynamics...") if return_initial_loss: return 0, 0 return 0
def __init__(self, latent_space, backward=False): assert latent_space.env.unwrapped.spec.id == "InvertedPendulum-v2" self.latent_space = latent_space self.backward = backward self.dynamics = ExactDynamicsMujoco( self.latent_space.env.unwrapped.spec.id, tolerance=1e-2, max_iters=100) self.low = np.array([-1, -np.pi, -10, -10]) self.high = np.array([1, np.pi, 10, 10])
def add_play_data(self, env, play_data): dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id) observations, actions = play_data["observations"], play_data["actions"] n_traj = len(observations) assert len(actions) == n_traj for i in range(n_traj): l_traj = len(observations[i]) for t in range(l_traj): obs = observations[i][t] action = actions[i][t] next_obs = dynamics.dynamics(obs, action) self.append(obs, action, next_obs)
def update_experience_replay( traj_actions_backward, traj_actions_forward, traj_observations_backward, traj_observations_forward, solver, experience_replay, ): """ Rolls out the action sequences observed in the true dynamics model and keeps training the dynamics models on these. This is currently only used for debugging. """ dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id, tolerance=1e-3, max_iters=100) for states, actions in zip(traj_observations_backward, traj_actions_backward): _get_transitions_from_rollout(states, actions, dynamics, experience_replay) for states, actions in zip(traj_observations_forward, traj_actions_forward): _get_transitions_from_rollout(states, actions, dynamics, experience_replay)
def __init__( self, env, solver, experience_replay, tensorboard_log=None, learning_rate=3e-4, n_layers=10, layer_size=256, n_out=1, gauss_stdev=None, ): self.env = env self.solver = solver self.experience_replay = experience_replay self.dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id, tolerance=1e-3, max_iters=100) assert len(self.env.action_space.shape) == 1 self.action_dim = self.env.action_space.shape[0] self.observation_shape = list(self.env.observation_space.shape) self.n_out = n_out self.gauss_stdev = gauss_stdev self.layer_size = layer_size self.n_layers = n_layers self._define_input_placeholders() self._define_model() self.loss = self._define_loss() self.learning_rate, self.global_step = get_learning_rate( learning_rate, None, 1) self.optimizer = tf.train.AdamOptimizer(self.learning_rate) self.gradients = self.optimizer.compute_gradients(loss=self.loss) self.optimization_op = self.optimizer.apply_gradients( self.gradients, global_step=self.global_step) self.tensorboard_log = tensorboard_log if self.tensorboard_log is not None: self._define_tensorboard_metrics() self.sess = None
class InversePolicyMDN: """ InversePolicyMDN """ def __init__( self, env, solver, experience_replay, tensorboard_log=None, learning_rate=3e-4, n_layers=10, layer_size=256, n_out=1, gauss_stdev=None, ): self.env = env self.solver = solver self.experience_replay = experience_replay self.dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id, tolerance=1e-3, max_iters=100) assert len(self.env.action_space.shape) == 1 self.action_dim = self.env.action_space.shape[0] self.observation_shape = list(self.env.observation_space.shape) self.n_out = n_out self.gauss_stdev = gauss_stdev self.layer_size = layer_size self.n_layers = n_layers self._define_input_placeholders() self._define_model() self.loss = self._define_loss() self.learning_rate, self.global_step = get_learning_rate( learning_rate, None, 1) self.optimizer = tf.train.AdamOptimizer(self.learning_rate) self.gradients = self.optimizer.compute_gradients(loss=self.loss) self.optimization_op = self.optimizer.apply_gradients( self.gradients, global_step=self.global_step) self.tensorboard_log = tensorboard_log if self.tensorboard_log is not None: self._define_tensorboard_metrics() self.sess = None def step(self, in_state, sample=True): batch_in_states = np.expand_dims(in_state, 0) if sample: (out, ) = self.sess.run([self.out_sample], feed_dict={self.in_state: batch_in_states}) out = out[0] else: factors, means, = self.sess.run( [self.mixture_factors, [c.loc for c in self.components]], feed_dict={self.in_state: batch_in_states}, ) out = means[np.argmax(factors[0])][0] out = np.clip(out, self.env.action_space.low, self.env.action_space.high) return out def _define_input_placeholders(self): self.true_action = tf.placeholder(tf.float32, [None, self.action_dim], name="action") self.in_state = tf.placeholder(tf.float32, [None] + self.observation_shape, name="in_state") def _define_model(self): activation = tf.nn.relu x = tf.reshape(self.in_state, [-1, np.prod(self.observation_shape)]) x_1 = x for i in range(self.n_layers): x = tf.keras.layers.Dense(self.layer_size, activation=activation, name="hidden_{}".format(i + 1))(x) means = x means = tf.keras.layers.Dense(self.n_out * self.action_dim, activation=None, name="output")(means) # note: this requires an 1d action space self.mixture_means = tf.split(means, [self.action_dim] * self.n_out, -1) if self.gauss_stdev is None: stddevs = x stddevs = tf.keras.layers.Dense(self.n_out, activation=None)(stddevs) stddevs = tf.exp(stddevs) stddevs = tf.clip_by_value(stddevs, clip_value_min=1e-10, clip_value_max=1e10) # same stddev for all components # self.mixture_stddevs = tf.reshape( # stddevs, (-1, self.n_out, self.state_size) # ) self.mixture_stddevs = tf.split(stddevs, [1] * self.n_out, -1) else: self.mixture_stddevs = None factors = tf.concat([x, x_1], -1) # factors = tf.keras.layers.Dense( # self.layer_size // 2, activation=activation, name="factors_hidden" # )(factors) factors = tf.keras.layers.Dense(self.n_out, activation=None, name="factors_out")(factors) # if we don't clip these, we get NANs in the loss computation factors = tf.clip_by_value(factors, clip_value_min=0.1, clip_value_max=10) self.mixture_factors = tf.nn.softmax(factors) if self.gauss_stdev is None: self.components = [ tfd.MultivariateNormalDiag( loc=mean, scale_diag=tf.tile(stddev, [1, self.action_dim]), validate_args=True, allow_nan_stats=False, ) for mean, stddev in zip(self.mixture_means, self.mixture_stddevs) ] else: self.components = [ tfd.MultivariateNormalDiag( loc=mean, scale_diag=tf.fill(tf.shape(mean), self.gauss_stdev), validate_args=True, allow_nan_stats=False, ) for mean in self.mixture_means ] self.mixture_distribution = tfd.Categorical(probs=self.mixture_factors, allow_nan_stats=False) self.out_distribution = tfd.Mixture( cat=self.mixture_distribution, components=self.components, validate_args=True, allow_nan_stats=False, ) self.out_sample = self.out_distribution.sample() def _define_loss(self): neg_logprob = -tf.reduce_mean( self.out_distribution.log_prob(self.true_action)) self.logprobs = [ tf.reduce_mean(dist.log_prob(self.true_action)) for dist in self.components ] self.mixture_entropy = tf.reduce_mean( self.mixture_distribution.entropy()) return neg_logprob def _define_tensorboard_metrics(self): tf.summary.scalar("loss/loss", self.loss) tf.summary.scalar("loss/mixture_entropy", self.mixture_entropy) tf.summary.scalar("learning_rate", self.learning_rate) for i, lp in enumerate(self.logprobs): tf.summary.scalar("loss/logprob_{}".format(i + 1), lp) tensorboard_log_gradients(self.gradients) def _apply_policy(self, observations): next_states, actions = [], [] for obs in observations: action = self.solver.predict(obs)[0] try: obs = self.dynamics.dynamics(obs, action) next_states.append(np.copy(obs)) actions.append(np.copy(action)) except Exception as e: print("_apply_policy", e) return next_states, actions def learn( self, n_epochs=1, batch_size=16, return_initial_loss=False, verbose=True, reinitialize=False, ): """ Main training loop """ if self.sess is None: self.sess = get_tf_session() reinitialize = True if reinitialize: self.sess.run(tf.global_variables_initializer()) if self.tensorboard_log is not None: summaries_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(self.tensorboard_log, self.sess.graph) else: summaries_op = tf.no_op() n_batches = len(self.experience_replay) // batch_size first_epoch_losses = [] last_epoch_losses = [] for epoch in range(n_epochs): for batch in range(n_batches): obs, _, _ = self.experience_replay.sample(batch_size, normalize=False) batch_next_states, batch_actions = self._apply_policy(obs) ( batch_loss, _, batch_lr, summary, step, mixture_entropy, ) = self.sess.run( [ self.loss, self.optimization_op, self.learning_rate, summaries_op, self.global_step, self.mixture_entropy, ], feed_dict={ self.in_state: batch_next_states, self.true_action: batch_actions, }, ) if epoch == 0: first_epoch_losses.append(batch_loss) if epoch == n_epochs - 1: last_epoch_losses.append(batch_loss) if self.tensorboard_log is not None: summary_writer.add_summary(summary, step) if verbose: print( "Epoch: {}/{}...".format(epoch + 1, n_epochs), "Batch: {}/{}...".format(batch + 1, n_batches), "Training loss: {:.4f} (ent {:.4f}) ".format( batch_loss, mixture_entropy), "(learning_rate = {:.6f})".format(batch_lr), ) if return_initial_loss: return np.mean(first_epoch_losses), np.mean(last_epoch_losses) return np.mean(last_epoch_losses)
def latent_rlsp( _run, env, current_obs, max_horizon, experience_replay, latent_space, inverse_transition_model, policy_horizon_factor=1, epochs=1, learning_rate=0.2, r_prior=None, r_vec=None, threshold=1e-2, n_trajectories=50, n_trajectories_forward_factor=1.0, print_level=0, callback=None, reset_solver=False, solver_iterations=1000, trajectory_video_path=None, continue_training_dynamics=False, continue_training_latent_space=False, tf_graphs=None, clip_mujoco_obs=False, horizon_curriculum=False, inverse_model_parameters=dict(), latent_space_model_parameters=dict(), inverse_policy_parameters=dict(), reweight_gradient=False, max_epochs_per_horizon=20, init_from_policy=None, solver_str="sac", reward_action_norm_factor=0, ): """ Deep RLSP algorithm. """ assert solver_str in ("sac", "ppo") def update_policy(r_vec, solver, horizon, obs_backward): print("Updating policy") obs_backward.extend(current_obs) wrapped_env = LatentSpaceRewardWrapper( env, latent_space, r_vec, inferred_weight=None, init_observations=obs_backward, time_horizon=max_horizon * policy_horizon_factor, use_task_reward=False, reward_action_norm_factor=reward_action_norm_factor, ) if reset_solver: print("resetting solver") if solver_str == "sac": solver = get_sac(wrapped_env, learning_starts=0, verbose=int(print_level >= 2)) else: solver = get_ppo(wrapped_env, verbose=int(print_level >= 2)) else: solver.set_env(wrapped_env) solver.learn(total_timesteps=solver_iterations, log_interval=100) return solver def update_inverse_policy(solver): print("Updating inverse policy") # inverse_policy = InversePolicyMDN( # env, solver, experience_replay, **inverse_policy_parameters["model"] # ) # inverse_policy.solver = solver first_epoch_loss, last_epoch_loss = inverse_policy.learn( return_initial_loss=True, verbose=print_level >= 2, **inverse_policy_parameters["learn"], ) print("Inverse policy loss: {} --> {}".format(first_epoch_loss, last_epoch_loss)) return first_epoch_loss, last_epoch_loss def _get_transitions_from_rollout(observations, actions, dynamics, experience_replay): for obs, action in zip(observations, actions): try: # action = env.action_space.sample() next_obs = dynamics.dynamics(obs, action) experience_replay.append(obs, action, next_obs) except Exception as e: print("_get_transitions_from_rollout", e) def update_experience_replay( traj_actions_backward, traj_actions_forward, traj_observations_backward, traj_observations_forward, solver, experience_replay, ): """ Rolls out the action sequences observed in the true dynamics model and keeps training the dynamics models on these. This is currently only used for debugging. """ dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id, tolerance=1e-3, max_iters=100) for states, actions in zip(traj_observations_backward, traj_actions_backward): _get_transitions_from_rollout(states, actions, dynamics, experience_replay) for states, actions in zip(traj_observations_forward, traj_actions_forward): _get_transitions_from_rollout(states, actions, dynamics, experience_replay) # experience_replay.add_policy_rollouts( # env, solver, 10, env.spec.max_episode_steps # ) def compute_grad(r_vec, epoch, env, horizon): init_obs_list = [] feature_counts_backward = np.zeros_like(r_vec) weight_sum_bwd = 0 obs_counts_backward = np.zeros_like(current_obs[0]) if trajectory_video_path is not None: trajectory_rgbs = [] states_backward = [] states_forward = [] obs_backward = [] traj_actions_backward = [] traj_observations_backward = [] np.random.shuffle(current_obs) current_obs_cycle = itertools.cycle(current_obs) print("Horizon", horizon) if print_level >= 2: print("Backward") for i_traj in range(n_trajectories): obs = np.copy(next(current_obs_cycle)) if i_traj == 0 and trajectory_video_path is not None: rgb = render_mujoco_from_obs(env, obs) trajectory_rgbs.append(rgb) state = latent_space.encoder(obs) # t = T feature_counts_backward_traj = np.copy(state) obs_counts_backward_traj = np.copy(obs) actions_backward = [] observations_backward = [] # simulate trajectory into the past # iterates for t = T, T-1, T-2, ..., 1 for t in range(horizon, 0, -1): action = inverse_policy.step(obs) if print_level >= 2 and i_traj == 0: print("inverse action", action) actions_backward.append(action) with timer.start("backward"): obs = inverse_transition_model.step(obs, action, sample=True) if clip_mujoco_obs: obs = clipper.clip(obs)[0] state = latent_space.encoder(obs) feature_counts_backward_traj += state states_backward.append(state) obs_backward.append(obs) # debugging observations_backward.append(obs) if print_level >= 2 and i_traj == 0: with np.printoptions(suppress=True): print(obs) obs_counts_backward_traj += obs if i_traj == 0 and trajectory_video_path is not None: rgb = render_mujoco_from_obs(env, obs) trajectory_rgbs.append(rgb) init_obs_list.append(obs) weight = similarity(obs, initial_obs) if reweight_gradient else 1 weight_sum_bwd += weight if print_level >= 2: print("similarity weight", weight) feature_counts_backward += feature_counts_backward_traj * weight obs_counts_backward += obs_counts_backward_traj * weight traj_actions_backward.append(actions_backward) traj_observations_backward.append(observations_backward) if trajectory_video_path is not None: trajectory_rgbs.extend([np.zeros_like(rgb)] * 5) init_obs_cycle = itertools.cycle(init_obs_list) feature_counts_forward = np.zeros_like(r_vec) weight_sum_fwd = 0 obs_counts_forward = np.zeros_like(current_obs[0]) n_trajectories_forward = int(n_trajectories * n_trajectories_forward_factor) traj_actions_forward = [] traj_observations_forward = [] if print_level >= 2: print("Forward") for i_traj in range(n_trajectories_forward): obs = np.copy(next(init_obs_cycle)) # t = 0 if i_traj == 0 and trajectory_video_path is not None: rgb = render_mujoco_from_obs(env, obs) trajectory_rgbs.append(rgb) weight = similarity(obs, initial_obs) if reweight_gradient else 1 weight_sum_fwd += weight env = init_env_from_obs(env, obs) state = latent_space.encoder(obs) feature_counts_forward_traj = np.copy(state) obs_counts_forward_traj = np.copy(obs) actions_forward = [] observations_forward = [] failed = False # iterates for t = 0, ..., T-1 for t in range(horizon): action = solver.predict(obs)[0] if print_level >= 2 and i_traj == 0: print("forward action", action) actions_forward.append(action) with timer.start("forward"): if not failed: try: new_obs = env.step(action)[0] obs = new_obs except Exception as e: failed = True print("compute_grad", e) if clip_mujoco_obs: obs, clipped = clipper.clip(obs) if clipped: env = init_env_from_obs(env, obs) state = latent_space.encoder(obs) feature_counts_forward_traj += state states_forward.append(state) # debugging observations_forward.append(obs) if print_level >= 2 and i_traj == 0: with np.printoptions(suppress=True): print(obs) obs_counts_forward_traj += obs if i_traj == 0 and trajectory_video_path is not None: rgb = env.render("rgb_array") trajectory_rgbs.append(rgb) feature_counts_forward += weight * feature_counts_forward_traj obs_counts_forward += weight * obs_counts_forward_traj traj_actions_forward.append(actions_forward) traj_observations_forward.append(observations_forward) if trajectory_video_path is not None: video_path = os.path.join(trajectory_video_path, "epoch_{}_traj.avi".format(epoch)) save_video(trajectory_rgbs, video_path, fps=2.0) print("Saved video to", video_path) # Normalize the gradient per-action, # so that its magnitude is not sensitive to the horizon feature_counts_forward /= weight_sum_fwd * horizon feature_counts_backward /= weight_sum_bwd * horizon dL_dr_vec = feature_counts_backward - feature_counts_forward print() print("\tfeature_counts_backward", feature_counts_backward) print("\tfeature_counts_forward", feature_counts_forward) print("\tn_trajectories_forward", n_trajectories_forward) print("\tn_trajectories", n_trajectories) print("\tdL_dr_vec", dL_dr_vec) print() # debugging if env.unwrapped.spec.id == "InvertedPendulum-v2": obs_counts_forward /= weight_sum_fwd * horizon obs_counts_backward /= weight_sum_bwd * horizon obs_counts_backward_enc = latent_space.encoder(obs_counts_backward) obs_counts_forward_enc = latent_space.encoder(obs_counts_forward) obs_counts_backward_old_reward = np.dot(r_vec, obs_counts_backward_enc) obs_counts_forward_old_reward = np.dot(r_vec, obs_counts_forward_enc) r_vec_new = r_vec + learning_rate * dL_dr_vec obs_counts_backward_new_reward = np.dot(r_vec_new, obs_counts_backward_enc) obs_counts_forward_new_reward = np.dot(r_vec_new, obs_counts_forward_enc) print("\tdebugging info") print("\tweight_sum_bwd", weight_sum_bwd) print("\tobs_counts_backward", obs_counts_backward) print("\tweight_sum_fwd", weight_sum_fwd) print("\tobs_counts_forward", obs_counts_forward) print("\told reward") print("\tobs_counts_backward", obs_counts_backward_old_reward) print("\tobs_counts_forward", obs_counts_forward_old_reward) print("\tnew reward") print("\tobs_counts_backward", obs_counts_backward_new_reward) print("\tobs_counts_forward", obs_counts_forward_new_reward) _run.log_scalar("debug_forward_pos", obs_counts_forward[0], epoch) _run.log_scalar("debug_forward_angle", obs_counts_forward[1], epoch) _run.log_scalar("debug_forward_velocity", obs_counts_forward[2], epoch) _run.log_scalar("debug_forward_angular_velocity", obs_counts_forward[3], epoch) _run.log_scalar("debug_backward_pos", obs_counts_backward[0], epoch) _run.log_scalar("debug_backward_angle", obs_counts_backward[1], epoch) _run.log_scalar("debug_backward_velocity", obs_counts_backward[2], epoch) _run.log_scalar("debug_backward_angular_velocity", obs_counts_backward[3], epoch) init_obs_array = np.array(init_obs_list) _run.log_scalar("debug_init_state_pos", init_obs_array[:, 0].mean(), epoch) _run.log_scalar("debug_init_state_angle", init_obs_array[:, 1].mean(), epoch) _run.log_scalar("debug_init_state_velocity", init_obs_array[:, 2].mean(), epoch) _run.log_scalar("debug_init_state_angular_velocity", init_obs_array[:, 3].mean(), epoch) # Gradient of the prior if r_prior is not None: dL_dr_vec += r_prior.logdistr_grad(r_vec) return ( dL_dr_vec, feature_counts_forward, feature_counts_backward, traj_actions_backward, traj_actions_forward, traj_observations_backward, traj_observations_forward, states_forward, states_backward, obs_backward, ) timer = Timer() clipper = MujocoObsClipper(env.unwrapped.spec.id) if trajectory_video_path is not None: os.makedirs(trajectory_video_path, exist_ok=True) current_state = [latent_space.encoder(obs) for obs in current_obs] initial_obs = env.reset() if r_vec is None: r_vec = sum(current_state) r_vec /= np.linalg.norm(r_vec) with np.printoptions(precision=4, suppress=True, threshold=10): print("Initial reward vector: {}".format(r_vec)) dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id, tolerance=1e-3, max_iters=100) wrapped_env = LatentSpaceRewardWrapper(env, latent_space, r_vec) if init_from_policy is not None: print(f"Loading policy from {init_from_policy}") solver = SAC.load(init_from_policy) else: if solver_str == "sac": solver = get_sac(wrapped_env, learning_starts=0, verbose=int(print_level >= 2)) else: solver = get_ppo(wrapped_env, verbose=int(print_level >= 2)) inverse_policy_graph = tf.Graph() with inverse_policy_graph.as_default(): inverse_policy = InversePolicyMDN(env, solver, experience_replay, **inverse_policy_parameters["model"]) gradients = [] obs_backward = [] solver = update_policy(r_vec, solver, 1, obs_backward) with inverse_policy_graph.as_default(): ( inverse_policy_initial_loss, inverse_policy_final_loss, ) = update_inverse_policy(solver) epoch = 0 for horizon in range(1, max_horizon + 1): if not horizon_curriculum: horizon = max_horizon max_horizon = env.spec.max_episode_steps max_epochs_per_horizon = epochs threshold = -float("inf") last_n_grad_norm = float("inf") # initialize negatively in case we don't continue to train backward_final_loss = -float("inf") latent_final_loss = -float("inf") inverse_policy_final_loss = -float("inf") backward_threshold = float("inf") latent_threshold = float("inf") inverse_policy_threshold = float("inf") gradients = [] print(f"New horizon: {horizon}") epochs_per_horizon = 0 while epochs_per_horizon < max_epochs_per_horizon and ( last_n_grad_norm > threshold or backward_final_loss > backward_threshold or latent_final_loss > latent_threshold or inverse_policy_final_loss > inverse_policy_threshold): epochs_per_horizon += 1 epoch += 1 if epoch > epochs: print(f"Stopping after {epoch} epochs.") return r_vec ( dL_dr_vec, feature_counts_forward, feature_counts_backward, traj_actions_backward, traj_actions_forward, traj_observations_backward, traj_observations_forward, states_forward, states_backward, obs_backward, ) = compute_grad(r_vec, epoch, env, horizon) print("threshold", threshold) if clip_mujoco_obs: print("clipper.counter", clipper.counter) clipper.counter = 0 if print_level >= 1: _print_timing(timer) grad_mean_n = 10 gradients.append(dL_dr_vec) last_n_grad_norm = np.linalg.norm( np.mean(gradients[-grad_mean_n:], axis=0)) # Clip gradient by norm grad_norm = np.linalg.norm(gradients[-1]) if grad_norm > 10: dL_dr_vec = 10 * dL_dr_vec / grad_norm # Gradient ascent r_vec = r_vec + learning_rate * dL_dr_vec with np.printoptions(precision=3, suppress=True, threshold=10): print( f"Epoch {epoch}; Reward vector: {r_vec} ", "(norm {:.3f}); grad_norm {:.3f}; last_n_grad_norm: {:.3f}" .format(np.linalg.norm(r_vec), grad_norm, last_n_grad_norm), ) latent_initial_loss = None forward_initial_loss = None backward_initial_loss = None if continue_training_dynamics or continue_training_latent_space: assert tf_graphs is not None update_experience_replay( traj_actions_backward, traj_actions_forward, traj_observations_backward, traj_observations_forward, solver, experience_replay, ) if continue_training_dynamics: with tf_graphs["inverse"].as_default(): ( backward_initial_loss, backward_final_loss, ) = inverse_transition_model.learn( return_initial_loss=True, verbose=print_level >= 2, **inverse_model_parameters["learn"], ) print("Backward model loss: {} --> {}".format( backward_initial_loss, backward_final_loss)) if continue_training_latent_space: with tf_graphs["latent"].as_default(): latent_initial_loss, latent_final_loss = latent_space.learn( experience_replay, return_initial_loss=True, verbose=print_level >= 2, **latent_space_model_parameters["learn"], ) print("Latent space loss: {} --> {}".format( latent_initial_loss, latent_final_loss)) solver = update_policy(r_vec, solver, horizon, obs_backward) with inverse_policy_graph.as_default(): ( inverse_policy_initial_loss, inverse_policy_final_loss, ) = update_inverse_policy(solver) if callback: callback(locals(), globals()) return r_vec