class SoftActorCritic(RL_algorithm): def __init__(self, config, env, replay, networks): """ Bascally a wrapper class for SAC from rlkit. Args: config: Configuration dictonary env: Environment replay: Replay buffer networks: dict containing two sub-dicts, 'individual' and 'population' which contain the networks. """ super().__init__(config, env, replay, networks) self._variant_pop = config['rl_algorithm_config']['algo_params_pop'] self._variant_spec = config['rl_algorithm_config']['algo_params'] self._ind_qf1 = networks['individual']['qf1'] self._ind_qf2 = networks['individual']['qf2'] self._ind_qf1_target = networks['individual']['qf1_target'] self._ind_qf2_target = networks['individual']['qf2_target'] self._ind_policy = networks['individual']['policy'] self._pop_qf1 = networks['population']['qf1'] self._pop_qf2 = networks['population']['qf2'] self._pop_qf1_target = networks['population']['qf1_target'] self._pop_qf2_target = networks['population']['qf2_target'] self._pop_policy = networks['population']['policy'] self._batch_size = config['rl_algorithm_config']['batch_size'] self._nmbr_indiv_updates = config['rl_algorithm_config']['indiv_updates'] self._nmbr_pop_updates = config['rl_algorithm_config']['pop_updates'] self._algorithm_ind = SoftActorCritic_rlkit( env=self._env, policy=self._ind_policy, qf1=self._ind_qf1, qf2=self._ind_qf2, target_qf1=self._ind_qf1_target, target_qf2=self._ind_qf2_target, use_automatic_entropy_tuning = False, **self._variant_spec ) self._algorithm_pop = SoftActorCritic_rlkit( env=self._env, policy=self._pop_policy, qf1=self._pop_qf1, qf2=self._pop_qf2, target_qf1=self._pop_qf1_target, target_qf2=self._pop_qf2_target, use_automatic_entropy_tuning = False, **self._variant_pop ) # self._algorithm_ind.to(ptu.device) # self._algorithm_pop.to(ptu.device) def episode_init(self): """ Initializations to be done before the first episode. In this case basically creates a fresh instance of SAC for the individual networks and copies the values of the target network. """ self._algorithm_ind = SoftActorCritic_rlkit( env=self._env, policy=self._ind_policy, qf1=self._ind_qf1, qf2=self._ind_qf2, target_qf1=self._ind_qf1_target, target_qf2=self._ind_qf2_target, use_automatic_entropy_tuning = False, # alt_alpha = self._alt_alpha, **self._variant_spec ) if self._config['rl_algorithm_config']['copy_from_gobal']: utils.copy_pop_to_ind(networks_pop=self._networks['population'], networks_ind=self._networks['individual']) # We have only to do this becasue the version of rlkit which we use # creates internally a target network # vf_dict = self._algorithm_pop.target_vf.state_dict() # self._algorithm_ind.target_vf.load_state_dict(vf_dict) # self._algorithm_ind.target_vf.eval() # self._algorithm_ind.to(ptu.device) def single_train_step(self, train_ind=True, train_pop=False): """ A single trianing step. Args: train_ind: Boolean. If true the individual networks will be trained. train_pop: Boolean. If true the population networks will be trained. """ if train_ind: # Get only samples from the species buffer self._replay.set_mode('species') # self._algorithm_ind.num_updates_per_train_call = self._variant_spec['num_updates_per_epoch'] # self._algorithm_ind._try_to_train() for _ in range(self._nmbr_indiv_updates): batch = self._replay.random_batch(self._batch_size) self._algorithm_ind.train(batch) if train_pop: # Get only samples from the population buffer self._replay.set_mode('population') # self._algorithm_pop.num_updates_per_train_call = self._variant_pop['num_updates_per_epoch'] # self._algorithm_pop._try_to_train() for _ in range(self._nmbr_pop_updates): batch = self._replay.random_batch(self._batch_size) self._algorithm_pop.train(batch) @staticmethod def create_networks(env, config): """ Creates all networks necessary for SAC. These networks have to be created before instantiating this class and used in the constructor. Args: config: A configuration dictonary containing population and individual networks Returns: A dictonary which contains the networks. """ network_dict = { 'individual' : SoftActorCritic._create_networks(env=env, config=config), 'population' : SoftActorCritic._create_networks(env=env, config=config), } return network_dict @staticmethod def _create_networks(env, config): """ Creates all networks necessary for SAC. These networks have to be created before instantiating this class and used in the constructor. TODO: Maybe this should be reworked one day... Args: config: A configuration dictonary. Returns: A dictonary which contains the networks. """ obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) net_size = config['rl_algorithm_config']['net_size'] hidden_sizes = [net_size] * config['rl_algorithm_config']['network_depth'] # hidden_sizes = [net_size, net_size, net_size] qf1 = FlattenMlp( hidden_sizes=hidden_sizes, input_size=obs_dim + action_dim, output_size=1, ).to(device=ptu.device) qf2 = FlattenMlp( hidden_sizes=hidden_sizes, input_size=obs_dim + action_dim, output_size=1, ).to(device=ptu.device) qf1_target = FlattenMlp( hidden_sizes=hidden_sizes, input_size=obs_dim + action_dim, output_size=1, ).to(device=ptu.device) qf2_target = FlattenMlp( hidden_sizes=hidden_sizes, input_size=obs_dim + action_dim, output_size=1, ).to(device=ptu.device) policy = TanhGaussianPolicy( hidden_sizes=hidden_sizes, obs_dim=obs_dim, action_dim=action_dim, ).to(device=ptu.device) clip_value = 1.0 for p in qf1.parameters(): p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value)) for p in qf2.parameters(): p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value)) for p in policy.parameters(): p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value)) return {'qf1' : qf1, 'qf2' : qf2, 'qf1_target' : qf1_target, 'qf2_target' : qf2_target, 'policy' : policy} @staticmethod def get_q_network(networks): """ Returns the q network from a dict of networks. This method extracts the q-network from the dictonary of networks created by the function create_networks. Args: networks: Dict containing the networks. Returns: The q-network as torch object. """ return networks['qf1'] @staticmethod def get_policy_network(networks): """ Returns the policy network from a dict of networks. This method extracts the policy network from the dictonary of networks created by the function create_networks. Args: networks: Dict containing the networks. Returns: The policy network as torch object. """ return networks['policy']
class SACDeadTrainer(TorchTrainer): def __init__( self, env, qf1, qf2, target_qf1, target_qf2, qf_dead, global_policy, discount=0.99, reward_scale=1.0, policy_lr=1e-3, qf_lr=1e-3, optimizer_class=RAdam, # optim.Adam soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, ): super().__init__() sac_trainer_params = { 'discount': discount, 'reward_scale': reward_scale, 'policy_lr': policy_lr, 'qf_lr': qf_lr, 'optimizer_class': optimizer_class, 'soft_target_tau': soft_target_tau, 'target_update_period': target_update_period, 'plotter': plotter, 'render_eval_paths': render_eval_paths, 'use_automatic_entropy_tuning': use_automatic_entropy_tuning, 'target_entropy': target_entropy } dead_trainer_params = { 'pass_per_iteration': 1, } self.sacTrainer = SACTrainer(env=env, policy=global_policy.tanhGaussianPolicy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_params) self.deadTrainer = DeadTrainer( policy_dead=global_policy.deadPredictionPolicy, qf_dead=qf_dead, **dead_trainer_params) ################################### self.env = env self.policy = global_policy self.plotter = plotter self.render_eval_paths = render_eval_paths self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True def train(self, batch): """ :param batch: dict with 3 fields 'normal', 'safe' and 'danger' 'normal' part is using for :return: """ batch_normal = batch['normal'] #batch_dead = {key: np.concatenate((batch['dead'][key], batch['safe'][key])) for key in batch['dead']} self.sacTrainer.train(batch_normal) self.deadTrainer.train(np_batch=batch['safe'], np_batch_dead=batch['dead']) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics = OrderedDict({ **self.sacTrainer.eval_statistics, **self.deadTrainer.eval_statistics }) self._n_train_steps_total += 1 def train_from_torch(self, batch): assert 'Blank method. Should not be called' == '' def get_diagnostics(self): return self.eval_statistics def end_epoch(self, epoch): self.sacTrainer.end_epoch(epoch) self.deadTrainer.end_epoch(epoch) self._need_to_update_eval_statistics = True @property def networks(self): return self.sacTrainer.networks + self.deadTrainer.networks @property def optimizers(self) -> Iterable[Optimizer]: return self.sacTrainer.optimizers + self.deadTrainer.optimizers def get_snapshot(self): d1 = self.sacTrainer.get_snapshot() d2 = self.deadTrainer.get_snapshot() return {**d1, **d2}
class SAC(Agent): def __init__(self, env, eval_env, mem, nets, train_step_params): super().__init__(env, eval_env, mem, nets, train_step_params) self._mem = mem self._env = env self._eval_env = eval_env self._policy_net, self._q1_net, self._q2_net, self._target_q1_net,\ self._target_q2_net = nets['policy_net'], nets['q1_net'], nets['q2_net'],\ nets['target_q1_net'], nets['target_q2_net'] self._train_step_params = train_step_params self._alg = SACTrainer(env=self._env, policy=self._policy_net, qf1=self._q1_net, qf2=self._q2_net, target_qf1=self._target_q1_net, target_qf2=self._target_q2_net, **train_step_params) def _train_step(self, n_trains_per_step, batch_size): if self._mem.num_steps_can_sample() < batch_size: return for _ in range(n_trains_per_step): batch = self._mem.random_batch(batch_size) self._alg.train(batch) def learn(self, results_path, seed_mem=True, **kwargs): replay_sample_ratio = kwargs['replay_sample_ratio'] n_epochs = kwargs['n_epochs'] episode_len = kwargs['episode_len'] eval_episode_len = kwargs['eval_episode_len'] start_steps = kwargs['start_steps'] n_trains_per_step = kwargs['n_trains_per_step'] eval_interval = kwargs['eval_interval'] batch_size = kwargs['train_batch_size'] checkpoint_interval = kwargs['checkpoint_interval'] log_interval = kwargs['log_interval'] artifact_path = results_path + f'artifact' # TODO: log summary of current config if seed_mem: s = self._env.reset() for k in range(start_steps): a = self._env.action_space.sample() ns, r, done, _ = self._env.step(a) self._mem.add_sample(observation=s, action=a, reward=r, next_observation=ns, terminal=1 if done else 0, env_info=dict()) if done: s = self._env.reset() else: s = ns # TODO: use RAY Sampler for parallel simulation sampling train_rt = Agent.results_tracker(id='train_performance') s = self._env.reset() # TODO: track and log policy loss for _ in range(n_epochs): for i in range(1, episode_len + 1): a = self.pred(s) ns, r, done, _ = self._env.step(a) self.life_tracker['total_n_train_steps'] += 1 train_rt['train_interval_timesteps'] += 1 train_rt['train_rewards'].append(r) self._mem.add_sample(observation=s, action=a, reward=r, next_observation=ns, terminal=1.0 if done else 0.0, env_info=dict()) if i % replay_sample_ratio == 0: self._train_step(n_trains_per_step, batch_size) self.life_tracker['total_n_train_batches'] += 1 if i % checkpoint_interval == 0: self.save(artifact_path) if done: s = self._env.reset() self.life_tracker['total_n_train_episodes'] += 1 train_rt['n_train_episodes_since_last_log'] += 1 if self.life_tracker[ 'total_n_train_episodes'] % log_interval == 0: if self.life_tracker[ 'total_n_train_episodes'] % eval_interval == 0: self.life_tracker['total_n_evals'] += 1 train_rt[ 'eval_interval_timesteps'] = eval_episode_len s = self._eval_env.reset() for _ in range(eval_episode_len): a = self.pred(s) ns, r, done, _ = self._eval_env.step(a) train_rt['eval_rewards'].append(r) if done: s = self._eval_env.reset() else: s = ns self.log_performance(train_rt) train_rt = Agent.results_tracker( id='train_performance') else: s = ns def pred(self, state, deterministic=False): state = torch.from_numpy(state).float().to(torch_util.device) with torch.no_grad(): return self._policy_net( state, deterministic=deterministic)[0].cpu().detach().numpy() def save(self, filepath): if not os.path.isdir(filepath): os.makedirs(filepath) nets = [ self._policy_net, self._q1_net, self._q2_net, self._target_q1_net, self._target_q2_net ] net_fps = [ 'policy-net.pt', 'q1-net.pt', 'q2-net.pt', 'target-q1-net.pt', 'target-q2-net.pt' ] for i, fn in enumerate(net_fps): torch.save(nets[i], f'{filepath}{os.sep}{fn}') comps = [self._train_step_params] comp_fps = ['train-step-params.pkl'] for i, fn in enumerate(comp_fps): with open(f'{filepath}{os.sep}{fn}', 'wb') as f: pickle.dump(comps[i], f)