class MlpModel(DynamicsModel): def __init__(self, env, n_layers=3, hidden_layer_size=64, optimizer_class=optim.Adam, learning_rate=1e-3, reward_weight=1, **kwargs): super().__init__(env=env, **kwargs) self.env = env obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) self.input_dim = obs_dim self.action_dim = action_dim self.next_obs_dim = obs_dim self.n_layers = n_layers self.hidden_layer_size = hidden_layer_size self.learning_rate = learning_rate self.reward_weight = reward_weight self.reset() self.reward_dim = 1 #terminal_dim = 1 self.net = FlattenMlp( hidden_sizes=[hidden_layer_size] * n_layers, input_size=self.input_dim + self.action_dim, output_size=self.next_obs_dim + self.reward_dim, ) self.net_optimizer = optimizer_class(self.net.parameters(), lr=learning_rate) def to(self, device=None): if device == None: device = ptu.device self.net.to(device) def _forward(self, state, action): output = self.net(state, action) next_state = output[:, :-self.reward_dim] reward = output[:, -self.reward_dim:] terminal = 0 env_info = {} return next_state, reward, terminal, env_info def step(self, action): action = ptu.from_numpy(action[np.newaxis, :]) next_state, reward, terminal, env_info = self._forward( self.state, action) self.state = next_state next_state = np.squeeze(ptu.get_numpy(next_state)) reward = np.squeeze(ptu.get_numpy(reward)) return next_state, reward, terminal, env_info def train(self, paths): states = ptu.from_numpy(paths["observations"]) actions = ptu.from_numpy(paths["actions"]) rewards = ptu.from_numpy(paths["rewards"]) next_states = ptu.from_numpy(paths["next_observations"]) terminals = paths["terminals"] next_state_preds, reward_preds, terminal_preds, env_infos = self._forward( states, actions) self.net_optimizer.zero_grad() self.transition_model_loss = torch.mean( (next_state_preds - next_states)**2) self.reward_model_loss = torch.mean((reward_preds - rewards)**2) self.net_loss = self.transition_model_loss + self.reward_weight * self.reward_model_loss self.net_loss.backward() self.net_optimizer.step()
def main( env_name, seed, deterministic, traj_prior, start_ft_after, ft_steps, avoid_freezing_z, lr, batch_size, avoid_loading_critics ): config = "configs/{}.json".format(env_name) variant = default_config if config: with open(osp.join(config)) as f: exp_params = json.load(f) variant = deep_update_dict(exp_params, variant) exp_name = variant['env_name'] print("Experiment: {}".format(exp_name)) env = NormalizedBoxEnv(ENVS[exp_name](**variant['env_params'])) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) print("Observation space:") print(env.observation_space) print(obs_dim) print("Action space:") print(env.action_space) print(action_dim) print("-" * 10) # instantiate networks latent_dim = variant['latent_size'] reward_dim = 1 context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant['algo_params']['use_next_obs_in_context'] else obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] encoder_model = RecurrentEncoder if recurrent else MlpEncoder context_encoder = encoder_model( hidden_sizes=[200, 200, 200], input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, ) qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) target_qf1 = qf1.copy() target_qf2 = qf2.copy() policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim, latent_dim=latent_dim, action_dim=action_dim, ) agent = PEARLAgent( latent_dim, context_encoder, policy, **variant['algo_params'] ) # deterministic eval if deterministic: agent = MakeDeterministic(agent) # load trained weights (otherwise simulate random policy) path_to_exp = "output/{}/pearl_{}".format(env_name, seed-1) print("Based on experiment: {}".format(path_to_exp)) context_encoder.load_state_dict(torch.load(os.path.join(path_to_exp, 'context_encoder.pth'))) policy.load_state_dict(torch.load(os.path.join(path_to_exp, 'policy.pth'))) if not avoid_loading_critics: qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf2.pth'))) target_qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf1.pth'))) target_qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf2.pth'))) # optional GPU mode ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id']) if ptu.gpu_enabled(): agent.to(device) policy.to(device) context_encoder.to(device) qf1.to(device) qf2.to(device) target_qf1.to(device) target_qf2.to(device) helper = PEARLFineTuningHelper( env=env, agent=agent, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, num_exp_traj_eval=traj_prior, start_fine_tuning=start_ft_after, fine_tuning_steps=ft_steps, should_freeze_z=(not avoid_freezing_z), replay_buffer_size=int(1e6), batch_size=batch_size, discount=0.99, policy_lr=lr, qf_lr=lr, temp_lr=lr, target_entropy=-action_dim, ) helper.fine_tune(variant=variant, seed=seed)